diff --git a/deep_research/README.md b/deep_research/README.md new file mode 100644 index 00000000..c057d85b --- /dev/null +++ b/deep_research/README.md @@ -0,0 +1,604 @@ +# 🔍 ZenML Deep Research Agent + +A production-ready MLOps pipeline for conducting deep, comprehensive research on any topic using LLMs and web search capabilities. + +
+ Research Pipeline Visualization +

ZenML Deep Research pipeline flow

+
+ +## 🎯 Overview + +The ZenML Deep Research Agent is a scalable, modular pipeline that automates in-depth research on any topic. It: + +- Creates a structured outline based on your research query +- Researches each section through targeted web searches and LLM analysis +- Iteratively refines content through reflection cycles +- Produces a comprehensive, well-formatted research report +- Visualizes the research process and report structure in the ZenML dashboard + +This project transforms exploratory notebook-based research into a production-grade, reproducible, and transparent process using the ZenML MLOps framework. + +## 📝 Example Research Results + +The Deep Research Agent produces comprehensive, well-structured reports on any topic. Here's an example of research conducted on quantum computing: + +
+ Sample Research Report +

Sample report generated by the Deep Research Agent

+
+ +## 🚀 Pipeline Architecture + +The pipeline uses a parallel processing architecture for efficiency and breaks down the research process into granular steps for maximum modularity and control: + +1. **Initialize Prompts**: Load and track all prompts as versioned artifacts +2. **Query Decomposition**: Break down the main query into specific sub-questions +3. **Parallel Information Gathering**: Process multiple sub-questions concurrently for faster results +4. **Merge Results**: Combine results from parallel processing into a unified state +5. **Cross-Viewpoint Analysis**: Analyze discrepancies and agreements between different perspectives +6. **Reflection Generation**: Generate recommendations for improving research quality +7. **Human Approval** (optional): Get human approval for additional searches +8. **Execute Approved Searches**: Perform approved additional searches to fill gaps +9. **Final Report Generation**: Compile all synthesized information into a coherent HTML report +10. **Collect Tracing Metadata**: Gather comprehensive metrics about token usage, costs, and performance + +This architecture enables: +- Better reproducibility and caching of intermediate results +- Parallel processing for faster research completion +- Easier debugging and monitoring of specific research stages +- More flexible reconfiguration of individual components +- Enhanced transparency into how the research is conducted +- Human oversight and control over iterative research expansions + +## 💡 Under the Hood + +- **LLM Integration**: Uses litellm for flexible access to various LLM providers +- **Web Research**: Utilizes Tavily API for targeted internet searches +- **ZenML Orchestration**: Manages pipeline flow, artifacts, and caching +- **Reproducibility**: Track every step, parameter, and output via ZenML +- **Visualizations**: Interactive visualizations of the research structure and progress +- **Report Generation**: Uses static HTML templates for consistent, high-quality reports +- **Human-in-the-Loop**: Optional approval mechanism via ZenML alerters (Discord, Slack, etc.) +- **LLM Observability**: Integrated Langfuse tracking for monitoring LLM usage, costs, and performance + +## 🛠️ Getting Started + +### Prerequisites + +- Python 3.9+ +- ZenML installed and configured +- API key for your preferred LLM provider (configured with litellm) +- Tavily API key +- Langfuse account for LLM tracking (optional but recommended) + +### Installation + +```bash +# Clone the repository +git clone +cd zenml_deep_research + +# Install dependencies +pip install -r requirements.txt + +# Set up API keys +export OPENAI_API_KEY=your_openai_key # Or another LLM provider key +export TAVILY_API_KEY=your_tavily_key # For Tavily search (default) +export EXA_API_KEY=your_exa_key # For Exa search (optional) + +# Set up Langfuse for LLM tracking (optional) +export LANGFUSE_PUBLIC_KEY=your_public_key +export LANGFUSE_SECRET_KEY=your_secret_key +export LANGFUSE_HOST=https://cloud.langfuse.com # Or your self-hosted URL + +# Initialize ZenML (if needed) +zenml init +``` + +### Setting up Langfuse for LLM Tracking + +The pipeline integrates with [Langfuse](https://langfuse.com) for comprehensive LLM observability and tracking. This allows you to monitor LLM usage, costs, and performance across all pipeline runs. + +#### 1. Create a Langfuse Account + +1. Sign up at [cloud.langfuse.com](https://cloud.langfuse.com) or set up a self-hosted instance +2. Create a new project in your Langfuse dashboard (e.g., "deep-research") +3. Navigate to Settings → API Keys to get your credentials + +#### 2. Configure Environment Variables + +Set the following environment variables with your Langfuse credentials: + +```bash +export LANGFUSE_PUBLIC_KEY=pk-lf-... # Your public key +export LANGFUSE_SECRET_KEY=sk-lf-... # Your secret key +export LANGFUSE_HOST=https://cloud.langfuse.com # Or your self-hosted URL +``` + +#### 3. Configure Project Name + +The Langfuse project name can be configured in any of the pipeline configuration files: + +```yaml +# configs/enhanced_research.yaml +langfuse_project_name: "deep-research" # Change to match your Langfuse project +``` + +**Note**: The project must already exist in your Langfuse dashboard before running the pipeline. + +#### What Gets Tracked + +When Langfuse is configured, the pipeline automatically tracks: + +- **All LLM calls** with their prompts, responses, and token usage +- **Pipeline trace information** including: + - `trace_name`: The ZenML pipeline run name for easy identification + - `trace_id`: The unique ZenML pipeline run ID for correlation +- **Tagged operations** such as: + - `structured_llm_output`: JSON generation calls + - `information_synthesis`: Research synthesis operations + - `find_most_relevant_string`: Relevance matching operations +- **Performance metrics**: Latency, token counts, and costs +- **Project organization**: All traces are organized under your configured project + +This integration provides full observability into your research pipeline's LLM usage, making it easy to optimize performance, track costs, and debug issues. + +### Running the Pipeline + +#### Basic Usage + +```bash +# Run with default configuration +python run.py +``` + +The default configuration and research query are defined in `configs/enhanced_research.yaml`. + +#### Using Research Mode Presets + +The pipeline includes three pre-configured research modes for different use cases: + +```bash +# Rapid mode - Quick overview with minimal depth +python run.py --mode rapid + +# Balanced mode - Standard research depth (default) +python run.py --mode balanced + +# Deep mode - Comprehensive analysis with maximum depth +python run.py --mode deep +``` + +**Mode Comparison:** + +| Mode | Sub-Questions | Search Results* | Additional Searches | Best For | +|------|---------------|----------------|-------------------|----------| +| **Rapid** | 5 | 2 per search | 0 | Quick overviews, time-sensitive research | +| **Balanced** | 10 | 3 per search | 2 | Most research tasks, good depth/speed ratio | +| **Deep** | 15 | 5 per search | 4 | Comprehensive analysis, academic research | + +*Can be overridden with `--num-results` + +#### Using Different Configurations + +```bash +# Run with a custom configuration file +python run.py --config configs/custom_enhanced_config.yaml + +# Override the research query from command line +python run.py --query "My research topic" + +# Specify maximum number of sub-questions to process in parallel +python run.py --max-sub-questions 15 + +# Combine mode with other options +python run.py --mode deep --query "Complex topic" --require-approval + +# Combine multiple options +python run.py --config configs/custom_enhanced_config.yaml --query "My research topic" --max-sub-questions 12 +``` + +### Advanced Options + +```bash +# Enable debug logging +python run.py --debug + +# Disable caching for a fresh run +python run.py --no-cache + +# Specify a log file +python run.py --log-file research.log + +# Enable human-in-the-loop approval for additional research +python run.py --require-approval + +# Set approval timeout (in seconds) +python run.py --require-approval --approval-timeout 7200 + +# Use a different search provider (default: tavily) +python run.py --search-provider exa # Use Exa search +python run.py --search-provider both # Use both providers +python run.py --search-provider exa --search-mode neural # Exa with neural search + +# Control the number of search results per query +python run.py --num-results 5 # Get 5 results per search +python run.py --num-results 10 --search-provider exa # 10 results with Exa +``` + +### Search Providers + +The pipeline supports multiple search providers for flexibility and comparison: + +#### Available Providers + +1. **Tavily** (Default) + - Traditional keyword-based search + - Good for factual information and current events + - Requires `TAVILY_API_KEY` environment variable + +2. **Exa** + - Neural search engine with semantic understanding + - Better for conceptual and research-oriented queries + - Supports three search modes: + - `auto` (default): Automatically chooses between neural and keyword + - `neural`: Semantic search for conceptual understanding + - `keyword`: Traditional keyword matching + - Requires `EXA_API_KEY` environment variable + +3. **Both** + - Runs searches on both providers + - Useful for comprehensive research or comparing results + - Requires both API keys + +#### Usage Examples + +```bash +# Use Exa with neural search +python run.py --search-provider exa --search-mode neural + +# Compare results from both providers +python run.py --search-provider both + +# Use Exa with keyword search for exact matches +python run.py --search-provider exa --search-mode keyword + +# Combine with other options +python run.py --mode deep --search-provider exa --require-approval +``` + +### Human-in-the-Loop Approval + +The pipeline supports human approval for additional research queries identified during the reflection phase: + +```bash +# Enable approval with default 1-hour timeout +python run.py --require-approval + +# Custom timeout (2 hours) +python run.py --require-approval --approval-timeout 7200 + +# Approval works with any configuration +python run.py --config configs/thorough_research.yaml --require-approval +``` + +When enabled, the pipeline will: +1. Pause after the initial research phase +2. Send an approval request via your configured ZenML alerter (Discord, Slack, etc.) +3. Present research progress, identified gaps, and proposed additional queries +4. Wait for your approval before conducting additional searches +5. Continue with approved queries or finalize the report based on your decision + +**Note**: You need a ZenML stack with an alerter configured (e.g., Discord or Slack) for approval functionality to work. + +**Tip**: When using `--mode deep`, the pipeline will suggest enabling `--require-approval` for better control over the comprehensive research process. + +## 📊 Visualizing Research Process + +The pipeline includes built-in visualizations to help you understand and monitor the research process: + +### Viewing Visualizations + +After running the pipeline, you can view the visualizations in the ZenML dashboard: + +1. Start the ZenML dashboard: + ```bash + zenml up + ``` + +2. Navigate to the "Runs" tab in the dashboard +3. Select your pipeline run +4. Explore visualizations for each step: + - **initialize_prompts_step**: View all prompts used in the pipeline + - **initial_query_decomposition_step**: See how the query was broken down + - **process_sub_question_step**: Track progress for each sub-question + - **cross_viewpoint_analysis_step**: View viewpoint analysis results + - **generate_reflection_step**: See reflection and recommendations + - **get_research_approval_step**: View approval decisions + - **pydantic_final_report_step**: Access the final research state + - **collect_tracing_metadata_step**: View comprehensive cost and performance metrics + +### Visualization Features + +The visualizations provide: +- An overview of the report structure +- Details of each paragraph's research status +- Search history and source information +- Progress through reflection iterations +- Professionally formatted HTML reports with static templates + +### Sample Visualization + +Here's what the report structure visualization looks like: + +``` +Report Structure: +├── Introduction +│ └── Initial understanding of the topic +├── Historical Background +│ └── Evolution and key developments +├── Current State +│ └── Latest advancements and implementations +└── Conclusion + └── Summary and future implications +``` + +## 📁 Project Structure + +``` +zenml_deep_research/ +├── configs/ # Configuration files +│ ├── __init__.py +│ └── enhanced_research.yaml # Main configuration file +├── materializers/ # Custom materializers for artifact storage +│ ├── __init__.py +│ └── pydantic_materializer.py +├── pipelines/ # ZenML pipeline definitions +│ ├── __init__.py +│ └── parallel_research_pipeline.py +├── steps/ # ZenML pipeline steps +│ ├── __init__.py +│ ├── approval_step.py # Human approval step for additional research +│ ├── cross_viewpoint_step.py +│ ├── execute_approved_searches_step.py # Execute approved searches +│ ├── generate_reflection_step.py # Generate reflection without execution +│ ├── iterative_reflection_step.py # Legacy combined reflection step +│ ├── merge_results_step.py +│ ├── process_sub_question_step.py +│ ├── pydantic_final_report_step.py +│ └── query_decomposition_step.py +├── utils/ # Utility functions and helpers +│ ├── __init__.py +│ ├── approval_utils.py # Human approval utilities +│ ├── helper_functions.py +│ ├── llm_utils.py # LLM integration utilities +│ ├── prompts.py # Contains prompt templates and HTML templates +│ ├── pydantic_models.py # Data models using Pydantic +│ └── search_utils.py # Web search functionality +├── __init__.py +├── requirements.txt # Project dependencies +├── logging_config.py # Logging configuration +├── README.md # Project documentation +└── run.py # Main script to run the pipeline +``` + +## 🔧 Customization + +The project supports two levels of customization: + +### 1. Command-Line Parameters + +You can customize the research behavior directly through command-line parameters: + +```bash +# Specify your research query +python run.py --query "Your research topic" + +# Control parallelism with max-sub-questions +python run.py --max-sub-questions 15 + +# Combine multiple options +python run.py --query "Your research topic" --max-sub-questions 12 --no-cache +``` + +These settings control how the parallel pipeline processes your research query. + +### 2. Pipeline Configuration + +For more detailed settings, modify the configuration file: + +```yaml +# configs/enhanced_research.yaml + +# Enhanced Deep Research Pipeline Configuration +enable_cache: true + +# Research query parameters +query: "Climate change policy debates" + +# Step configurations +steps: + initial_query_decomposition_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + cross_viewpoint_analysis_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + viewpoint_categories: ["scientific", "political", "economic", "social", "ethical", "historical"] + + iterative_reflection_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + max_additional_searches: 2 + num_results_per_search: 3 + + # Human approval configuration (when using --require-approval) + get_research_approval_step: + parameters: + timeout: 3600 # 1 hour timeout for approval + max_queries: 2 # Maximum queries to present for approval + + pydantic_final_report_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 +``` + +To use a custom configuration file: + +```bash +python run.py --config configs/custom_research.yaml +``` + +### Available Configurations + +**Mode-Based Configurations** (automatically selected when using `--mode`): + +| Config File | Mode | Description | +|-------------|------|-------------| +| `rapid_research.yaml` | `--mode rapid` | Quick overview with minimal depth | +| `balanced_research.yaml` | `--mode balanced` | Standard research with moderate depth | +| `deep_research.yaml` | `--mode deep` | Comprehensive analysis with maximum depth | + +**Specialized Configurations:** + +| Config File | Description | Key Parameters | +|-------------|-------------|----------------| +| `enhanced_research.yaml` | Default research configuration | Standard settings, 2 additional searches | +| `thorough_research.yaml` | In-depth analysis | 12 sub-questions, 5 results per search | +| `quick_research.yaml` | Faster results | 5 sub-questions, 2 results per search | +| `daily_trends.yaml` | Research on recent topics | 24-hour search recency, disable cache | +| `compare_viewpoints.yaml` | Focus on comparing perspectives | Extended viewpoint categories | +| `parallel_research.yaml` | Optimized for parallel execution | Configured for distributed orchestrators | + +You can create additional configuration files by copying and modifying the base configuration files above. + +## 🎯 Prompts Tracking and Management + +The pipeline includes a sophisticated prompts tracking system that allows you to track all prompts as versioned artifacts in ZenML. This provides better observability, version control, and visualization of the prompts used in your research pipeline. + +### Overview + +The prompts tracking system enables: +- **Artifact Tracking**: All prompts are tracked as versioned artifacts in ZenML +- **Beautiful Visualizations**: HTML interface in the dashboard with search, copy, and expand features +- **Version Control**: Prompts are versioned alongside your code +- **Pipeline Integration**: Prompts are passed through the pipeline as artifacts, not hardcoded imports + +### Components + +1. **PromptsBundle Model** (`utils/prompt_models.py`) + - Pydantic model containing all prompts used in the pipeline + - Each prompt includes metadata: name, content, description, version, and tags + +2. **PromptsBundleMaterializer** (`materializers/prompts_materializer.py`) + - Custom materializer creating HTML visualizations in the ZenML dashboard + - Features: search, copy-to-clipboard, expandable content, tag categorization + +3. **Prompt Loader** (`utils/prompt_loader.py`) + - Utility to load prompts from `prompts.py` into a PromptsBundle + +### Integration Guide + +To integrate prompts tracking into a pipeline: + +1. **Initialize prompts as the first step:** + ```python + from steps.initialize_prompts_step import initialize_prompts_step + + @pipeline + def my_pipeline(): + prompts_bundle = initialize_prompts_step(pipeline_version="1.0.0") + ``` + +2. **Update steps to receive prompts_bundle:** + ```python + @step + def my_step(state: ResearchState, prompts_bundle: PromptsBundle): + prompt = prompts_bundle.get_prompt_content("synthesis_prompt") + # Use prompt in your step logic + ``` + +3. **Pass prompts_bundle through the pipeline:** + ```python + state = synthesis_step(state=state, prompts_bundle=prompts_bundle) + ``` + +### Benefits + +- **Full Tracking**: Every pipeline run tracks which exact prompts were used +- **Version History**: See how prompts evolved across different runs +- **Debugging**: Easily identify which prompts produced specific outputs +- **A/B Testing**: Compare results using different prompt versions + +### Visualization Features + +The HTML visualization in the ZenML dashboard includes: +- Pipeline version and creation timestamp +- Statistics (total prompts, tagged prompts, custom prompts) +- Search functionality across all prompt content +- Expandable/collapsible prompt content +- One-click copy to clipboard +- Tag-based categorization with visual indicators + +## 📊 Cost and Performance Tracking + +The pipeline includes comprehensive tracking of costs and performance metrics through the `collect_tracing_metadata_step`, which runs at the end of each pipeline execution. + +### Tracked Metrics + +- **LLM Costs**: Detailed breakdown by model and prompt type +- **Search Costs**: Tracking for both Tavily and Exa search providers +- **Token Usage**: Input/output tokens per model and step +- **Performance**: Latency and execution time metrics +- **Cost Attribution**: See which steps and prompts consume the most resources + +### Viewing Metrics + +After pipeline execution, the tracing metadata is available in the ZenML dashboard: + +1. Navigate to your pipeline run +2. Find the `collect_tracing_metadata_step` +3. View the comprehensive cost visualization including: + - Total pipeline cost (LLM + Search) + - Cost breakdown by model + - Token usage distribution + - Performance metrics + +This helps you: +- Optimize pipeline costs by identifying expensive operations +- Monitor token usage to stay within limits +- Track performance over time +- Make informed decisions about model selection + +## 📈 Example Use Cases + +- **Academic Research**: Rapidly generate preliminary research on academic topics +- **Business Intelligence**: Stay informed on industry trends and competitive landscape +- **Content Creation**: Develop well-researched content for articles, blogs, or reports +- **Decision Support**: Gather comprehensive information for informed decision-making + +## 🔄 Integration Possibilities + +This pipeline can integrate with: + +- **Document Storage**: Save reports to database or document management systems +- **Web Applications**: Power research functionality in web interfaces +- **Alerting Systems**: Schedule research on key topics and receive regular reports +- **Other ZenML Pipelines**: Chain with downstream analysis or processing + +## 📄 License + +This project is licensed under the Apache License 2.0. diff --git a/deep_research/__init__.py b/deep_research/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/deep_research/configs/balanced_research.yaml b/deep_research/configs/balanced_research.yaml new file mode 100644 index 00000000..4f8bfa23 --- /dev/null +++ b/deep_research/configs/balanced_research.yaml @@ -0,0 +1,79 @@ +# Deep Research Pipeline Configuration - Balanced Mode +enable_cache: true + +# ZenML MCP +model: + name: "deep_research" + description: "Parallelized ZenML pipelines for deep research on a given query." + tags: + [ + "research", + "exa", + "tavily", + "openrouter", + "sambanova", + "langfuse", + "balanced", + ] + use_cases: "Research on a given query." + +# Langfuse project name for LLM tracking +langfuse_project_name: "deep-research" + +# Research parameters for balanced research +parameters: + query: "Default research query" + +steps: + initial_query_decomposition_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + max_sub_questions: 10 # Balanced number of sub-questions + + process_sub_question_step: + parameters: + llm_model_search: "sambanova/Meta-Llama-3.3-70B-Instruct" + llm_model_synthesis: "sambanova/DeepSeek-R1-Distill-Llama-70B" + cap_search_length: 20000 # Standard cap for search length + + cross_viewpoint_analysis_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + viewpoint_categories: + [ + "scientific", + "political", + "economic", + "social", + "ethical", + "historical", + ] # Standard viewpoints + + generate_reflection_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + get_research_approval_step: + parameters: + timeout: 3600 # 1 hour timeout + max_queries: 2 # Moderate additional queries + + execute_approved_searches_step: + parameters: + llm_model: "sambanova/Meta-Llama-3.3-70B-Instruct" + cap_search_length: 20000 + + pydantic_final_report_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 \ No newline at end of file diff --git a/deep_research/configs/compare_viewpoints.yaml b/deep_research/configs/compare_viewpoints.yaml new file mode 100644 index 00000000..26a59dd4 --- /dev/null +++ b/deep_research/configs/compare_viewpoints.yaml @@ -0,0 +1,43 @@ +# Deep Research Pipeline Configuration - Compare Viewpoints +enable_cache: true + +# ZenML MCP +model: + name: "deep_research" + description: "Parallelized ZenML pipelines for deep research on a given query." + tags: + [ + "research", + "exa", + "tavily", + "openrouter", + "sambanova", + "langfuse", + "viewpoints", + ] + use_cases: "Research on a given query." + +# Langfuse project name for LLM tracking +langfuse_project_name: "deep-research" + +# Research parameters for comparing different viewpoints +steps: + initial_query_decomposition_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + cross_viewpoint_analysis_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + viewpoint_categories: ["scientific", "political", "economic", "social", "ethical", "historical"] + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 diff --git a/deep_research/configs/daily_trends.yaml b/deep_research/configs/daily_trends.yaml new file mode 100644 index 00000000..0f2b6587 --- /dev/null +++ b/deep_research/configs/daily_trends.yaml @@ -0,0 +1,45 @@ +# Deep Research Pipeline Configuration - Daily Trends Research +enable_cache: false # Disable cache to always get fresh results for daily trends + +# ZenML MCP +model: + name: "deep_research" + description: "Parallelized ZenML pipelines for deep research on a given query." + tags: + [ + "research", + "exa", + "tavily", + "openrouter", + "sambanova", + "langfuse", + "daily_trends", + ] + use_cases: "Research on a given query." + +# Langfuse project name for LLM tracking +langfuse_project_name: "deep-research" + +# Research parameters for daily trending topics +parameters: + query: "Latest developments in artificial intelligence" + +steps: + initial_query_decomposition_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + process_sub_question_step: + parameters: + llm_model_search: "sambanova/Meta-Llama-3.3-70B-Instruct" + llm_model_synthesis: "sambanova/DeepSeek-R1-Distill-Llama-70B" +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 \ No newline at end of file diff --git a/deep_research/configs/deep_research.yaml b/deep_research/configs/deep_research.yaml new file mode 100644 index 00000000..61cc4c2b --- /dev/null +++ b/deep_research/configs/deep_research.yaml @@ -0,0 +1,81 @@ +# Deep Research Pipeline Configuration - Deep Comprehensive Mode +enable_cache: false # Disable cache for fresh comprehensive analysis + +# ZenML MCP +model: + name: "deep_research" + description: "Parallelized ZenML pipelines for deep research on a given query." + tags: + [ + "research", + "exa", + "tavily", + "openrouter", + "sambanova", + "langfuse", + "deep", + ] + use_cases: "Research on a given query." + +# Langfuse project name for LLM tracking +langfuse_project_name: "deep-research" + +# Research parameters for deep comprehensive research +parameters: + query: "Default research query" + +steps: + initial_query_decomposition_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + max_sub_questions: 15 # Maximum sub-questions for comprehensive analysis + + process_sub_question_step: + parameters: + llm_model_search: "sambanova/Meta-Llama-3.3-70B-Instruct" + llm_model_synthesis: "sambanova/DeepSeek-R1-Distill-Llama-70B" + cap_search_length: 30000 # Higher cap for more comprehensive data + + cross_viewpoint_analysis_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + viewpoint_categories: + [ + "scientific", + "political", + "economic", + "social", + "ethical", + "historical", + "technological", + "philosophical", + ] # Extended viewpoints for comprehensive analysis + + generate_reflection_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + get_research_approval_step: + parameters: + timeout: 7200 # 2 hour timeout for deep research + max_queries: 4 # Maximum additional queries for deep mode + + execute_approved_searches_step: + parameters: + llm_model: "sambanova/Meta-Llama-3.3-70B-Instruct" + cap_search_length: 30000 + + pydantic_final_report_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 \ No newline at end of file diff --git a/deep_research/configs/enhanced_research.yaml b/deep_research/configs/enhanced_research.yaml new file mode 100644 index 00000000..1933efa9 --- /dev/null +++ b/deep_research/configs/enhanced_research.yaml @@ -0,0 +1,71 @@ +# Enhanced Deep Research Pipeline Configuration +enable_cache: false + +# ZenML MCP +model: + name: "deep_research" + description: "Parallelized ZenML pipelines for deep research on a given query." + tags: + [ + "research", + "exa", + "tavily", + "openrouter", + "sambanova", + "langfuse", + "enhanced", + ] + use_cases: "Research on a given query." + +# Research query parameters +query: "Climate change policy debates" + +# Langfuse project name for LLM tracking +langfuse_project_name: "deep-research" + +# Step configurations +steps: + initial_query_decomposition_step: + parameters: + llm_model: "openrouter/google/gemini-2.5-flash-preview-05-20" + + cross_viewpoint_analysis_step: + parameters: + llm_model: "openrouter/google/gemini-2.5-flash-preview-05-20" + viewpoint_categories: + [ + "scientific", + "political", + "economic", + "social", + "ethical", + "historical", + ] + + generate_reflection_step: + parameters: + llm_model: "openrouter/google/gemini-2.5-flash-preview-05-20" + + get_research_approval_step: + parameters: + timeout: 3600 + max_queries: 2 + + execute_approved_searches_step: + parameters: + llm_model: "openrouter/google/gemini-2.5-flash-preview-05-20" + + pydantic_final_report_step: + parameters: + llm_model: "openrouter/google/gemini-2.5-flash-preview-05-20" + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 diff --git a/deep_research/configs/enhanced_research_with_approval.yaml b/deep_research/configs/enhanced_research_with_approval.yaml new file mode 100644 index 00000000..73d6fe42 --- /dev/null +++ b/deep_research/configs/enhanced_research_with_approval.yaml @@ -0,0 +1,77 @@ +# Enhanced Deep Research Pipeline Configuration with Human Approval +enable_cache: false + +# ZenML MCP +model: + name: "deep_research" + description: "Parallelized ZenML pipelines for deep research on a given query." + tags: + [ + "research", + "exa", + "tavily", + "openrouter", + "sambanova", + "langfuse", + "enhanced_approval", + ] + use_cases: "Research on a given query." + +# Langfuse project name for LLM tracking +langfuse_project_name: "deep-research" + +# Research query parameters +query: "Climate change policy debates" + +# Pipeline parameters +parameters: + require_approval: true # Enable human-in-the-loop approval + approval_timeout: 1800 # 30 minutes timeout for approval + max_additional_searches: 3 # Allow up to 3 additional searches + +# Step configurations +steps: + initial_query_decomposition_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + cross_viewpoint_analysis_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + viewpoint_categories: + [ + "scientific", + "political", + "economic", + "social", + "ethical", + "historical", + ] + + # New reflection steps (replacing iterative_reflection_step) + generate_reflection_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + get_research_approval_step: + parameters: + alerter_type: "slack" # or "email" if configured + + execute_approved_searches_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + pydantic_final_report_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 \ No newline at end of file diff --git a/deep_research/configs/parallel_research.yaml b/deep_research/configs/parallel_research.yaml new file mode 100644 index 00000000..25ea6df0 --- /dev/null +++ b/deep_research/configs/parallel_research.yaml @@ -0,0 +1,93 @@ +# Deep Research Pipeline Configuration - Parallelized Version +enable_cache: true + +# ZenML MCP +model: + name: "deep_research" + description: "Parallelized ZenML pipelines for deep research on a given query." + tags: + [ + "research", + "exa", + "tavily", + "openrouter", + "sambanova", + "langfuse", + "parallel", + ] + use_cases: "Research on a given query." + +# Langfuse project name for LLM tracking +langfuse_project_name: "deep-research" + +# Pipeline parameters +parameters: + query: "How are people balancing MLOps and Agents/LLMOps?" + max_sub_questions: 10 + +# Step parameters +steps: + initial_query_decomposition_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + process_sub_question_step: + parameters: + llm_model_search: "sambanova/Meta-Llama-3.3-70B-Instruct" + llm_model_synthesis: "sambanova/DeepSeek-R1-Distill-Llama-70B" + cap_search_length: 20000 + + merge_sub_question_results_step: + parameters: + step_prefix: "process_question_" + output_name: "output" + + cross_viewpoint_analysis_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + viewpoint_categories: ["scientific", "political", "economic", "social", "ethical", "historical"] + + generate_reflection_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + get_research_approval_step: + parameters: + timeout: 3600 + max_queries: 2 + + execute_approved_searches_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + cap_search_length: 20000 + + pydantic_final_report_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 + + # Uncomment and customize these settings when running with orchestrators that support parallelization + # orchestrator.kubeflow: + # synchronous: false + # resources: + # cpu_request: "1" + # memory_request: "2Gi" + # cpu_limit: "2" + # memory_limit: "4Gi" + + # orchestrator.kubernetes: + # synchronous: false + # resources: + # process_sub_question_step: + # cpu_request: "1" + # memory_request: "2Gi" \ No newline at end of file diff --git a/deep_research/configs/pipeline_config.yaml b/deep_research/configs/pipeline_config.yaml new file mode 100644 index 00000000..84e2520b --- /dev/null +++ b/deep_research/configs/pipeline_config.yaml @@ -0,0 +1,58 @@ +# Deep Research Pipeline Configuration +enable_cache: true + +# ZenML MCP +model: + name: "deep_research" + description: "Parallelized ZenML pipelines for deep research on a given query." + tags: + [ + "research", + "exa", + "tavily", + "openrouter", + "sambanova", + "langfuse", + "pipeline", + ] + use_cases: "Research on a given query." + +# Langfuse project name for LLM tracking +langfuse_project_name: "deep-research" + +# Research parameters +parameters: + query: "Default research query" # The research query/topic to investigate + +steps: + initial_query_decomposition_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + process_sub_question_step: + parameters: + llm_model_search: "sambanova/Meta-Llama-3.3-70B-Instruct" + llm_model_synthesis: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + generate_reflection_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + get_research_approval_step: + parameters: + auto_approve: false + + execute_approved_searches_step: + parameters: + llm_model: "sambanova/Meta-Llama-3.3-70B-Instruct" + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 \ No newline at end of file diff --git a/deep_research/configs/quick_research.yaml b/deep_research/configs/quick_research.yaml new file mode 100644 index 00000000..b210f18f --- /dev/null +++ b/deep_research/configs/quick_research.yaml @@ -0,0 +1,59 @@ +# Deep Research Pipeline Configuration - Quick Research +enable_cache: true + +# ZenML MCP +model: + name: "deep_research" + description: "Parallelized ZenML pipelines for deep research on a given query." + tags: + [ + "research", + "exa", + "tavily", + "openrouter", + "sambanova", + "langfuse", + "quick", + ] + use_cases: "Research on a given query." + +# Langfuse project name for LLM tracking +langfuse_project_name: "deep-research" + +# Research parameters for quick research +parameters: + query: "Default research query" + +steps: + initial_query_decomposition_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + max_sub_questions: 5 # Limit to fewer sub-questions for quick research + + process_sub_question_step: + parameters: + llm_model_search: "sambanova/Meta-Llama-3.3-70B-Instruct" + llm_model_synthesis: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + generate_reflection_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + get_research_approval_step: + parameters: + auto_approve: true # Auto-approve for quick research + + execute_approved_searches_step: + parameters: + llm_model: "sambanova/Meta-Llama-3.3-70B-Instruct" + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 diff --git a/deep_research/configs/rapid_research.yaml b/deep_research/configs/rapid_research.yaml new file mode 100644 index 00000000..e69982bf --- /dev/null +++ b/deep_research/configs/rapid_research.yaml @@ -0,0 +1,59 @@ +# Deep Research Pipeline Configuration - Quick Research +enable_cache: true + +# ZenML MCP +model: + name: "deep_research" + description: "Parallelized ZenML pipelines for deep research on a given query." + tags: + [ + "research", + "exa", + "tavily", + "openrouter", + "sambanova", + "langfuse", + "rapid", + ] + use_cases: "Research on a given query." + +# Langfuse project name for LLM tracking +langfuse_project_name: "deep-research" + +# Research parameters for quick research +parameters: + query: "Default research query" + +steps: + initial_query_decomposition_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + max_sub_questions: 5 # Limit to fewer sub-questions for quick research + + process_sub_question_step: + parameters: + llm_model_search: "sambanova/Meta-Llama-3.3-70B-Instruct" + llm_model_synthesis: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + generate_reflection_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + get_research_approval_step: + parameters: + auto_approve: true # Auto-approve for quick research + + execute_approved_searches_step: + parameters: + llm_model: "sambanova/Meta-Llama-3.3-70B-Instruct" + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 diff --git a/deep_research/configs/thorough_research.yaml b/deep_research/configs/thorough_research.yaml new file mode 100644 index 00000000..a798577e --- /dev/null +++ b/deep_research/configs/thorough_research.yaml @@ -0,0 +1,63 @@ +# Deep Research Pipeline Configuration - Thorough Research +enable_cache: true + +# ZenML MCP +model: + name: "deep_research" + description: "Parallelized ZenML pipelines for deep research on a given query." + tags: + [ + "research", + "exa", + "tavily", + "openrouter", + "sambanova", + "langfuse", + "thorough", + ] + use_cases: "Research on a given query." + +# Langfuse project name for LLM tracking +langfuse_project_name: "deep-research" + +# Research parameters for more thorough research +parameters: + query: "Quantum computing applications in cryptography" + +steps: + initial_query_decomposition_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + max_sub_questions: 12 # More sub-questions for thorough analysis + + process_sub_question_step: + parameters: + llm_model_search: "sambanova/Meta-Llama-3.3-70B-Instruct" + llm_model_synthesis: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + generate_reflection_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + get_research_approval_step: + parameters: + auto_approve: false + + execute_approved_searches_step: + parameters: + llm_model: "sambanova/Meta-Llama-3.3-70B-Instruct" + + pydantic_final_report_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 diff --git a/deep_research/design/budget_test_pipeline.py b/deep_research/design/budget_test_pipeline.py new file mode 100644 index 00000000..9a5026bb --- /dev/null +++ b/deep_research/design/budget_test_pipeline.py @@ -0,0 +1,102 @@ +import logging +import os + +import openai +import requests +from openai import BadRequestError, RateLimitError +from zenml import get_step_context, pipeline, step + +logger = logging.getLogger(__name__) + +LITELLM_API_KEY = os.getenv("LITELLM_BUDGET_TEST_API_KEY") +LITELLM_BASE_URL = os.getenv("LITELLM_BUDGET_TEST_BASE_URL") + +client = openai.OpenAI( + api_key=LITELLM_API_KEY, + base_url=LITELLM_BASE_URL, +) + + +@step +def set_budget(amount: float) -> None: + """Set the budget for the research project.""" + context = get_step_context() + run_name = context.pipeline_run.name + breakpoint() + + # create a user with the name set to run_name + response = requests.post( + f"{LITELLM_BASE_URL}/user/new", + headers={ + "Authorization": f"Bearer {LITELLM_API_KEY}", + "Content-Type": "application/json", + }, + json={"user_id": run_name}, + ) + logger.info(response.json()) + + logger.info(f"Setting budget to ${amount:.2f} for run {run_name}") + + # Update the user with budget settings + budget_response = requests.post( + f"{LITELLM_BASE_URL}/user/update", + headers={ + "Authorization": f"Bearer {LITELLM_API_KEY}", + "Content-Type": "application/json", + }, + json={ + "user_id": run_name, + "max_budget": amount, + "budget_duration": "1d", # Optional: resets daily - can be "1s", "1m", "1h", "1d", "1mo" + }, + ) + logger.info(f"Budget update response: {budget_response.json()}") + return + + +@step +def llm_functionality() -> str: + """Test the LLM functionality by generating a short joke.""" + context = get_step_context() + run_name = context.pipeline_run.name + + prompt = "Tell me a short joke about programming." + + try: + response = client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": prompt}], + user=run_name, + max_tokens=50, + ) + joke = response.choices[0].message.content + logger.info(joke) + return joke + + except BadRequestError as e: + if "budget" not in str(e).lower() and "exceeded" not in str(e).lower(): + raise # Re-raise if it's a different 400 error + + logger.error(f"Budget exceeded for user {run_name}: {e}") + raise ValueError(f"Budget exceeded for user {run_name}: {str(e)}") + except RateLimitError as e: + # HTTP 429 - Could be rate limits or provider budget exceeded + if "budget" in str(e).lower(): + logger.error(f"Provider budget exceeded: {e}") + raise ValueError(f"Provider budget exceeded: {str(e)}") + else: + logger.error(f"Rate limit hit: {e}") + raise ValueError(f"Rate limit exceeded: {str(e)}") + + +@pipeline(enable_cache=False) +def budget_test_pipeline() -> str: + """Test pipeline to set budget and check LLM functionality.""" + set_budget(amount=0.01) + llm_functionality(after="set_budget") + + +if __name__ == "__main__": + # Configure logging + logging.basicConfig(level=logging.INFO) + budget_test_pipeline() diff --git a/deep_research/design/exa-cheat-sheet.md b/deep_research/design/exa-cheat-sheet.md new file mode 100644 index 00000000..7f5f4df8 --- /dev/null +++ b/deep_research/design/exa-cheat-sheet.md @@ -0,0 +1,449 @@ +# Exa Python SDK Implementation Cheat Sheet + +## Overview + +Exa is a neural search engine designed for AI applications. Unlike traditional keyword-based search, Exa uses embeddings to understand semantic meaning, making it ideal for sophisticated search queries. + +## Installation & Setup + +```python +pip install exa-py +``` + +```python +from exa_py import Exa +import os + +# Initialize client +exa = Exa(os.getenv("EXA_API_KEY")) +``` + +## Core Methods Comparison + +### Search Methods + +| **Method** | **Purpose** | **Returns** | +|------------|-------------|-------------| +| `search()` | Basic search, returns links only | List of Result objects with URLs, titles, scores | +| `search_and_contents()` | Search + content retrieval in one call | Results with full text/highlights | +| `get_contents()` | Get content for specific document IDs | Content for provided IDs | +| `find_similar()` | Find pages similar to a URL | Similar pages | +| `find_similar_and_contents()` | Find similar + get content | Similar pages with content | + +### Key Differences from Tavily + +1. **Neural vs Keyword Search**: Exa defaults to neural search but supports keyword search via `type="keyword"` +2. **Content Integration**: Exa can retrieve full content in the same API call +3. **Semantic Similarity**: `find_similar()` functionality for finding related content +4. **Structured Response**: More predictable response structure + +## Basic Search Implementation + +### Simple Search (Links Only) + +```python +def exa_search_basic(query: str, num_results: int = 3) -> Dict[str, Any]: + try: + response = exa.search( + query=query, + num_results=num_results, + type="auto" # "auto", "neural", or "keyword" + ) + + return { + "query": query, + "results": [ + { + "url": result.url, + "title": result.title, + "score": result.score, + "published_date": result.published_date, + "author": result.author, + "id": result.id # Temporary document ID + } + for result in response.results + ] + } + except Exception as e: + return {"query": query, "results": [], "error": str(e)} +``` + +### Search with Content (Primary Method) + +```python +def exa_search_with_content( + query: str, + num_results: int = 3, + max_characters: int = 20000, + include_highlights: bool = False +) -> Dict[str, Any]: + try: + # Configure content options + text_options = {"max_characters": max_characters} + + kwargs = { + "query": query, + "num_results": num_results, + "text": text_options, + "type": "auto" + } + + # Add highlights if requested + if include_highlights: + kwargs["highlights"] = { + "highlights_per_url": 2, + "num_sentences": 3 + } + + response = exa.search_and_contents(**kwargs) + + return { + "query": query, + "results": [ + { + "url": result.url, + "title": result.title, + "content": result.text, # Full text content + "highlights": getattr(result, 'highlights', []), + "score": result.score, + "published_date": result.published_date, + "author": result.author, + "id": result.id + } + for result in response.results + ] + } + except Exception as e: + return {"query": query, "results": [], "error": str(e)} +``` + +## Advanced Search Parameters + +### Date Filtering + +```python +response = exa.search_and_contents( + query="AI research", + start_published_date="2024-01-01", + end_published_date="2024-12-31", + start_crawl_date="2024-01-01", # When Exa crawled the content + end_crawl_date="2024-12-31" +) +``` + +### Domain Filtering + +```python +response = exa.search_and_contents( + query="machine learning", + include_domains=["arxiv.org", "scholar.google.com"], + exclude_domains=["reddit.com", "twitter.com"] +) +``` + +### Search Types + +```python +# Neural search (default, semantic understanding) +neural_results = exa.search(query, type="neural") + +# Keyword search (Google-style) +keyword_results = exa.search(query, type="keyword") + +# Auto (Exa chooses best approach) +auto_results = exa.search(query, type="auto") +``` + +## Content Options Deep Dive + +### Text Content Options + +```python +text_options = { + "max_characters": 5000, # Limit content length + "include_html_tags": True # Keep HTML formatting +} + +response = exa.search_and_contents( + query="AI developments", + text=text_options +) +``` + +### Highlights Options + +```python +highlights_options = { + "highlights_per_url": 3, # Number of highlights per result + "num_sentences": 2, # Sentences per highlight + "query": "custom highlight query" # Override search query for highlights +} + +response = exa.search_and_contents( + query="machine learning", + highlights=highlights_options +) +``` + +### Combined Content Retrieval + +```python +# Get both full text and highlights +response = exa.search_and_contents( + query="quantum computing", + text={"max_characters": 10000}, + highlights={ + "highlights_per_url": 2, + "num_sentences": 1 + } +) + +# Results will have both .text and .highlights attributes +for result in response.results: + print(f"Title: {result.title}") + print(f"Full text: {result.text[:200]}...") + print(f"Highlights: {result.highlights}") +``` + +## Response Structure Analysis + +### Basic Result Object + +```python +class Result: + url: str # The webpage URL + id: str # Temporary document ID for get_contents() + title: Optional[str] # Page title + score: Optional[float] # Relevance score (0-1) + published_date: Optional[str] # Estimated publication date + author: Optional[str] # Content author if available +``` + +### Result with Content + +```python +class ResultWithText(Result): + text: str # Full page content (when text=True) + +class ResultWithHighlights(Result): + highlights: List[str] # Key excerpts + highlight_scores: List[float] # Relevance scores for highlights + +class ResultWithTextAndHighlights(Result): + text: str + highlights: List[str] + highlight_scores: List[float] +``` + +## Error Handling & Retry Logic + +```python +def exa_search_with_retry( + query: str, + max_retries: int = 2, + **kwargs +) -> Dict[str, Any]: + """Search with retry logic similar to your Tavily implementation""" + + # Alternative query strategies + query_variants = [ + query, + f'"{query}"', # Exact phrase + f"research about {query}", + f"article on {query}" + ] + + for attempt in range(max_retries + 1): + try: + current_query = query_variants[min(attempt, len(query_variants) - 1)] + + response = exa.search_and_contents( + query=current_query, + **kwargs + ) + + # Check if we got meaningful content + content_results = sum( + 1 for r in response.results + if hasattr(r, 'text') and r.text.strip() + ) + + if content_results > 0: + return { + "query": current_query, + "results": [ + { + "url": r.url, + "content": getattr(r, 'text', ''), + "title": r.title or '', + "snippet": ' '.join(getattr(r, 'highlights', [])[:1]), + "score": r.score, + "published_date": r.published_date, + "author": r.author + } + for r in response.results + ] + } + + except Exception as e: + if attempt == max_retries: + return {"query": query, "results": [], "error": str(e)} + continue + + return {"query": query, "results": []} +``` + +## Mapping to Your SearchResult Model + +Based on your Pydantic `SearchResult` model, here's the mapping: + +```python +def convert_exa_to_search_result(exa_result) -> SearchResult: + """Convert Exa result to your SearchResult format""" + + # Get the best available content + content = "" + if hasattr(exa_result, 'text') and exa_result.text: + content = exa_result.text + elif hasattr(exa_result, 'highlights') and exa_result.highlights: + content = f"Title: {exa_result.title}\n\nHighlights:\n" + "\n".join(exa_result.highlights) + elif exa_result.title: + content = f"Title: {exa_result.title}" + + # Create snippet from highlights or title + snippet = "" + if hasattr(exa_result, 'highlights') and exa_result.highlights: + snippet = exa_result.highlights[0] + elif exa_result.title: + snippet = exa_result.title + + return SearchResult( + url=exa_result.url, + content=content, + title=exa_result.title or "", + snippet=snippet + ) +``` + +## Additional Features + +### Find Similar Content + +```python +def find_similar_content(url: str, num_results: int = 3): + """Find content similar to a given URL""" + response = exa.find_similar_and_contents( + url=url, + num_results=num_results, + text=True, + exclude_source_domain=True # Don't include same domain + ) + return response.results +``` + +### Get Content by IDs + +```python +def get_content_by_ids(document_ids: List[str]): + """Retrieve full content for specific document IDs""" + response = exa.get_contents( + ids=document_ids, + text={"max_characters": 10000} + ) + return response.results +``` + +### Answer API (Tavily Alternative) + +```python +def exa_answer_query(query: str, include_full_text: bool = False): + """Get direct answer to question (similar to Tavily's answer feature)""" + response = exa.answer( + query=query, + text=include_full_text # Include full text of citations + ) + + return { + "answer": response.answer, + "citations": [ + { + "url": citation.url, + "title": citation.title, + "text": getattr(citation, 'text', '') if include_full_text else '', + "published_date": citation.published_date + } + for citation in response.citations + ] + } +``` + +## Configuration Toggle Implementation + +```python +class SearchConfig: + TAVILY = "tavily" + EXA = "exa" + +def unified_search( + query: str, + provider: str = SearchConfig.EXA, + **kwargs +) -> List[SearchResult]: + """Unified search interface supporting both providers""" + + if provider == SearchConfig.EXA: + exa_results = exa_search_with_content(query, **kwargs) + return [ + convert_exa_to_search_result(result) + for result in exa_results.get("results", []) + ] + elif provider == SearchConfig.TAVILY: + tavily_results = tavily_search(query, **kwargs) + return extract_search_results(tavily_results) + else: + raise ValueError(f"Unknown provider: {provider}") +``` + +## Performance & Cost Considerations + +### Pricing Structure +- **Neural Search**: $0.005 for 1-25 results, $0.025 for 26-100 results +- **Keyword Search**: $0.0025 for 1-100 results +- **Content**: $0.001 per page for text/highlights/summary +- **Live Crawling**: Automatic fallback for uncached content + +### Optimization Tips + +1. **Use `search_and_contents()`** instead of separate `search()` + `get_contents()` calls +2. **Set appropriate `max_characters`** to control content length and costs +3. **Choose search type wisely**: neural for semantic queries, keyword for exact matches +4. **Use domain filtering** to focus on high-quality sources +5. **Implement caching** for repeated queries + +## Environment Variables + +```bash +EXA_API_KEY=your_exa_api_key_here +``` + +## Testing Queries + +```python +# Test neural search capabilities +test_queries = [ + "fascinating article about machine learning", + "comprehensive guide to Python optimization", + "latest research in quantum computing", + "how to implement search functionality" +] + +for query in test_queries: + results = exa_search_with_content(query, num_results=2) + print(f"Query: {query}") + print(f"Results: {len(results['results'])}") + print("---") +``` + +This cheat sheet provides everything you need to implement Exa as a configurable +alternative to Tavily, maintaining compatibility with your existing +`SearchResult` model while leveraging Exa's advanced neural search capabilities. diff --git a/deep_research/design/exa_cost_tracking_fixes.md b/deep_research/design/exa_cost_tracking_fixes.md new file mode 100644 index 00000000..03fbfee9 --- /dev/null +++ b/deep_research/design/exa_cost_tracking_fixes.md @@ -0,0 +1,38 @@ +# Exa Cost Tracking - Formatting Fixes + +## Issues Fixed + +### 1. JavaScript Syntax in F-strings +The main issue was with JavaScript code inside Python f-strings. The curly braces `{}` in JavaScript objects conflicted with f-string syntax. + +**Solution:** +- Moved the f-string variable substitution outside the JavaScript code +- Created a JavaScript variable `totalCombinedCost` to hold the Python value +- This avoided having f-string expressions inside JavaScript function bodies + +### 2. Code Formatting +The formatter (ruff) also made several automatic improvements: +- Fixed line wrapping for long lines +- Adjusted import statement formatting +- Fixed whitespace consistency + +## Files Modified by Formatter + +1. `utils/search_utils.py` - Import formatting +2. `utils/pydantic_models.py` - Line wrapping +3. `steps/process_sub_question_step.py` - Line wrapping +4. `steps/execute_approved_searches_step.py` - Line wrapping +5. `steps/iterative_reflection_step.py` - Import formatting +6. `steps/collect_tracing_metadata_step.py` - Line wrapping +7. `steps/merge_results_step.py` - Line wrapping +8. `materializers/tracing_metadata_materializer.py` - JavaScript syntax fix and formatting + +## Testing + +All tests pass successfully: +- Python syntax validation: ✅ +- Exa API cost extraction: ✅ +- ResearchState cost tracking: ✅ +- Cost aggregation: ✅ + +The implementation is now fully functional and properly formatted. \ No newline at end of file diff --git a/deep_research/design/exa_cost_tracking_summary.md b/deep_research/design/exa_cost_tracking_summary.md new file mode 100644 index 00000000..c2971255 --- /dev/null +++ b/deep_research/design/exa_cost_tracking_summary.md @@ -0,0 +1,94 @@ +# Exa Cost Tracking Implementation Summary + +## Overview +This implementation adds comprehensive cost tracking for Exa search queries throughout the ZenML Deep Research pipeline. The costs are tracked at every step where searches are performed and aggregated for final visualization alongside LLM costs. + +## Key Changes + +### 1. Core Infrastructure +- **`utils/search_utils.py`**: Modified `exa_search()` to extract cost from `response.cost_dollars.total` +- **`utils/search_utils.py`**: Updated `extract_search_results()` and `search_and_extract_results()` to return tuple `(results, cost)` +- **`utils/pydantic_models.py`**: Added search cost tracking fields to `ResearchState`: + - `search_costs: Dict[str, float]` - Total costs by provider + - `search_cost_details: List[Dict[str, Any]]` - Detailed cost logs + +### 2. Pipeline Steps +Updated all steps that perform searches to handle the new cost tracking: + +- **`process_sub_question_step.py`**: Tracks costs for sub-question searches +- **`execute_approved_searches_step.py`**: Tracks costs for reflection-based searches +- **`iterative_reflection_step.py`**: Tracks costs for gap-filling searches +- **`merge_results_step.py`**: Aggregates costs from parallel sub-states + +Each step: +1. Unpacks the tuple from `search_and_extract_results()` +2. Updates `state.search_costs["exa"]` with cumulative cost +3. Appends detailed cost information to `state.search_cost_details` + +### 3. Metadata Collection +- **`utils/pydantic_models.py`**: Added search cost fields to `TracingMetadata` +- **`collect_tracing_metadata_step.py`**: Extracts search costs from final state and includes them in tracing metadata + +### 4. Visualization +- **`materializers/tracing_metadata_materializer.py`**: Enhanced to display: + - Individual search provider costs with query counts + - Combined cost summary (LLM + Search) + - Interactive doughnut chart showing cost breakdown + - Percentage calculations for cost distribution + +## Usage + +When running a pipeline with Exa as the search provider: + +```python +# In pipeline configuration +search_provider="exa" # or "both" to use Exa alongside Tavily +``` + +The pipeline will automatically: +1. Track costs for each Exa search query +2. Aggregate costs across all steps +3. Display total costs in the final visualization + +## Cost Information Captured + +For each search, the system captures: +- Provider name (e.g., "exa") +- Search query text +- Cost in dollars +- Timestamp +- Pipeline step name +- Purpose (e.g., "sub_question", "reflection_enhancement", "gap_filling") +- Related sub-question (if applicable) + +## Example Output + +In the final HTML visualization: +``` +Search Provider Costs +EXA Search: $0.0280 +10 queries • $0.0028/query + +Combined Cost Summary +LLM Cost: $0.1234 (81.5% of total) +Search Cost: $0.0280 (18.5% of total) +Total Pipeline Cost: $0.1514 +``` + +## Testing + +Run the test script to verify the implementation: +```bash +python design/test_exa_cost_tracking.py +``` + +All tests should pass, confirming: +- Exa API cost extraction works correctly +- ResearchState properly tracks costs +- Cost aggregation across steps functions properly + +## Notes + +- Tavily doesn't provide cost information in their API, so only Exa costs are tracked +- Costs are tracked even if searches fail (cost is still incurred) +- The implementation is backward compatible - pipelines without Exa will simply show no search costs \ No newline at end of file diff --git a/deep_research/design/hitl.md b/deep_research/design/hitl.md new file mode 100644 index 00000000..20359804 --- /dev/null +++ b/deep_research/design/hitl.md @@ -0,0 +1,475 @@ +# Human Approval Step - Technical Specification + +## Overview + +Add a human-in-the-loop approval mechanism to the Deep Research Pipeline that allows users to review and approve/reject additional research recommendations before they are executed. + +## Purpose + +1. **Cost Control**: Prevent runaway token usage and API calls +2. **Quality Control**: Allow subject matter experts to guide research direction +3. **Transparency**: Show stakeholders what additional research is being considered +4. **Flexibility**: Enable selective approval of specific research queries + +## Pipeline Integration Point + +Since ZenML requires a static DAG, we need to: +1. Split the reflection step into two separate steps +2. Insert an approval step between them +3. Pass the approval decision to the second reflection step + +### Current Pipeline Flow: +```python +# In parallel_research_pipeline.py +analyzed_state = cross_viewpoint_analysis_step(state=merged_state) +reflected_state = iterative_reflection_step(state=analyzed_state) +``` + +### New Pipeline Flow with Approval: +```python +# In parallel_research_pipeline.py +analyzed_state = cross_viewpoint_analysis_step(state=merged_state) + +# Step 1: Generate reflection and recommendations (no searches yet) +reflection_output = generate_reflection_step(state=analyzed_state) + +# Step 2: Get approval for recommended searches +approval_decision = get_research_approval_step( + state=reflection_output.state, + proposed_queries=reflection_output.recommended_queries, + critique_points=reflection_output.critique_summary +) + +# Step 3: Execute approved searches (if any) +reflected_state = execute_approved_searches_step( + state=reflection_output.state, + approval_decision=approval_decision, + original_reflection=reflection_output +) +``` + +## Implementation Components + +### 1. New Data Models + +Add to `utils/pydantic_models.py`: +```python +from typing import List, Dict, Any +from pydantic import BaseModel, Field + +class ReflectionOutput(BaseModel): + """Output from the reflection generation step.""" + state: ResearchState + recommended_queries: List[str] = Field(default_factory=list) + critique_summary: List[Dict[str, Any]] = Field(default_factory=list) + additional_questions: List[str] = Field(default_factory=list) + +class ApprovalDecision(BaseModel): + """Approval decision from human reviewer.""" + approved: bool = False + selected_queries: List[str] = Field(default_factory=list) + approval_method: str = "" # "APPROVE_ALL", "SKIP", "SELECT_SPECIFIC" + reviewer_notes: str = "" + timestamp: float = Field(default_factory=lambda: time.time()) +``` + +### 2. Split Reflection into Two Steps + +Create `steps/generate_reflection_step.py`: +```python +from typing import Annotated +from zenml import step +from utils.pydantic_models import ResearchState, ReflectionOutput + +@step +def generate_reflection_step( + state: ResearchState, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + reflection_prompt: str = REFLECTION_PROMPT, +) -> Annotated[ReflectionOutput, "reflection_output"]: + """ + Generate reflection and recommendations WITHOUT executing searches. + + This step only analyzes the current state and produces recommendations. + """ + logger.info("Generating reflection on research") + + # Existing reflection logic (from iterative_reflection_step) + reflection_input = prepare_reflection_input(state) + + reflection_result = get_structured_llm_output( + prompt=json.dumps(reflection_input), + system_prompt=reflection_prompt, + model=llm_model, + fallback_response={"critique": [], "additional_questions": [], "recommended_search_queries": []} + ) + + # Return structured output for next steps + return ReflectionOutput( + state=state, + recommended_queries=reflection_result.get("recommended_search_queries", []), + critique_summary=reflection_result.get("critique", []), + additional_questions=reflection_result.get("additional_questions", []) + ) +``` + +### 3. Approval Step + +Create `steps/approval_step.py`: +```python +from typing import Annotated +from zenml import step +from zenml.alerter import Client +from utils.pydantic_models import ResearchState, ReflectionOutput, ApprovalDecision +import json + +@step(enable_cache=False) # Never cache approval decisions +def get_research_approval_step( + reflection_output: ReflectionOutput, + require_approval: bool = True, + alerter_type: str = "slack", + timeout: int = 3600 +) -> Annotated[ApprovalDecision, "approval_decision"]: + """ + Get human approval for additional research queries. + + Always returns an ApprovalDecision object. If require_approval is False, + automatically approves all queries. + """ + + # If approval not required, auto-approve all + if not require_approval: + return ApprovalDecision( + approved=True, + selected_queries=reflection_output.recommended_queries, + approval_method="AUTO_APPROVED", + reviewer_notes="Approval not required by configuration" + ) + + # If no queries to approve, skip + if not reflection_output.recommended_queries: + return ApprovalDecision( + approved=False, + selected_queries=[], + approval_method="NO_QUERIES", + reviewer_notes="No additional queries recommended" + ) + + # Prepare approval request + message = format_approval_request( + main_query=reflection_output.state.main_query, + progress_summary=summarize_research_progress(reflection_output.state), + critique_points=reflection_output.critique_summary, + proposed_queries=reflection_output.recommended_queries + ) + + try: + # Get alerter and send request + client = Client() + response = client.active_stack.alerter.ask( + message=message, + params={"timeout": timeout} + ) + + # Parse response + return parse_approval_response(response, reflection_output.recommended_queries) + + except Exception as e: + logger.error(f"Approval request failed: {e}") + # On error, default to not approved + return ApprovalDecision( + approved=False, + selected_queries=[], + approval_method="ERROR", + reviewer_notes=f"Approval failed: {str(e)}" + ) +``` + +### 4. Execute Approved Searches Step + +Create `steps/execute_approved_searches_step.py`: +```python +def format_approval_request( + main_query: str, + progress_summary: Dict[str, Any], + critique_points: List[Dict[str, Any]], + proposed_queries: List[str] +) -> str: + """Format the approval request message.""" + + # High-priority critiques + high_priority = [c for c in critique_points if c.get("importance") == "high"] + + message = f""" +📊 **Research Progress Update** + +**Main Query:** {main_query} + +**Current Status:** +- Sub-questions analyzed: {progress_summary['completed_count']} +- Average confidence: {progress_summary['avg_confidence']} +- Low confidence areas: {progress_summary['low_confidence_count']} + +**Key Issues Identified:** +{format_critique_summary(high_priority)} + +**Proposed Additional Research** ({len(proposed_queries)} queries): +{format_query_list(proposed_queries)} + +**Estimated Additional Time:** ~{len(proposed_queries) * 2} minutes +**Estimated Additional Cost:** ~${calculate_estimated_cost(proposed_queries)} + +**Response Options:** +- Reply `APPROVE ALL` to proceed with all queries +- Reply `SKIP` to finish with current findings +- Reply `SELECT 1,3,5` to approve specific queries by number + +**Timeout:** Response required within {timeout//60} minutes +""" + return message + + +```python +from typing import Annotated +from zenml import step +from materializers.pydantic_materializer import ResearchStateMaterializer +from utils.pydantic_models import ( + ResearchState, ReflectionOutput, ApprovalDecision, + ReflectionMetadata, SynthesizedInfo +) + +@step(output_materializers=ResearchStateMaterializer) +def execute_approved_searches_step( + reflection_output: ReflectionOutput, + approval_decision: ApprovalDecision, + num_results_per_search: int = 3, + cap_search_length: int = 20000, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + additional_synthesis_prompt: str = ADDITIONAL_SYNTHESIS_PROMPT, +) -> Annotated[ResearchState, "updated_state"]: + """ + Execute approved searches and enhance the research state. + + This step receives the approval decision and only executes + searches that were approved. + """ + logger.info(f"Processing approval decision: {approval_decision.approval_method}") + + state = reflection_output.state + enhanced_info = create_enhanced_info_copy(state.synthesized_info) + + # Check if we should execute searches + if not approval_decision.approved or not approval_decision.selected_queries: + logger.info("No additional searches approved") + + # Create metadata indicating no additional research + reflection_metadata = ReflectionMetadata( + critique_summary=[c.get("issue", "") for c in reflection_output.critique_summary], + additional_questions_identified=reflection_output.additional_questions, + searches_performed=[], + improvements_made=0, + user_decision=approval_decision.approval_method, + reviewer_notes=approval_decision.reviewer_notes + ) + + state.update_after_reflection(enhanced_info, reflection_metadata) + return state + + # Execute approved searches + logger.info(f"Executing {len(approval_decision.selected_queries)} approved searches") + + for query in approval_decision.selected_queries: + logger.info(f"Performing approved search: {query}") + + # Execute search (existing logic from iterative_reflection_step) + search_results = search_and_extract_results( + query=query, + max_results=num_results_per_search, + cap_content_length=cap_search_length, + ) + + # Find relevant sub-question and enhance + # ... (rest of enhancement logic from original iterative_reflection_step) + + # Create final metadata with approval info + reflection_metadata = ReflectionMetadata( + critique_summary=[c.get("issue", "") for c in reflection_output.critique_summary], + additional_questions_identified=reflection_output.additional_questions, + searches_performed=approval_decision.selected_queries, + improvements_made=count_improvements(enhanced_info), + user_decision=approval_decision.approval_method, + reviewer_notes=approval_decision.reviewer_notes + ) + + state.update_after_reflection(enhanced_info, reflection_metadata) + return state +``` + +### 5. Updated Pipeline Definition + +Update `pipelines/parallel_research_pipeline.py`: +```python +@step(output_materializers=ResearchStateMaterializer) +def iterative_reflection_step( + state: ResearchState, + max_additional_searches: int = 2, + require_approval: bool = False, # NEW + approval_timeout: int = 3600, # NEW + alerter_type: str = "slack", # NEW + # ... other params +) -> Annotated[ResearchState, "updated_state"]: + """Perform iterative reflection with optional human approval.""" + + # ... existing reflection logic ... + + # Get recommended queries + search_queries = reflection_result.get("recommended_search_queries", []) + + # NEW: Approval gate + if require_approval and search_queries: + approved, selected_queries = get_research_approval_step( + state=state, + proposed_queries=search_queries[:max_additional_searches], + reflection_critique=reflection_result.get("critique", []), + alerter_type=alerter_type, + timeout=approval_timeout + ) + + if not approved: + logger.info("Additional research not approved by user") + # Create metadata indicating skipped research + reflection_metadata = ReflectionMetadata( + critique_summary=[item.get("issue", "") for item in reflection_result.get("critique", [])], + additional_questions_identified=reflection_result.get("additional_questions", []), + searches_performed=[], + improvements_made=0, + user_decision="SKIPPED_ADDITIONAL_RESEARCH" + ) + state.update_after_reflection(state.synthesized_info, reflection_metadata) + return state + + # Use only approved queries + search_queries = selected_queries + + # ... continue with approved searches ... +``` + +## Testing Strategy + +### Unit Tests for Approval Logic (`tests/test_approval_utils.py`): + +Focus on testing the core approval parsing logic without running actual ZenML steps: + +```python +import pytest +from utils.approval_utils import parse_approval_response +from utils.pydantic_models import ApprovalDecision + +def test_parse_approval_responses(): + """Test parsing different approval responses.""" + queries = ["query1", "query2", "query3"] + + # Test approve all + decision = parse_approval_response("APPROVE ALL", queries) + assert decision.approved == True + assert decision.selected_queries == queries + assert decision.approval_method == "APPROVE_ALL" + + # Test skip + decision = parse_approval_response("skip", queries) # Test case insensitive + assert decision.approved == False + assert decision.selected_queries == [] + assert decision.approval_method == "SKIP" + + # Test selection + decision = parse_approval_response("SELECT 1,3", queries) + assert decision.approved == True + assert decision.selected_queries == ["query1", "query3"] + assert decision.approval_method == "SELECT_SPECIFIC" + + # Test invalid selection + decision = parse_approval_response("SELECT invalid", queries) + assert decision.approved == False + assert decision.approval_method == "PARSE_ERROR" + + # Test out of range indices + decision = parse_approval_response("SELECT 1,5,10", queries) + assert decision.approved == True + assert decision.selected_queries == ["query1"] # Only valid indices + assert decision.approval_method == "SELECT_SPECIFIC" + + +def test_format_approval_request(): + """Test formatting of approval request messages.""" + from utils.approval_utils import format_approval_request + + message = format_approval_request( + main_query="Test query", + progress_summary={ + 'completed_count': 5, + 'avg_confidence': 0.75, + 'low_confidence_count': 2 + }, + critique_points=[ + {"issue": "Missing data", "importance": "high"}, + {"issue": "Minor gap", "importance": "low"} + ], + proposed_queries=["query1", "query2"] + ) + + assert "Test query" in message + assert "5" in message + assert "0.75" in message + assert "2 queries" in message + assert "APPROVE ALL" in message + assert "SKIP" in message + assert "SELECT" in message +``` + +### Note on Testing Approach: +- We focus on unit tests for the approval parsing and formatting logic +- Integration testing with actual ZenML steps and alerters will be done manually during development +- No full end-to-end integration tests are included to keep the test suite lightweight + +## Key Changes Summary + +### What Changed: +1. **Split `iterative_reflection_step` into 3 steps** to comply with ZenML's static DAG requirement +2. **Always execute all steps** - the approval step runs every time but auto-approves when `require_approval=False` +3. **Pass data between steps** using new Pydantic models (`ReflectionOutput`, `ApprovalDecision`) +4. **No conditionals in pipeline definition** - all branching logic moved inside steps + +### New Files to Create: +- `steps/generate_reflection_step.py` +- `steps/approval_step.py` +- `steps/execute_approved_searches_step.py` +- `utils/approval_utils.py` +- Updates to `utils/pydantic_models.py` +- Updates to `pipelines/parallel_research_pipeline.py` + +### Configuration: +- Add `require_approval` and `approval_timeout` to pipeline parameters +- Configure alerter in stack or via environment variables + +### Usage: +```bash +# With approval enabled +python run.py --config configs/enhanced_research.yaml + +# Without approval (default behavior) +python run.py --config configs/enhanced_research.yaml --no-approval +``` + +## Implementation Checklist + +- [ ] Add new Pydantic models (`ReflectionOutput`, `ApprovalDecision`) +- [ ] Split existing `iterative_reflection_step` into `generate_reflection_step` +- [ ] Create `get_research_approval_step` +- [ ] Create `execute_approved_searches_step` +- [ ] Update pipeline definition with new steps +- [ ] Add approval utility functions +- [ ] Configure Slack/Email alerter +- [ ] Add unit tests +- [ ] Add integration test with mocked alerter +- [ ] Update documentation +- [ ] Test end-to-end flow diff --git a/deep_research/design/lite_test.py b/deep_research/design/lite_test.py new file mode 100644 index 00000000..7a59b255 --- /dev/null +++ b/deep_research/design/lite_test.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 + +import requests + +# Configuration +LITELLM_BASE_URL = "https://litellm-service-5ikaahlouq-uc.a.run.app" +API_KEY = "zenmllitellm" + + +def test_model(model_name, prompt, max_tokens=100): + """Test a specific model with a prompt""" + + url = f"{LITELLM_BASE_URL}/v1/chat/completions" + headers = { + "Authorization": f"Bearer {API_KEY}", + "Content-Type": "application/json", + } + + data = { + "model": model_name, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": max_tokens, + } + + try: + print(f"\n🤖 Testing {model_name}...") + print(f"📝 Prompt: {prompt}") + print("⏳ Waiting for response...") + + response = requests.post(url, headers=headers, json=data) + response.raise_for_status() + + result = response.json() + content = result["choices"][0]["message"]["content"] + usage = result["usage"] + + print(f"✅ Response: {content}") + print( + f"📊 Tokens: {usage['prompt_tokens']} prompt + {usage['completion_tokens']} completion = {usage['total_tokens']} total" + ) + + return True + + except requests.exceptions.RequestException as e: + print(f"❌ Error with {model_name}: {e}") + return False + except KeyError as e: + print(f"❌ Unexpected response format: {e}") + return False + + +def main(): + """Test multiple models with different prompts""" + + print("🚀 Testing LiteLLM Deployment") + print("=" * 50) + + # Test cases: (model, prompt) + test_cases = [ + ("gpt-4o", "Tell me a short joke about programming."), + ("gpt-4o-mini", "What is the capital of Japan?"), + ("claude-3-5-sonnet", "Explain quantum computing in one sentence."), + ("claude-3-5-haiku", "Write a haiku about artificial intelligence."), + ("gpt-3.5-turbo", "What's 15 * 23?"), + ] + + successful_tests = 0 + total_tests = len(test_cases) + + for model, prompt in test_cases: + if test_model(model, prompt): + successful_tests += 1 + + print("\n" + "=" * 50) + print(f"🎯 Results: {successful_tests}/{total_tests} tests passed") + + if successful_tests == total_tests: + print("🎉 All models are working perfectly!") + else: + print("⚠️ Some models may need attention.") + + +if __name__ == "__main__": + main() diff --git a/deep_research/design/prompt_cost_visualization.md b/deep_research/design/prompt_cost_visualization.md new file mode 100644 index 00000000..457d90f4 --- /dev/null +++ b/deep_research/design/prompt_cost_visualization.md @@ -0,0 +1,236 @@ +# Prompt-Based Cost Visualization Design + +## Overview + +This document outlines the design for disaggregating LLM costs by prompt type in the ZenML Deep Research pipeline. The goal is to provide detailed insights into which prompts consume the most tokens and incur the highest costs, enabling better optimization decisions. + +## Current State + +### Tracing Structure +- **Traces**: One per pipeline run, containing all LLM calls +- **Observations**: Individual LLM calls with model, tokens, and cost data +- **Current Visualization**: Aggregates costs by model and step name only + +### Problem +The current visualization shows total costs but doesn't break down spending by the specific prompt templates being used. This makes it difficult to: +1. Identify which prompts are most expensive +2. Optimize token usage for specific prompt types +3. Understand the cost distribution across different pipeline phases + +## Proposed Solution + +### 1. Prompt Type Identification + +Create a mapping of prompt types to unique keywords that appear in each prompt template: + +```python +PROMPT_IDENTIFIERS = { + "query_decomposition": ["MAIN RESEARCH QUERY", "DIFFERENT DIMENSIONS", "sub-questions"], + "search_query": ["Deep Research assistant", "effective search query"], + "synthesis": ["information synthesis", "comprehensive answer", "confidence level"], + "viewpoint_analysis": ["multi-perspective analysis", "viewpoint categories"], + "reflection": ["critique and improve", "information gaps"], + "additional_synthesis": ["enhance the original synthesis"], + "conclusion_generation": ["Synthesis and Integration", "Direct Response to Main Query"], + "executive_summary": ["executive summaries", "Key Findings", "250-400 words"], + "introduction": ["engaging introductions", "Context and Relevance"], +} +``` + +Note: From the sample observation, the system prompt is accessed via: +- `observation.input['messages'][0]['content']` for the system prompt +- `observation.input['messages'][1]['content']` for the user input +- Token usage is in `observation.usage.input` and `observation.usage.output` +- Cost is in `observation.calculated_total_cost` (defaults to 0.0) + +### 2. New Utility Functions + +Add to `utils/tracing_metadata_utils.py`: + +```python +def identify_prompt_type(observation: ObservationsView) -> Optional[str]: + """ + Identify the prompt type based on keywords in the observation's input. + + Examines the system prompt in observation.input['messages'][0]['content'] + for unique keywords that identify each prompt type. + + Returns: + str: The prompt type name, or "unknown" if not identified + """ + +def get_costs_by_prompt_type(trace_id: str) -> Dict[str, Dict[str, float]]: + """ + Get cost breakdown by prompt type for a given trace. + + Uses observation.usage.input/output for token counts and + observation.calculated_total_cost for costs. + + Returns: + Dict mapping prompt_type to { + 'cost': float, + 'input_tokens': int, + 'output_tokens': int, + 'count': int # number of calls + } + """ + +def get_prompt_type_statistics(trace_id: str) -> Dict[str, Dict[str, Any]]: + """ + Get detailed statistics for each prompt type. + + Returns: + Dict mapping prompt_type to { + 'cost': float, + 'input_tokens': int, + 'output_tokens': int, + 'count': int, + 'avg_cost_per_call': float, + 'avg_input_tokens': float, + 'avg_output_tokens': float, + 'percentage_of_total_cost': float + } + """ +``` + +### 3. Visualization Updates + +#### A. Data Collection +Update `steps/collect_tracing_metadata_step.py` to: + +1. Call `get_costs_by_prompt_type()` for the trace +2. Calculate percentages and averages +3. Store prompt-level metrics in the `TracingMetadata` model + +#### B. Add to Pydantic Model +Update `utils/pydantic_models.py`: + +```python +class PromptTypeMetrics(BaseModel): + """Metrics for a specific prompt type.""" + prompt_type: str + total_cost: float + input_tokens: int + output_tokens: int + call_count: int + avg_cost_per_call: float + percentage_of_total_cost: float + +class TracingMetadata(BaseModel): + # ... existing fields ... + prompt_metrics: List[PromptTypeMetrics] = Field( + default_factory=list, + description="Cost breakdown by prompt type" + ) +``` + +#### C. HTML Visualization +Update `materializers/tracing_metadata_materializer.py` to add: + +1. **Bar Chart**: Cost by prompt type + - X-axis: Prompt types + - Y-axis: Cost in USD + - Color-coded bars with hover tooltips + +2. **Token Usage Chart**: Stacked bar chart + - X-axis: Prompt types + - Y-axis: Token count + - Stacked: Input tokens (bottom) and output tokens (top) + +3. **Efficiency Table**: + - Columns: Prompt Type, Total Cost, Calls, Avg Cost/Call, % of Total + - Sortable by any column + - Highlight most expensive prompts + +### 4. Implementation Approach + +#### Phase 1: Core Functionality +1. Implement `identify_prompt_type()` with robust keyword matching + - Access system prompt via `observation.input['messages'][0]['content']` + - Handle cases where messages structure differs + - Add fallback logic for observations without clear prompt type +2. Test with sample traces to ensure accurate categorization + +#### Phase 2: Cost Aggregation +1. Implement `get_costs_by_prompt_type()` + - Use `observation.usage.input` and `observation.usage.output` for token counts + - Use `observation.calculated_total_cost` for cost (fallback to 0.0 if None) + - Handle edge cases (missing cost data, partial tokens) +2. Add caching for performance with large traces + +#### Phase 3: Visualization +1. Update data models +2. Implement chart generation using Chart.js or similar +3. Add interactive features (sorting, filtering) + +### 5. Visualization Mockup + +``` +┌─────────────────────────────────────────────────────────┐ +│ Cost by Prompt Type │ +├─────────────────────────────────────────────────────────┤ +│ │ +│ $0.50 ┤ ████ │ +│ $0.40 ┤ ████ ████ │ +│ $0.30 ┤ ████ ████ ████ │ +│ $0.20 ┤ ████ ████ ████ ████ │ +│ $0.10 ┤ ████ ████ ████ ████ ████ │ +│ $0.00 └────────────────────────────────────── │ +│ Query Synth Search Reflect Exec │ +│ Decomp Summary │ +└─────────────────────────────────────────────────────────┘ + +┌─────────────────────────────────────────────────────────┐ +│ Prompt Type Efficiency │ +├────────────────┬──────┬──────┬───────────┬────────────┤ +│ Prompt Type │ Cost │ Calls│ Avg $/Call│ % of Total │ +├────────────────┼──────┼──────┼───────────┼────────────┤ +│ Query Decomp │$0.45 │ 3 │ $0.15 │ 28% │ +│ Synthesis │$0.38 │ 12 │ $0.03 │ 24% │ +│ Search Query │$0.25 │ 45 │ $0.006 │ 16% │ +│ Reflection │$0.20 │ 8 │ $0.025 │ 13% │ +│ Executive Sum │$0.15 │ 1 │ $0.15 │ 9% │ +└────────────────┴──────┴──────┴───────────┴────────────┘ +``` + +### 6. Future Enhancements + +1. **Drill-down capability**: Click on a prompt type to see individual observations +2. **Time-series analysis**: Track prompt costs over multiple pipeline runs +3. **Optimization suggestions**: Automatically identify prompts that could be shortened +4. **A/B testing support**: Compare costs between different prompt versions +5. **Export functionality**: Download cost data as CSV/JSON + +### 7. Configuration + +Add to pipeline configs: + +```yaml +cost_visualization: + group_by_prompt: true + show_token_breakdown: true + highlight_threshold: 0.1 # Highlight prompts > 10% of total cost +``` + +## Benefits + +1. **Cost Transparency**: Clear understanding of where money is being spent +2. **Optimization Targets**: Identify which prompts to optimize first +3. **Token Efficiency**: See which prompts generate the most output relative to input +4. **Budget Planning**: Better estimates for future research tasks +5. **Prompt Engineering**: Data-driven approach to prompt refinement + +## Testing Strategy + +1. Unit tests for prompt identification logic +2. Integration tests with sample Langfuse data +3. Visualization tests using snapshot testing +4. Performance tests with large traces (1000+ observations) + +## Migration Plan + +Since this is an additive feature: +1. No breaking changes to existing code +2. Gradual rollout: Start with basic prompt identification +3. Feature flag for new visualizations +4. Backward compatibility with traces lacking prompt data \ No newline at end of file diff --git a/deep_research/design/pydantic_migration.md b/deep_research/design/pydantic_migration.md new file mode 100644 index 00000000..8ccbd792 --- /dev/null +++ b/deep_research/design/pydantic_migration.md @@ -0,0 +1,351 @@ +# Design Document: Migrating from Dataclasses to Pydantic Models in ZenML Deep Research + +## Overview + +This document outlines a plan to migrate the current dataclass-based state objects to Pydantic models in the ZenML Deep Research project. The migration will improve type validation, simplify serialization, leverage ZenML's built-in Pydantic support, and enable explicit `HTMLString` artifacts at key pipeline steps. + +## Current State + +The project currently uses: +- Dataclasses for state objects (`ResearchState`, `SearchResult`, etc.) +- Custom serialization/deserialization with `_convert_to_dict()` and `_convert_from_dict()` +- A custom `ResearchStateMaterializer` with 350+ lines of code + +## Migration Goals + +- [x] Replace dataclasses with Pydantic models for better validation and error messages +- [x] Leverage ZenML's built-in `PydanticMaterializer` to simplify serialization +- [x] Make HTML generation a first-class artifact where appropriate +- [x] Maintain existing visualizations while reducing code complexity + +## Implementation Plan + +### 1. Create Pydantic Model Equivalents +- [x] Convert leaf models first (`SearchResult`, `ViewpointTension`, etc.) +- [x] Implement nested models (`SynthesizedInfo`, `ViewpointAnalysis`, etc.) +- [x] Finally convert the main `ResearchState` model + +For each model, we'll follow this pattern: + +```python +# Before (dataclass) +@dataclass +class SearchResult: + url: str = "" + content: str = "" + title: str = "" + snippet: str = "" + +# After (Pydantic) +from pydantic import BaseModel, Field + +class SearchResult(BaseModel): + """Represents a search result for a sub-question.""" + url: str = "" + content: str = "" + title: str = "" + snippet: str = "" + + model_config = { + "extra": "ignore", # Ignore extra fields during deserialization + "frozen": False, # Allow attribute updates + "validate_assignment": True, # Validate when attributes are set + } +``` + +The main `ResearchState` model will require special attention to handle the update methods correctly: + +```python +class ResearchState(BaseModel): + # Base fields + main_query: str = "" + sub_questions: List[str] = Field(default_factory=list) + search_results: Dict[str, List[SearchResult]] = Field(default_factory=dict) + # ...other fields... + + model_config = { + "validate_assignment": True, + "frozen": False, + } + + def get_current_stage(self) -> str: + """Determine the current stage of research based on filled data.""" + if self.final_report_html: + return "final_report" + # ...rest of implementation... + + def update_sub_questions(self, sub_questions: List[str]) -> None: + """Update the sub-questions list.""" + self.sub_questions = sub_questions +``` + +### 2. Create Extended PydanticMaterializer +- [x] Create a new materializer that extends ZenML's `PydanticMaterializer` +- [x] Keep only the visualization logic from the original materializer +- [x] Remove all manual JSON serialization/deserialization code + +We'll create a materializer that extends the built-in PydanticMaterializer: + +```python +from zenml.materializers import PydanticMaterializer +from zenml.enums import ArtifactType, VisualizationType +from zenml.io import fileio +import os +from typing import Dict, Type, Any + +class ResearchStateMaterializer(PydanticMaterializer): + """Materializer for the ResearchState class with visualizations.""" + + ASSOCIATED_TYPES = (ResearchState,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + + def save_visualizations( + self, data: ResearchState + ) -> Dict[str, VisualizationType]: + """Create and save visualizations for the ResearchState. + + Args: + data: The ResearchState to visualize + + Returns: + Dictionary mapping file paths to visualization types + """ + # Generate an HTML visualization + visualization_path = os.path.join(self.uri, "research_state.html") + + # Create HTML content based on current stage + html_content = self._generate_visualization_html(data) + + # Write the HTML content to a file + with fileio.open(visualization_path, "w") as f: + f.write(html_content) + + # Return the visualization path and type + return {visualization_path: VisualizationType.HTML} + + def _generate_visualization_html(self, state: ResearchState) -> str: + """Generate HTML visualization for the research state. + + Args: + state: The ResearchState to visualize + + Returns: + HTML string + """ + # Copy the existing visualization generation logic + # ... +``` + +### 3. Update Step Signatures and Methods +- [x] Modify step signatures to return separate state and HTML artifacts where useful +- [x] Update docstrings and type hints +- [x] Register materializers with ZenML steps + +For key steps where HTML visualization is important (like final report): + +```python +from typing import Annotated, Tuple +from zenml.types import HTMLString + +@step( + output_materializers={ + "state": ResearchStateMaterializer, + "viz": None # Default HTML materializer + } +) +def final_report_step( + state: ResearchState, + llm_model: str = "gpt-4", +) -> Tuple[ + Annotated[ResearchState, "state"], + Annotated[HTMLString, "viz"] +]: + """Generate the final research report. + + Args: + state: The research state with synthesized information + llm_model: LLM model to use for report generation + + Returns: + Tuple of (updated research state, HTML visualization) + """ + # ... existing implementation ... + + # Generate HTML report + report_html = generate_report_from_template(state, llm_model) + + # Update state + state.final_report_html = report_html + + # Return both state and visualization + return state, HTMLString(report_html) +``` + +For most steps, we can keep the single ResearchState return type: + +```python +@step(output_materializers=ResearchStateMaterializer) +def process_sub_question_step( + state: ResearchState, + question_index: int, + # ... other parameters +) -> Annotated[ResearchState, "output"]: + """Process a single sub-question.""" + # ... implementation ... + return sub_state +``` + +### 4. Update Import References and Fix Pipeline Structure +- [x] Update imports in all pipeline files +- [x] Fix pipeline structure to handle new output types +- [x] Update step references + +Example of fixing a pipeline file: + +```python +from typing import List, Optional + +from zenml import pipeline +from zenml.types import HTMLString + +# Update imports for the Pydantic models +from utils.pydantic_models import ResearchState + +# Import your steps +from steps.query_decomposition_step import query_decomposition_step +from steps.process_sub_question_step import process_sub_question_step +from steps.merge_results_step import merge_results_step +from steps.final_report_step import final_report_step + +@pipeline +def research_pipeline( + query: str, + # ... other parameters +): + """Pipeline for deep research with enhanced capabilities.""" + + # Initialize research state + initial_state = query_decomposition_step(query=query) + + # Process each sub-question in parallel + sub_states = [] + for i in range(5): # Support up to 5 sub-questions + sub_state = process_sub_question_step( + state=initial_state, + question_index=i, + ) + sub_states.append(sub_state) + + # Merge results + merged_state = merge_results_step( + initial_state=initial_state, + sub_states=sub_states, + ) + + # Generate final report (returns tuple of state and HTML) + final_state, report_html = final_report_step( + state=merged_state, + ) + + return final_state, report_html +``` + +### 5. Clean Up Legacy Code and Testing +- [x] Remove old dataclass models once migration is complete +- [x] Remove manual serialization methods +- [x] Perform pipeline tests + +## Detailed Implementation Steps + +### Phase 1: Create Pydantic Models + +1. **Setup and Preparation** + - [x] Add Pydantic to requirements if not already there + - [x] Create a new file `utils/pydantic_models.py` for new models + +2. **Convert Simple Models** + - [x] Implement `SearchResult` model + - [x] Implement `ViewpointTension` model + - [x] Test serialization/deserialization + +3. **Implement Nested Models** + - [x] Implement `SynthesizedInfo` model + - [x] Implement `ViewpointAnalysis` model + - [x] Implement `ReflectionMetadata` model + - [x] Test nested serialization + +4. **Create Main ResearchState** + - [x] Implement `ResearchState` with all methods + - [x] Configure model settings for mutability + - [x] Test comprehensive serialization/deserialization + +### Phase 2: Implement New Materializer + +1. **Extract Visualization Logic** + - [x] Copy HTML generation from current materializer + - [x] Refactor as needed for Pydantic model access + +2. **Create Extended Materializer** + - [x] Create new `pydantic_materializer.py` file + - [x] Implement class extending PydanticMaterializer + - [x] Test basic saving/loading + +3. **Test Visualization Integration** + - [x] Test HTML generation with sample data + - [x] Ensure compatibility with ZenML UI + +### Phase 3: Update Step Signatures + +1. **Identify Key Visualization Steps** + - [x] Final report step + - [x] Viewpoint analysis step + - [x] Reflection step + +2. **Update Step Decorators** + - [x] Modify decorators to register materializers + - [x] Update return type annotations + +3. **Test Updated Steps** + - [x] Test step execution + - [x] Verify multiple outputs work correctly + +### Phase 4: Refactor Pipeline Code + +1. **Update Imports** + - [x] Change import statements in all files + - [x] Fix type annotations + +2. **Pipeline Integration** + - [x] Update pipeline to handle new return types + - [x] Test full pipeline execution + +3. **Final Cleanup** + - [x] Remove old dataclass implementation + - [x] Update documentation + +## Migration Testing Strategy + +1. **Unit Testing** + - Test each model conversion with sample data + - Verify serialization/deserialization works + +2. **Integration Testing** + - Test step execution with new models + - Verify HTML artifacts appear in ZenML UI + +3. **Pipeline Validation** + - Run complete pipeline with test query + - Compare results to pre-migration outputs + +## Timeline + +- Phase 1 (Models): 1-2 days +- Phase 2 (Materializer): 1 day +- Phase 3 (Step Signatures): 1-2 days +- Phase 4 (Cleanup): 1 day + +Total estimated time: 4-6 days + +## Conclusion + +This migration will modernize the codebase by leveraging Pydantic's validation capabilities and ZenML's built-in support for Pydantic models. The result will be a more maintainable, type-safe implementation with improved visualization capabilities through first-class HTML artifacts. \ No newline at end of file diff --git a/deep_research/design/sample_observation.md b/deep_research/design/sample_observation.md new file mode 100644 index 00000000..8aa1180f --- /dev/null +++ b/deep_research/design/sample_observation.md @@ -0,0 +1,220 @@ +``` +ObservationsView( + id='time-14-13-52-206331_gen-1748348032-7QuH7ONwLVwXpbZmZp9V', + trace_id='e8f5fee7-1b60-42ea-8f7f-fcfc286ec231', + type='GENERATION', + name='litellm-completion', + start_time=datetime.datetime(2025, 5, 27, 12, 13, 52, 206000, tzinfo=datetime.timezone.utc), + end_time=datetime.datetime(2025, 5, 27, 12, 13, 57, 603000, tzinfo=datetime.timezone.utc), + completion_start_time=datetime.datetime(2025, 5, 27, 12, 13, 57, 603000, tzinfo=datetime.timezone.utc), + model='google/gemini-2.5-flash-preview-05-20', + model_parameters={'temperature': '0.2', 'top_p': '0.9', 'max_tokens': 1500}, + input={ + 'messages': [ + { + 'role': 'system', + 'content': '\nYou are a Deep Research assistant specializing in synthesizing comprehensive research conclusions. Given all the research findings from a deep research study, your +task is to create a thoughtful, evidence-based conclusion that ties together the overall findings.\n\nYour conclusion should:\n\n1. **Synthesis and Integration (150-200 words):**\n - Connect +insights from different sub-questions to form a higher-level understanding\n - Identify overarching themes and patterns that emerge from the research\n - Highlight how different findings +relate to and support each other\n - Avoid simply summarizing each section separately\n\n2. **Direct Response to Main Query (100-150 words):**\n - Address the original research question +directly with evidence-based conclusions\n - State what the research definitively established vs. what remains uncertain\n - Provide a clear, actionable answer based on the synthesized +evidence\n\n3. **Limitations and Future Directions (100-120 words):**\n - Acknowledge remaining uncertainties and information gaps across all sections\n - Suggest specific areas where +additional research would be most valuable\n - Identify what types of evidence or perspectives would strengthen the findings\n\n4. **Implications and Applications (80-100 words):**\n - +Explain the practical significance of the research findings\n - Suggest how the insights might be applied or what they mean for stakeholders\n - Connect findings to broader contexts or +implications\n\nFormat your output as a well-structured conclusion section in HTML format with appropriate paragraph breaks and formatting. Use

tags for paragraphs and organize the content +logically with clear transitions between the different aspects outlined above.\n\nIMPORTANT: Do NOT include any headings like "Conclusion",

, or

tags - the section already has a heading. +Start directly with the conclusion content in paragraph form. Just create flowing, well-structured paragraphs that cover all four aspects naturally.\n\nEnsure the conclusion feels cohesive and +draws meaningful connections between findings rather than just listing them sequentially.\n' + }, + { + 'role': 'user', + 'content': '{\n "main_query": "Is LLMOps a subset of MLOps, or is it something completely different?",\n "sub_questions": [\n "What are the fundamental differences in the +lifecycle stages, tooling, and operational challenges between traditional MLOps practices and those specifically required for Large Language Models (LLMs)?",\n "To what extent do the unique +characteristics of LLMs, such as their scale, emergent behaviors, prompt engineering, and continuous pre-training/fine-tuning needs, necessitate a distinct operational framework beyond what MLOps +currently provides?"\n ],\n "enhanced_info": {\n "What are the fundamental differences in the lifecycle stages, tooling, and operational challenges between traditional MLOps practices and +those specifically required for Large Language Models (LLMs)?": {\n "synthesized_answer": "The fundamental differences between traditional MLOps practices and those required for Large +Language Models (LLMs) lie in their lifecycle stages, tooling, and operational challenges. Traditional MLOps focuses on the end-to-end lifecycle of machine learning models, including data +preparation, model training, evaluation, deployment, and monitoring. In contrast, LLMOps, a subset of GenAIOps, is specifically tailored for managing LLMs, which are characterized by their large +size, pre-training on vast datasets, and unique challenges such as prompt engineering, hallucinations, and the need for retrieval-augmented generation (RAG). Key differences include the emphasis +on adapting pre-trained foundation models rather than training from scratch, the use of specialized tools like prompt management and RAG workflows, and the focus on addressing LLM-specific +challenges such as cost, latency, and ethical considerations.",\n "confidence_level": "medium",\n "information_gaps": "The search results did not address the full range of operational +challenges specific to LLMs, such as model interpretability, scalability, and energy consumption. Additionally, there is limited discussion on the long-term maintenance and updating of LLMs in +production environments.",\n "key_sources": [\n "https://developer.nvidia.com/blog/mastering-llm-techniques-llmops/",\n +"https://wandb.ai/site/articles/understanding-llmops-large-language-model-operations/"\n ],\n "improvements": [\n "More detailed information on the operational challenges of +LLMs",\n "Discussion on the long-term maintenance and updating of LLMs",\n "Perspectives on model interpretability and scalability",\n "Failed to enhance synthesis"\n +]\n },\n "To what extent do the unique characteristics of LLMs, such as their scale, emergent behaviors, prompt engineering, and continuous pre-training/fine-tuning needs, necessitate a +distinct operational framework beyond what MLOps currently provides?": {\n "synthesized_answer": "Large Language Models (LLMs) necessitate a distinct operational framework beyond traditional +MLOps due to their unique characteristics such as scale, emergent behaviors, prompt engineering, and continuous pre-training/fine-tuning needs. While MLOps provides foundational practices for +machine learning model management, LLMs present specific challenges that require extensions and specialized approaches, leading to the emergence of LLMOps. LLMs\' large scale demands significant +computational resources and specialized deployment strategies, including efficient serving frameworks like vLLM, Ollama, and LocalAI, which optimize inference through techniques like +PagedAttention, continuous batching, and quantization, and offer user-friendly APIs for seamless integration. Their emergent behaviors, which can lead to unexpected outputs, necessitate advanced +monitoring and prompt engineering techniques, supported by observability tools such as WhyLabs LangKit, AgentOps, and Arize Phoenix, which provide insights into performance, error tracking, and +usage patterns, and help detect issues like malicious prompts, sensitive data leakage, and hallucinations. Additionally, the continuous need for pre-training and fine-tuning requires ongoing +updates and adaptations that are not typically addressed in standard MLOps workflows. This is facilitated by specialized fine-tuning tools and platforms. The LLMOps landscape is comprehensive, +encompassing integration frameworks, vector databases, RLHF services, LLM testing tools, LLM monitoring and observability tools, and fine-tuning tools. LLMOps platforms, whether designed +specifically for LLMs or MLOps platforms expanding their capabilities, offer features for finetuning, versioning, and deploying LLMs, with options ranging from no-code/low-code solutions for ease +of adoption to code-first platforms for greater flexibility. Data and cloud platforms are also increasingly offering LLMOps capabilities, allowing users to leverage their own data for building +and fine-tuning LLMs. Orchestration frameworks, including standard DevOps tools like Kubernetes and Docker Compose, as well as LLM-specific solutions like OpenLLM (BentoML), are crucial for +managing deployment, scaling, and automating workflows. API gateways, such as LiteLLM Proxy Server, manage data flow, handle routing and security, and simplify integration by providing a unified +interface for various LLM providers. The operationalization of LLMs also involves addressing critical aspects like model interpretability, scalability, and energy consumption, which are central +to large-scale AI deployments. Furthermore, long-term maintenance and updating of LLMs in production environments are vital for understanding total cost of ownership and ongoing operational +burden. The growing ecosystem of LLMOps tools and companies, categorized across various functionalities like model deployment, training, experiment tracking, monitoring, security, data +management, prompt engineering, and vector search, underscores the specialized and evolving nature of LLMOps. This tailored framework is essential for effectively managing and operationalizing +LLMs in production environments, while also considering ethical implications, bias mitigation, and the broader societal impact of widespread LLM adoption.",\n "confidence_level": "high",\n +"information_gaps": "The search results provide a comprehensive overview of LLMOps and its differences from MLOps, but there may be gaps in discussing the full range of use cases, particularly in +highly specialized or less common applications. Additionally, more detailed information on the latest tools and practices in LLMOps, as well as real-world case studies, would complement the +existing information.",\n "key_sources": [\n "https://medium.com/@sahin.samia/a-comprehensive-analysis-of-llmops-managing-large-language-models-in-production-649ae793353a",\n +"https://aws.amazon.com/blogs/machine-learning/fmops-llmops-operationalize-generative-ai-and-differences-with-mlops/",\n "https://cloud.google.com/discover/what-is-llmops"\n ],\n +"improvements": [\n "More detailed case studies on LLMOps implementations across various industries.",\n "Information on the latest tools and technologies specifically developed for +LLMOps.",\n "Discussion of emerging trends and future directions in LLMOps.",\n "Incorporated detailed information on specific LLMOps tools and frameworks for serving (vLLM, Ollama, +LocalAI), orchestration (OpenLLM/BentoML, AutoGen, standard DevOps tools), API gateways (LiteLLM Proxy Server), and observability (WhyLabs LangKit, AgentOps, Arize Phoenix), addressing the +critique regarding the lack of depth in tooling and practices.",\n "Expanded on the operational challenges by explicitly mentioning model interpretability, scalability, and energy +consumption as core concerns for LLMs, directly addressing the completeness of operational challenges critique.",\n "Included discussion on the long-term maintenance and updating of LLMs +by highlighting the role of LLMOps platforms and frameworks in facilitating continuous updates and adaptations, thereby addressing the critique on limited discussion of this aspect.",\n +"Provided a broader landscape of LLMOps categories and functionalities, including integration frameworks, vector databases, RLHF services, LLM testing tools, fine-tuning tools, and various +aspects like security, privacy, compliance, data storage, and prompt engineering, enhancing the comprehensiveness of the synthesis.",\n "Acknowledged the ethical and societal implications +by emphasizing the need to consider bias mitigation and the broader societal impact of LLM adoption, setting the stage for deeper dives into these areas in future analyses."\n ]\n }\n +},\n "viewpoint_analysis": {\n "main_points_of_agreement": [\n "LLMs require specialized operational practices beyond traditional MLOps due to their unique characteristics.",\n +"Scale is a significant factor, demanding substantial computational resources and specialized deployment strategies for LLMs.",\n "Emergent behaviors of LLMs necessitate advanced monitoring +and prompt engineering techniques.",\n "Continuous pre-training and fine-tuning are critical for LLMs, requiring ongoing updates and adaptations.",\n "LLMOps is emerging as a distinct +framework or extension to MLOps to address these specific challenges.",\n "Prompt engineering is a key operational aspect unique to LLMs.",\n "Cost and latency are significant +operational challenges for LLMs in production."\n ],\n "areas_of_tension": [\n {\n "topic": "Defining LLMOps: Subset or Distinct Framework?",\n "viewpoints": {\n +"scientific": "From a scientific perspective, LLMOps is largely seen as an extension or specialization of MLOps, building upon its foundational principles but adding new methodologies and tools +to address LLM-specific complexities. The core scientific principles of experimentation, data management, and model evaluation remain, but the scale and emergent properties of LLMs introduce +novel research questions and engineering challenges. It\'s about refining existing scientific methods for a new class of models.",\n "political": "Politically, the distinction between +MLOps and LLMOps might be framed in terms of resource allocation, funding priorities, and regulatory oversight. If LLMOps is seen as \'completely different,\' it might warrant separate funding +streams, new regulatory bodies, or distinct policy frameworks for AI governance. If it\'s a \'subset,\' existing MLOps policies might be adapted, potentially leading to less new legislation but +requiring significant updates to current guidelines. This impacts who controls the narrative and resources.",\n "economic": "Economically, the classification impacts investment +strategies and market segmentation. If LLMOps is a distinct domain, it creates new market opportunities for specialized tools, services, and expertise, potentially leading to new startups and +venture capital interest. If it\'s merely a subset, existing MLOps vendors might simply expand their offerings, leading to consolidation rather than new market creation. The economic incentive +for defining it as distinct is higher for new entrants."\n }\n },\n {\n "topic": "Prioritizing Operational Challenges: Technical vs. Societal",\n "viewpoints": {\n +"scientific": "The scientific perspective primarily focuses on technical operational challenges: scalability, computational efficiency, model stability, and the effectiveness of prompt +engineering. The emphasis is on developing robust algorithms, optimized architectures, and reliable monitoring systems to ensure technical performance and reproducibility. Societal impacts are +often considered secondary to technical feasibility and performance metrics.",\n "social": "From a social perspective, operational challenges extend beyond technical performance to +include societal impacts like bias, fairness, accessibility, and the potential for misuse. The \'emergent behaviors\' of LLMs are not just technical glitches but can manifest as harmful +stereotypes or misinformation, requiring operational frameworks that prioritize ethical considerations and community well-being over pure technical efficiency. The focus is on responsible +deployment and mitigating negative social consequences.",\n "ethical": "The ethical perspective views operational challenges through the lens of moral responsibility. Issues like +\'hallucinations\' are not just technical errors but raise questions about truthfulness and accountability. The continuous fine-tuning process must ethically manage data privacy and consent. The +operational framework must embed mechanisms for transparency, accountability, and human oversight to ensure LLMs are developed and deployed in a morally sound manner, even if it adds complexity +or cost."\n }\n },\n {\n "topic": "The Role of Human Intervention and Expertise",\n "viewpoints": {\n "scientific": "Scientifically, human intervention in +LLMOps is often viewed as a necessary but ideally reducible component. The goal is to automate as much as possible, from data pipelines to model deployment and monitoring, using sophisticated +algorithms and AI-driven tools. Human expertise is crucial in the initial design, problem-solving, and interpreting complex results, but the operational ideal is high autonomy.",\n +"economic": "Economically, the role of human intervention is a cost-benefit analysis. High levels of human expertise (e.g., prompt engineers, AI ethicists) are expensive. The economic drive is to +automate tasks to reduce labor costs, but also to invest in specialized human capital where it provides a significant competitive advantage or mitigates high-risk failures (e.g., preventing +costly ethical breaches or major system outages).",\n "historical": "Historically, technological advancements often reduce the need for manual labor, but also create new specialized +roles. The historical perspective would note that while automation is a trend, new technologies like LLMs often introduce unforeseen complexities that require new forms of human expertise (e.g., +prompt engineering, RAG specialists) that didn\'t exist before. This suggests a continuous evolution of human roles rather than outright replacement, echoing past industrial revolutions."\n +}\n }\n ],\n "perspective_gaps": "The current research, while strong on technical and operational aspects, has significant gaps in fully exploring the political, ethical, and +historical dimensions of LLMOps. There\'s limited discussion on the geopolitical implications of LLM development and deployment, particularly concerning data sovereignty, international standards, +and the concentration of power among a few tech giants. The ethical considerations are mentioned (e.g., \'ethical considerations\' as a challenge) but lack deep dives into specific frameworks for +accountability, bias mitigation beyond technical fixes, and the societal impact of widespread LLM adoption on employment, information integrity, and human agency. Historically, the research +doesn\'t contextualize LLMOps within the broader evolution of software engineering, AI development, or even the history of industrial automation, missing insights into recurring patterns of +technological adoption, resistance, and societal adaptation. Including these missing perspectives would enrich understanding by providing a more holistic view of LLMOps not just as a technical +discipline but as a socio-technical system embedded within complex political, economic, and moral landscapes. Specific questions that remain unexplored include: What regulatory frameworks are +emerging globally for LLMOps, and how do they differ? What are the long-term ethical guidelines for continuous fine-tuning, especially concerning user data and evolving societal norms? How does +the rapid pace of LLM development compare to previous technological revolutions in terms of societal disruption and adaptation? What historical precedents exist for managing technologies with +emergent and unpredictable behaviors? How do different political systems approach the governance and control of LLM infrastructure and data?",\n "integrative_insights": "LLMOps, while +technically an extension of MLOps, represents a significant evolutionary step driven by the unique scale and emergent properties of Large Language Models. A scientific understanding highlights +the need for specialized tools and methodologies for prompt engineering, continuous fine-tuning, and advanced monitoring, acknowledging that while foundational MLOps principles apply, the \'how\' +changes significantly. Economically, this specialization creates new market niches and demands for highly skilled labor, even as automation seeks to reduce costs. The tension between defining +LLMOps as a subset versus a distinct field can be reconciled by viewing it as a \'specialized domain within MLOps,\' akin to how \'DevSecOps\' is a specialized domain within \'DevOps.\' This +acknowledges the shared foundational principles while recognizing the unique challenges and expertise required. From a political and ethical standpoint, the emergent behaviors of LLMs necessitate +a shift from purely technical operational concerns to include robust frameworks for accountability, bias mitigation, and responsible deployment. This means integrating ethical guidelines and +regulatory compliance directly into the operational lifecycle, not as an afterthought. Historical context suggests that while automation is a continuous trend, new technologies invariably create +new human roles and challenges, implying that LLMOps will require a blend of advanced automation and specialized human expertise. Actionable takeaways include: developing modular LLMOps platforms +that can integrate both traditional MLOps components and LLM-specific tools; investing in interdisciplinary teams that combine technical expertise with ethical and social science perspectives; +and advocating for adaptive regulatory frameworks that can evolve with the rapid pace of LLM development while ensuring public trust and safety. The seemingly contradictory viewpoints on +automation versus human intervention can be harmonized by recognizing that automation handles routine tasks, while human expertise is crucial for navigating the unpredictable, emergent, and +ethically complex aspects of LLMs."\n },\n "reflection_metadata": {\n "critique_summary": [\n "The research acknowledges that \'the full range of operational challenges specific to +LLMs, such as model interpretability, scalability, and energy consumption\' were not fully addressed. This is a significant gap as these are core operational concerns for any large-scale AI +deployment.",\n "While mentioning \'specialized tools like prompt management and RAG workflows,\' the research lacks detailed discussion on specific tools, platforms, or best practices +within LLMOps. This limits the practical applicability of the findings.",\n "The \'information_gaps\' section explicitly states \'more detailed information on the latest tools and practices +in LLMOps, as well as real-world case studies, would complement the existing information.\' Without concrete examples, the theoretical differences can be hard to grasp fully.",\n "The +research notes \'limited discussion on the long-term maintenance and updating of LLMs in production environments.\' This is crucial for understanding the total cost of ownership and ongoing +operational burden of LLMs.",\n "While ethical and societal viewpoints are introduced in the viewpoint analysis, the \'perspective_gaps\' section highlights a lack of \'deep dives into +specific frameworks for accountability, bias mitigation beyond technical fixes, and the societal impact of widespread LLM adoption on employment, information integrity, and human agency.\' This +is a critical area for responsible AI development.",\n "The \'perspective_gaps\' section points out \'limited discussion on the geopolitical implications of LLM development and deployment, +particularly concerning data sovereignty, international standards, and the concentration of power among a few tech giants.\' This is a major external factor influencing LLMOps.",\n "The +research \'doesn\'t contextualize LLMOps within the broader evolution of software engineering, AI development, or even the history of industrial automation.\' Understanding historical precedents +can provide valuable insights into challenges and solutions."\n ],\n "additional_questions_identified": [\n "What specific technical frameworks and architectural patterns are emerging +for scalable LLM deployment and inference?",\n "How do organizations currently measure and optimize the cost-effectiveness (e.g., inference cost, fine-tuning cost) of LLMs in production?",\n +"What are the leading open-source and commercial tools specifically designed for LLMOps, and how do they address the unique challenges?",\n "What are the best practices for data governance +and privacy within LLMOps, especially concerning sensitive user data used for fine-tuning or RAG?",\n "How are organizations addressing the interpretability and explainability challenges of +LLMs in production, particularly for critical applications?",\n "What are the current and anticipated regulatory trends globally that will impact LLMOps practices?",\n "How do different +industry sectors (e.g., healthcare, finance, creative industries) adapt LLMOps practices to their specific needs and compliance requirements?",\n "What are the emerging roles and skill sets +required for LLMOps teams, and how do they differ from traditional MLOps roles?",\n "What strategies are being employed to mitigate \'hallucinations\' and ensure factual accuracy in LLM +outputs in production environments?",\n "How does the concept of \'model decay\' or \'concept drift\' apply to LLMs, and what operational strategies are used to manage it?"\n ],\n +"improvements_made": 6.0\n }\n}' + } + ] + }, + version=None, + metadata={ + 'project': 'deep-research', + 'hidden_params': { + 'model_id': None, + 'cache_key': None, + 'api_base': None, + 'response_cost': None, + 'additional_headers': {}, + 'litellm_overhead_time_ms': None, + 'batch_models': None, + 'litellm_model_name': 'openrouter/google/gemini-2.5-flash-preview-05-20', + 'usage_object': None + }, + 'litellm_response_cost': None, + 'api_base': 'https://openrouter.ai/api/v1/chat/completions', + 'cache_hit': False, + 'requester_metadata': {} + }, + output={ + 'content': "

The research comprehensively demonstrates that while LLMOps shares foundational principles with MLOps, it is not merely a direct subset but rather a specialized and +significantly extended operational framework necessitated by the unique characteristics of Large Language Models (LLMs). Overarching themes reveal that the sheer scale of LLMs, their emergent +behaviors, the criticality of prompt engineering, and the continuous need for pre-training and fine-tuning fundamentally alter the traditional MLOps lifecycle, tooling requirements, and +operational challenges. Findings consistently highlight that LLMs demand specialized deployment strategies, advanced monitoring for unpredictable outputs, and dedicated workflows for prompt +management and RAG. These elements interrelate, with the scale driving the need for efficient serving frameworks, emergent behaviors necessitating sophisticated observability, and continuous +adaptation requiring specialized fine-tuning tools, all converging to form a distinct operational paradigm.

\n\n

In direct response to the main query, the research definitively establishes +that LLMOps is a specialized domain within MLOps, rather than a completely different discipline. It builds upon MLOps' core tenets of data management, model deployment, and monitoring but +introduces a new layer of complexity and specific requirements. The evidence strongly indicates that while MLOps provides the groundwork, the unique attributes of LLMs—such as their computational +demands, the nuances of prompt engineering, and the imperative for continuous adaptation—mandate a distinct set of tools, practices, and expertise. What remains less certain is the precise +boundary where MLOps ends and LLMOps begins, as many MLOps platforms are evolving to incorporate LLM-specific functionalities.

\n\n

Despite the comprehensive overview, several uncertainties +and information gaps persist. The research could benefit from more detailed real-world case studies illustrating the practical implementation of LLMOps across diverse industries, particularly +concerning cost optimization and long-term maintenance strategies. There's also a need for deeper exploration into specific frameworks for accountability and bias mitigation beyond technical +fixes, addressing the broader societal implications. Furthermore, the geopolitical dimensions of LLM development and deployment, including data sovereignty and international regulatory trends, +remain largely unexplored. Future research should focus on these areas to provide a more holistic understanding of LLMOps' practical, ethical, and global landscape.

\n\n

The findings carry +significant implications for organizations and practitioners. They underscore the necessity of investing in specialized LLMOps tools and expertise, recognizing that traditional MLOps approaches +alone are insufficient for effectively managing LLMs in production. For stakeholders, this means adapting existing MLOps teams and infrastructure, or building new capabilities, to address the +unique challenges of LLMs, including their computational intensity, the need for robust prompt engineering, and continuous model adaptation. The insights also highlight the critical importance of +integrating ethical considerations and responsible AI practices directly into the operational framework from the outset.

", + 'role': 'assistant', + 'tool_calls': None, + 'function_call': None + }, + usage=Usage(input=4598, output=573, total=5171, unit=, input_cost=None, output_cost=None, total_cost=None), + level=, + status_message=None, + parent_observation_id=None, + prompt_id=None, + usage_details={'input': 4598, 'output': 573, 'total': 5171}, + cost_details={}, + environment='default', + prompt_name=None, + prompt_version=None, + model_id=None, + input_price=None, + output_price=None, + total_price=None, + calculated_input_cost=None, + calculated_output_cost=None, + calculated_total_cost=0.0, + latency=5397.0, + time_to_first_token=5.397, + promptTokens=4598, + createdAt='2025-05-27T12:13:59.000Z', + totalTokens=5171, + updatedAt='2025-05-27T12:14:06.712Z', + unit='TOKENS', + projectId='cmb52g8bz01zead07rrupy94y', + completionTokens=573 +) +``` diff --git a/deep_research/design/test_exa_cost_tracking.py b/deep_research/design/test_exa_cost_tracking.py new file mode 100644 index 00000000..c04f09ed --- /dev/null +++ b/deep_research/design/test_exa_cost_tracking.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +"""Test script to verify Exa cost tracking implementation.""" + +import os +import sys + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from utils.pydantic_models import ResearchState +from utils.search_utils import ( + exa_search, + extract_search_results, + search_and_extract_results, +) + + +def test_exa_cost_extraction(): + """Test that Exa costs are properly extracted from API responses.""" + print("=== Testing Exa Cost Extraction ===") + + # Test with a simple query + query = "What is quantum computing?" + print(f"\nSearching for: {query}") + + # Test direct exa_search + results = exa_search(query, max_results=2) + print( + f"Direct exa_search returned exa_cost: ${results.get('exa_cost', 0.0):.4f}" + ) + + # Test extract_search_results + extracted, cost = extract_search_results(results, provider="exa") + print(f"extract_search_results returned cost: ${cost:.4f}") + print(f"Number of results extracted: {len(extracted)}") + + # Test search_and_extract_results + results2, cost2 = search_and_extract_results( + query, max_results=2, provider="exa" + ) + print(f"search_and_extract_results returned cost: ${cost2:.4f}") + print(f"Number of results: {len(results2)}") + + return cost2 > 0 + + +def test_research_state_cost_tracking(): + """Test that ResearchState properly tracks costs.""" + print("\n=== Testing ResearchState Cost Tracking ===") + + state = ResearchState(main_query="Test query") + + # Simulate adding search costs + state.search_costs["exa"] = 0.05 + state.search_cost_details.append( + { + "provider": "exa", + "query": "test query 1", + "cost": 0.02, + "timestamp": 1234567890.0, + "step": "test_step", + } + ) + state.search_cost_details.append( + { + "provider": "exa", + "query": "test query 2", + "cost": 0.03, + "timestamp": 1234567891.0, + "step": "test_step", + } + ) + + print(f"Total Exa cost: ${state.search_costs.get('exa', 0.0):.4f}") + print(f"Number of search details: {len(state.search_cost_details)}") + + return True + + +def test_cost_aggregation(): + """Test cost aggregation from multiple states.""" + print("\n=== Testing Cost Aggregation ===") + + # Create multiple sub-states + state1 = ResearchState(main_query="Test") + state1.search_costs["exa"] = 0.02 + state1.search_cost_details.append( + { + "provider": "exa", + "query": "query1", + "cost": 0.02, + "timestamp": 1234567890.0, + "step": "sub_step_1", + } + ) + + state2 = ResearchState(main_query="Test") + state2.search_costs["exa"] = 0.03 + state2.search_cost_details.append( + { + "provider": "exa", + "query": "query2", + "cost": 0.03, + "timestamp": 1234567891.0, + "step": "sub_step_2", + } + ) + + # Simulate merge + merged_state = ResearchState(main_query="Test") + merged_state.search_costs = {} + merged_state.search_cost_details = [] + + for state in [state1, state2]: + for provider, cost in state.search_costs.items(): + merged_state.search_costs[provider] = ( + merged_state.search_costs.get(provider, 0.0) + cost + ) + merged_state.search_cost_details.extend(state.search_cost_details) + + print( + f"Merged total cost: ${merged_state.search_costs.get('exa', 0.0):.4f}" + ) + print( + f"Merged search details count: {len(merged_state.search_cost_details)}" + ) + + return merged_state.search_costs.get("exa", 0.0) == 0.05 + + +def main(): + """Run all tests.""" + print("Testing Exa Cost Tracking Implementation\n") + + # Check if Exa API key is set + if not os.getenv("EXA_API_KEY"): + print("WARNING: EXA_API_KEY not set. Skipping real API tests.") + test_api = False + else: + test_api = True + + tests_passed = 0 + tests_total = 0 + + # Test 1: Exa cost extraction (only if API key is available) + if test_api: + tests_total += 1 + try: + if test_exa_cost_extraction(): + print("✓ Exa cost extraction test passed") + tests_passed += 1 + else: + print("✗ Exa cost extraction test failed") + except Exception as e: + print(f"✗ Exa cost extraction test failed with error: {e}") + + # Test 2: ResearchState cost tracking + tests_total += 1 + try: + if test_research_state_cost_tracking(): + print("✓ ResearchState cost tracking test passed") + tests_passed += 1 + else: + print("✗ ResearchState cost tracking test failed") + except Exception as e: + print(f"✗ ResearchState cost tracking test failed with error: {e}") + + # Test 3: Cost aggregation + tests_total += 1 + try: + if test_cost_aggregation(): + print("✓ Cost aggregation test passed") + tests_passed += 1 + else: + print("✗ Cost aggregation test failed") + except Exception as e: + print(f"✗ Cost aggregation test failed with error: {e}") + + print(f"\nTests passed: {tests_passed}/{tests_total}") + + if tests_passed == tests_total: + print("\n✅ All tests passed!") + return 0 + else: + print("\n❌ Some tests failed") + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/deep_research/design/test_prompt_cost_visualization.py b/deep_research/design/test_prompt_cost_visualization.py new file mode 100644 index 00000000..35b43b42 --- /dev/null +++ b/deep_research/design/test_prompt_cost_visualization.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python3 +"""Test script for prompt cost visualization feature.""" + +import sys + +sys.path.append("..") + +from utils.pydantic_models import PromptTypeMetrics, TracingMetadata +from utils.tracing_metadata_utils import identify_prompt_type + + +# Mock observation data for testing +class MockUsage: + def __init__(self, input_tokens, output_tokens): + self.input = input_tokens + self.output = output_tokens + + +class MockObservation: + def __init__( + self, + prompt_type_content, + input_tokens=100, + output_tokens=50, + cost=0.01, + ): + self.input = { + "messages": [ + {"content": prompt_type_content}, + {"content": "user message"}, + ] + } + self.usage = MockUsage(input_tokens, output_tokens) + self.calculated_total_cost = cost + + +def test_identify_prompt_type(): + """Test the identify_prompt_type function.""" + print("Testing identify_prompt_type...") + + # Test each prompt type + test_cases = [ + ( + "You are a Deep Research assistant specializing in effective search query generation.", + "search_query", + ), + ( + "Given the MAIN RESEARCH QUERY and DIFFERENT DIMENSIONS sub-questions", + "query_decomposition", + ), + ( + "Your task is information synthesis with comprehensive answer and confidence level", + "synthesis", + ), + ("This is an unknown prompt type", "unknown"), + ] + + for content, expected in test_cases: + obs = MockObservation(content) + result = identify_prompt_type(obs) + status = "✓" if result == expected else "✗" + print( + f" {status} Content: '{content[:50]}...' => {result} (expected: {expected})" + ) + + +def test_prompt_metrics_creation(): + """Test creating PromptTypeMetrics objects.""" + print("\nTesting PromptTypeMetrics creation...") + + metrics = PromptTypeMetrics( + prompt_type="search_query", + total_cost=0.25, + input_tokens=5000, + output_tokens=2000, + call_count=10, + avg_cost_per_call=0.025, + percentage_of_total_cost=35.5, + ) + + print(f" ✓ Created metrics for {metrics.prompt_type}") + print(f" Total cost: ${metrics.total_cost:.4f}") + print(f" Calls: {metrics.call_count}") + print(f" Avg cost/call: ${metrics.avg_cost_per_call:.4f}") + print(f" % of total: {metrics.percentage_of_total_cost:.1f}%") + + +def test_tracing_metadata_with_prompts(): + """Test TracingMetadata with prompt metrics.""" + print("\nTesting TracingMetadata with prompt metrics...") + + # Create sample prompt metrics + prompt_metrics = [ + PromptTypeMetrics( + prompt_type="query_decomposition", + total_cost=0.45, + input_tokens=3000, + output_tokens=1500, + call_count=3, + avg_cost_per_call=0.15, + percentage_of_total_cost=28.0, + ), + PromptTypeMetrics( + prompt_type="synthesis", + total_cost=0.38, + input_tokens=8000, + output_tokens=4000, + call_count=12, + avg_cost_per_call=0.032, + percentage_of_total_cost=24.0, + ), + PromptTypeMetrics( + prompt_type="search_query", + total_cost=0.25, + input_tokens=5000, + output_tokens=2000, + call_count=45, + avg_cost_per_call=0.006, + percentage_of_total_cost=16.0, + ), + ] + + # Create TracingMetadata + metadata = TracingMetadata( + pipeline_run_name="test-pipeline-run", + pipeline_run_id="test-id-123", + total_cost=1.58, + total_input_tokens=20000, + total_output_tokens=10000, + prompt_metrics=prompt_metrics, + ) + + print( + f" ✓ Created TracingMetadata with {len(metadata.prompt_metrics)} prompt types" + ) + print(f" Total pipeline cost: ${metadata.total_cost:.4f}") + print("\n Prompt breakdown:") + for metric in metadata.prompt_metrics: + print( + f" - {metric.prompt_type}: ${metric.total_cost:.4f} ({metric.percentage_of_total_cost:.1f}%)" + ) + + +def test_visualization_html_generation(): + """Test that visualization HTML can be generated.""" + print("\nTesting HTML visualization generation...") + + from materializers.tracing_metadata_materializer import ( + TracingMetadataMaterializer, + ) + + # Create metadata with prompt metrics + metadata = TracingMetadata( + pipeline_run_name="test-visualization", + pipeline_run_id="test-viz-123", + total_cost=2.50, + total_input_tokens=30000, + total_output_tokens=15000, + total_tokens=45000, + formatted_latency="2m 30.5s", + models_used=["gpt-4", "claude-3"], + prompt_metrics=[ + PromptTypeMetrics( + prompt_type="synthesis", + total_cost=1.20, + input_tokens=15000, + output_tokens=8000, + call_count=10, + avg_cost_per_call=0.12, + percentage_of_total_cost=48.0, + ), + PromptTypeMetrics( + prompt_type="search_query", + total_cost=0.80, + input_tokens=10000, + output_tokens=5000, + call_count=50, + avg_cost_per_call=0.016, + percentage_of_total_cost=32.0, + ), + ], + ) + + materializer = TracingMetadataMaterializer(uri="/tmp/test") + html = materializer._generate_visualization_html(metadata) + + # Check that key elements are present + checks = [ + ("Cost Analysis by Prompt Type" in html, "Prompt cost section"), + ("promptCostChart" in html, "Cost chart canvas"), + ("promptTokenChart" in html, "Token chart canvas"), + ("Chart.js" in html, "Chart.js library"), + ("Prompt Type Efficiency" in html, "Efficiency table"), + ("Synthesis" in html, "Synthesis prompt type"), + ("Search Query" in html, "Search Query prompt type"), + ] + + for check, description in checks: + status = "✓" if check else "✗" + print(f" {status} {description}") + + # Save HTML for manual inspection if needed + with open("/tmp/test_visualization.html", "w") as f: + f.write(html) + print( + "\n ℹ️ HTML saved to /tmp/test_visualization.html for manual inspection" + ) + + +def main(): + """Run all tests.""" + print("=== Prompt Cost Visualization Tests ===\n") + + test_identify_prompt_type() + test_prompt_metrics_creation() + test_tracing_metadata_with_prompts() + test_visualization_html_generation() + + print("\n=== All tests completed! ===") + + +if __name__ == "__main__": + main() diff --git a/deep_research/logging_config.py b/deep_research/logging_config.py new file mode 100644 index 00000000..2b93c3e0 --- /dev/null +++ b/deep_research/logging_config.py @@ -0,0 +1,42 @@ +import logging +import sys +from typing import Optional + + +def configure_logging( + level: int = logging.INFO, log_file: Optional[str] = None +): + """Configure logging for the application. + + Args: + level: The log level (default: INFO) + log_file: Optional path to a log file + """ + # Create formatter + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + + # Configure root logger + root_logger = logging.getLogger() + root_logger.setLevel(level) + + # Remove existing handlers to avoid duplicate logs + for handler in root_logger.handlers[:]: + root_logger.removeHandler(handler) + + # Console handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(formatter) + root_logger.addHandler(console_handler) + + # File handler if log_file is provided + if log_file: + file_handler = logging.FileHandler(log_file) + file_handler.setFormatter(formatter) + root_logger.addHandler(file_handler) + + # Reduce verbosity for noisy third-party libraries + logging.getLogger("LiteLLM").setLevel(logging.WARNING) + logging.getLogger("httpx").setLevel(logging.WARNING) + logging.getLogger("urllib3").setLevel(logging.WARNING) diff --git a/deep_research/materializers/__init__.py b/deep_research/materializers/__init__.py new file mode 100644 index 00000000..1479f72b --- /dev/null +++ b/deep_research/materializers/__init__.py @@ -0,0 +1,21 @@ +""" +Materializers package for the ZenML Deep Research project. + +This package contains custom ZenML materializers that handle serialization and +deserialization of complex data types used in the research pipeline, particularly +the ResearchState object that tracks the state of the research process. +""" + +from .approval_decision_materializer import ApprovalDecisionMaterializer +from .prompts_materializer import PromptsBundleMaterializer +from .pydantic_materializer import ResearchStateMaterializer +from .reflection_output_materializer import ReflectionOutputMaterializer +from .tracing_metadata_materializer import TracingMetadataMaterializer + +__all__ = [ + "ApprovalDecisionMaterializer", + "PromptsBundleMaterializer", + "ReflectionOutputMaterializer", + "ResearchStateMaterializer", + "TracingMetadataMaterializer", +] diff --git a/deep_research/materializers/approval_decision_materializer.py b/deep_research/materializers/approval_decision_materializer.py new file mode 100644 index 00000000..0a4ed2c3 --- /dev/null +++ b/deep_research/materializers/approval_decision_materializer.py @@ -0,0 +1,281 @@ +"""Materializer for ApprovalDecision with custom visualization.""" + +import os +from datetime import datetime +from typing import Dict + +from utils.pydantic_models import ApprovalDecision +from zenml.enums import ArtifactType, VisualizationType +from zenml.io import fileio +from zenml.materializers import PydanticMaterializer + + +class ApprovalDecisionMaterializer(PydanticMaterializer): + """Materializer for the ApprovalDecision class with visualizations.""" + + ASSOCIATED_TYPES = (ApprovalDecision,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + + def save_visualizations( + self, data: ApprovalDecision + ) -> Dict[str, VisualizationType]: + """Create and save visualizations for the ApprovalDecision. + + Args: + data: The ApprovalDecision to visualize + + Returns: + Dictionary mapping file paths to visualization types + """ + # Generate an HTML visualization + visualization_path = os.path.join(self.uri, "approval_decision.html") + + # Create HTML content + html_content = self._generate_visualization_html(data) + + # Write the HTML content to a file + with fileio.open(visualization_path, "w") as f: + f.write(html_content) + + # Return the visualization path and type + return {visualization_path: VisualizationType.HTML} + + def _generate_visualization_html(self, decision: ApprovalDecision) -> str: + """Generate HTML visualization for the approval decision. + + Args: + decision: The ApprovalDecision to visualize + + Returns: + HTML string + """ + # Format timestamp + decision_time = datetime.fromtimestamp(decision.timestamp).strftime( + "%Y-%m-%d %H:%M:%S" + ) + + # Determine status color and icon + if decision.approved: + status_color = "#27ae60" + status_icon = "✅" + status_text = "APPROVED" + else: + status_color = "#e74c3c" + status_icon = "❌" + status_text = "NOT APPROVED" + + # Format approval method + method_display = { + "APPROVE_ALL": "Approve All Queries", + "SKIP": "Skip Additional Research", + "SELECT_SPECIFIC": "Select Specific Queries", + }.get(decision.approval_method, decision.approval_method or "Unknown") + + html = f""" + + + + Approval Decision + + + +
+

+ 🔒 Approval Decision +
+ {status_icon} + {status_text} +
+

+ +
+
+
Approval Method
+
{method_display}
+
+
+
Decision Time
+
{decision_time}
+
+
+
Queries Selected
+
{len(decision.selected_queries)}
+
+
+ """ + + # Add selected queries section if any + if decision.selected_queries: + html += """ +
+

📋Selected Queries

+
+ """ + + for i, query in enumerate(decision.selected_queries, 1): + html += f""" +
+
{i}
+
{query}
+
+ """ + + html += """ +
+
+ """ + else: + html += """ +
+

📋Selected Queries

+
+ No queries were selected for additional research +
+
+ """ + + # Add reviewer notes if any + if decision.reviewer_notes: + html += f""" +
+

📝Reviewer Notes

+
+ {decision.reviewer_notes} +
+
+ """ + + # Add timestamp footer + html += f""" +
+ Decision recorded at: {decision_time} +
+
+ + + """ + + return html diff --git a/deep_research/materializers/prompts_materializer.py b/deep_research/materializers/prompts_materializer.py new file mode 100644 index 00000000..96e35dde --- /dev/null +++ b/deep_research/materializers/prompts_materializer.py @@ -0,0 +1,509 @@ +"""Materializer for PromptsBundle with custom HTML visualization. + +This module provides a materializer that creates beautiful HTML visualizations +for prompt bundles in the ZenML dashboard. +""" + +import os +from typing import Dict + +from utils.prompt_models import PromptsBundle +from zenml.enums import ArtifactType, VisualizationType +from zenml.io import fileio +from zenml.materializers import PydanticMaterializer + + +class PromptsBundleMaterializer(PydanticMaterializer): + """Materializer for PromptsBundle with custom visualization.""" + + ASSOCIATED_TYPES = (PromptsBundle,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + + def save_visualizations( + self, data: PromptsBundle + ) -> Dict[str, VisualizationType]: + """Create and save visualizations for the PromptsBundle. + + Args: + data: The PromptsBundle to visualize + + Returns: + Dictionary mapping file paths to visualization types + """ + # Generate an HTML visualization + visualization_path = os.path.join(self.uri, "prompts_bundle.html") + + # Create HTML content + html_content = self._generate_visualization_html(data) + + # Write the HTML content to a file + with fileio.open(visualization_path, "w") as f: + f.write(html_content) + + # Return the visualization path and type + return {visualization_path: VisualizationType.HTML} + + def _generate_visualization_html(self, bundle: PromptsBundle) -> str: + """Generate HTML visualization for the prompts bundle. + + Args: + bundle: The PromptsBundle to visualize + + Returns: + HTML string + """ + # Create HTML content + html = f""" + + + + Prompts Bundle - {bundle.created_at} + + + +
+
+

🎯 Prompts Bundle

+ +
+
+ {len(bundle.list_all_prompts())} + Total Prompts +
+
+ {len([p for p in bundle.list_all_prompts().values() if p.tags])} + Tagged Prompts +
+
+ {len(bundle.custom_prompts)} + Custom Prompts +
+
+
+ + + +
+ """ + + # Add each prompt + prompts = [ + ( + "search_query_prompt", + bundle.search_query_prompt, + ["search", "query"], + ), + ( + "query_decomposition_prompt", + bundle.query_decomposition_prompt, + ["analysis", "decomposition"], + ), + ( + "synthesis_prompt", + bundle.synthesis_prompt, + ["synthesis", "integration"], + ), + ( + "viewpoint_analysis_prompt", + bundle.viewpoint_analysis_prompt, + ["analysis", "viewpoint"], + ), + ( + "reflection_prompt", + bundle.reflection_prompt, + ["reflection", "critique"], + ), + ( + "additional_synthesis_prompt", + bundle.additional_synthesis_prompt, + ["synthesis", "enhancement"], + ), + ( + "conclusion_generation_prompt", + bundle.conclusion_generation_prompt, + ["report", "conclusion"], + ), + ] + + for prompt_type, prompt, default_tags in prompts: + # Use provided tags or default tags + tags = prompt.tags if prompt.tags else default_tags + + html += f""" +
+
+

{prompt.name}

+ v{prompt.version} +
+ {f'

{prompt.description}

' if prompt.description else ""} +
+ {"".join([f'{tag}' for tag in tags])} +
+
+ + {self._escape_html(prompt.content)} +
+ +
+ """ + + # Add custom prompts if any + for name, prompt in bundle.custom_prompts.items(): + tags = prompt.tags if prompt.tags else ["custom"] + html += f""" +
+
+

{prompt.name}

+ v{prompt.version} +
+ {f'

{prompt.description}

' if prompt.description else ""} +
+ {"".join([f'{tag}' for tag in tags])} +
+
+ + {self._escape_html(prompt.content)} +
+ +
+ """ + + html += """ +
+ +
+ + + + + """ + + return html + + def _escape_html(self, text: str) -> str: + """Escape HTML special characters. + + Args: + text: Text to escape + + Returns: + Escaped text + """ + return ( + text.replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace('"', """) + .replace("'", "'") + ) diff --git a/deep_research/materializers/pydantic_materializer.py b/deep_research/materializers/pydantic_materializer.py new file mode 100644 index 00000000..ee01281b --- /dev/null +++ b/deep_research/materializers/pydantic_materializer.py @@ -0,0 +1,764 @@ +"""Pydantic materializer for research state objects. + +This module contains an extended version of ZenML's PydanticMaterializer +that adds visualization capabilities for the ResearchState model. +""" + +import os +from typing import Dict + +from utils.pydantic_models import ResearchState +from zenml.enums import ArtifactType, VisualizationType +from zenml.io import fileio +from zenml.materializers import PydanticMaterializer + + +class ResearchStateMaterializer(PydanticMaterializer): + """Materializer for the ResearchState class with visualizations.""" + + ASSOCIATED_TYPES = (ResearchState,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + + def save_visualizations( + self, data: ResearchState + ) -> Dict[str, VisualizationType]: + """Create and save visualizations for the ResearchState. + + Args: + data: The ResearchState to visualize + + Returns: + Dictionary mapping file paths to visualization types + """ + # Generate an HTML visualization + visualization_path = os.path.join(self.uri, "research_state.html") + + # Create HTML content based on current stage + html_content = self._generate_visualization_html(data) + + # Write the HTML content to a file + with fileio.open(visualization_path, "w") as f: + f.write(html_content) + + # Return the visualization path and type + return {visualization_path: VisualizationType.HTML} + + def _generate_visualization_html(self, state: ResearchState) -> str: + """Generate HTML visualization for the research state. + + Args: + state: The ResearchState to visualize + + Returns: + HTML string + """ + # Base structure for the HTML + html = f""" + + + + Research State: {state.main_query} + + + + +
+

Research State

+ + +
+
Initial Query
+
Query Decomposition
+
Information Gathering
+
Information Synthesis
+
Viewpoint Analysis
+
Reflection & Enhancement
+
Final Report
+
+ + +
+
+
+ + +
    + """ + + # Determine which tab should be active based on current stage + current_stage = state.get_current_stage() + + # Map stages to tabs + stage_to_tab = { + "empty": "overview", + "initial": "overview", + "after_query_decomposition": "sub-questions", + "after_search": "search-results", + "after_synthesis": "synthesis", + "after_viewpoint_analysis": "viewpoints", + "after_reflection": "reflection", + "final_report": "final-report", + } + + # Get the default active tab based on stage + default_active_tab = stage_to_tab.get(current_stage, "overview") + + # Create tab headers dynamically based on available data + tabs_created = [] + + # Overview tab is always shown + is_active = default_active_tab == "overview" + html += f'
  • Overview
  • ' + tabs_created.append("overview") + + if state.sub_questions: + is_active = default_active_tab == "sub-questions" + html += f'
  • Sub-Questions
  • ' + tabs_created.append("sub-questions") + + if state.search_results: + is_active = default_active_tab == "search-results" + html += f'
  • Search Results
  • ' + tabs_created.append("search-results") + + if state.synthesized_info: + is_active = default_active_tab == "synthesis" + html += f'
  • Synthesis
  • ' + tabs_created.append("synthesis") + + if state.viewpoint_analysis: + is_active = default_active_tab == "viewpoints" + html += f'
  • Viewpoints
  • ' + tabs_created.append("viewpoints") + + if state.enhanced_info or state.reflection_metadata: + is_active = default_active_tab == "reflection" + html += f'
  • Reflection
  • ' + tabs_created.append("reflection") + + if state.final_report_html: + is_active = default_active_tab == "final-report" + html += f'
  • Final Report
  • ' + tabs_created.append("final-report") + + # Ensure the active tab actually exists in the created tabs + # If not, fallback to the first available tab + if default_active_tab not in tabs_created and tabs_created: + default_active_tab = tabs_created[0] + + html += """ +
+ + + """ + + # Overview tab content (always shown) + is_active = default_active_tab == "overview" + html += f""" +
+
+

Main Query

+
+ """ + + if state.main_query: + html += f"

{state.main_query}

" + else: + html += "

No main query specified

" + + html += """ +
+
+
+ """ + + # Sub-questions tab content + if state.sub_questions: + is_active = default_active_tab == "sub-questions" + html += f""" +
+
+

Sub-Questions ({len(state.sub_questions)})

+
+ """ + + for i, question in enumerate(state.sub_questions): + html += f""" +
+ {i + 1}. {question} +
+ """ + + html += """ +
+
+
+ """ + + # Search results tab content + if state.search_results: + is_active = default_active_tab == "search-results" + html += f""" +
+
+

Search Results

+ """ + + for question, results in state.search_results.items(): + html += f""" +

{question}

+

Found {len(results)} results

+
    + """ + + for result in results: + # Extract domain from URL or use special handling for generated content + if result.url == "tavily-generated-answer": + domain = "Tavily" + else: + domain = "" + try: + from urllib.parse import urlparse + + parsed_url = urlparse(result.url) + domain = parsed_url.netloc + # Strip www. prefix to save space + if domain.startswith("www."): + domain = domain[4:] + except: + domain = ( + result.url.split("/")[2] + if len(result.url.split("/")) > 2 + else "" + ) + # Strip www. prefix to save space + if domain.startswith("www."): + domain = domain[4:] + + html += f""" +
  • + {result.title} ({domain}) +
  • + """ + + html += """ +
+ """ + + html += """ +
+
+ """ + + # Synthesized information tab content + if state.synthesized_info: + is_active = default_active_tab == "synthesis" + html += f""" +
+
+

Synthesized Information

+ """ + + for question, info in state.synthesized_info.items(): + html += f""" +

{question} {info.confidence_level}

+
+

{info.synthesized_answer}

+ """ + + if info.key_sources: + html += """ +
+

Key Sources:

+
    + """ + + for source in info.key_sources[:3]: + html += f""" +
  • {source[:50]}...
  • + """ + + if len(info.key_sources) > 3: + html += f"
  • ...and {len(info.key_sources) - 3} more sources
  • " + + html += """ +
+
+ """ + + if info.information_gaps: + html += f""" + + """ + + html += """ +
+ """ + + html += """ +
+
+ """ + + # Viewpoint analysis tab content + if state.viewpoint_analysis: + is_active = default_active_tab == "viewpoints" + html += f""" +
+
+

Viewpoint Analysis

+
+ """ + + # Points of agreement + if state.viewpoint_analysis.main_points_of_agreement: + html += """ +

Points of Agreement

+
    + """ + + for point in state.viewpoint_analysis.main_points_of_agreement: + html += f""" +
  • {point}
  • + """ + + html += """ +
+ """ + + # Areas of tension + if state.viewpoint_analysis.areas_of_tension: + html += """ +

Areas of Tension

+ """ + + for tension in state.viewpoint_analysis.areas_of_tension: + html += f""" +
+

{tension.topic}

+
    + """ + + for viewpoint, description in tension.viewpoints.items(): + html += f""" +
  • {viewpoint}: {description}
  • + """ + + html += """ +
+
+ """ + + # Perspective gaps and integrative insights + if state.viewpoint_analysis.perspective_gaps: + html += f""" +

Perspective Gaps

+

{state.viewpoint_analysis.perspective_gaps}

+ """ + + if state.viewpoint_analysis.integrative_insights: + html += f""" +

Integrative Insights

+

{state.viewpoint_analysis.integrative_insights}

+ """ + + html += """ +
+
+
+ """ + + # Reflection & Enhancement tab content + if state.enhanced_info or state.reflection_metadata: + is_active = default_active_tab == "reflection" + html += f""" +
+
+

Reflection & Enhancement

+ """ + + # Reflection metadata + if state.reflection_metadata: + html += """ +
+ """ + + if state.reflection_metadata.critique_summary: + html += """ +

Critique Summary

+
    + """ + + for critique in state.reflection_metadata.critique_summary: + html += f""" +
  • {critique}
  • + """ + + html += """ +
+ """ + + if state.reflection_metadata.additional_questions_identified: + html += """ +

Additional Questions Identified

+
    + """ + + for question in state.reflection_metadata.additional_questions_identified: + html += f""" +
  • {question}
  • + """ + + html += """ +
+ """ + + html += f""" + +
+ """ + + # Enhanced information + if state.enhanced_info: + html += """ +

Enhanced Information

+ """ + + for question, info in state.enhanced_info.items(): + # Show only for questions with improvements + if info.improvements: + html += f""" +
+

{question} {info.confidence_level}

+ +
+

Improvements Made:

+
    + """ + + for improvement in info.improvements: + html += f""" +
  • {improvement}
  • + """ + + html += """ +
+
+
+ """ + + html += """ +
+
+ """ + + # Final report tab + if state.final_report_html: + is_active = default_active_tab == "final-report" + html += f""" +
+
+

Final Report

+

Final HTML report is available but not displayed here. View the HTML artifact to see the complete report.

+
+
+ """ + + # Close HTML tags + html += """ +
+ + + """ + + return html + + def _get_stage_class(self, state: ResearchState, stage: str) -> str: + """Get CSS class for a stage based on current progress. + + Args: + state: ResearchState object + stage: Stage name + + Returns: + CSS class string + """ + current_stage = state.get_current_stage() + + # These are the stages in order + stages = [ + "empty", + "initial", + "after_query_decomposition", + "after_search", + "after_synthesis", + "after_viewpoint_analysis", + "after_reflection", + "final_report", + ] + + current_index = ( + stages.index(current_stage) if current_stage in stages else 0 + ) + stage_index = stages.index(stage) if stage in stages else 0 + + if stage_index == current_index: + return "active" + elif stage_index < current_index: + return "completed" + else: + return "" + + def _calculate_progress(self, state: ResearchState) -> int: + """Calculate overall progress percentage. + + Args: + state: ResearchState object + + Returns: + Progress percentage (0-100) + """ + # Map stages to progress percentages + stage_percentages = { + "empty": 0, + "initial": 5, + "after_query_decomposition": 20, + "after_search": 40, + "after_synthesis": 60, + "after_viewpoint_analysis": 75, + "after_reflection": 90, + "final_report": 100, + } + + current_stage = state.get_current_stage() + return stage_percentages.get(current_stage, 0) diff --git a/deep_research/materializers/reflection_output_materializer.py b/deep_research/materializers/reflection_output_materializer.py new file mode 100644 index 00000000..1e8b37ae --- /dev/null +++ b/deep_research/materializers/reflection_output_materializer.py @@ -0,0 +1,279 @@ +"""Materializer for ReflectionOutput with custom visualization.""" + +import os +from typing import Dict + +from utils.pydantic_models import ReflectionOutput +from zenml.enums import ArtifactType, VisualizationType +from zenml.io import fileio +from zenml.materializers import PydanticMaterializer + + +class ReflectionOutputMaterializer(PydanticMaterializer): + """Materializer for the ReflectionOutput class with visualizations.""" + + ASSOCIATED_TYPES = (ReflectionOutput,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + + def save_visualizations( + self, data: ReflectionOutput + ) -> Dict[str, VisualizationType]: + """Create and save visualizations for the ReflectionOutput. + + Args: + data: The ReflectionOutput to visualize + + Returns: + Dictionary mapping file paths to visualization types + """ + # Generate an HTML visualization + visualization_path = os.path.join(self.uri, "reflection_output.html") + + # Create HTML content + html_content = self._generate_visualization_html(data) + + # Write the HTML content to a file + with fileio.open(visualization_path, "w") as f: + f.write(html_content) + + # Return the visualization path and type + return {visualization_path: VisualizationType.HTML} + + def _generate_visualization_html(self, output: ReflectionOutput) -> str: + """Generate HTML visualization for the reflection output. + + Args: + output: The ReflectionOutput to visualize + + Returns: + HTML string + """ + html = f""" + + + + Reflection Output + + + +
+

🔍 Reflection & Analysis Output

+ + + +
+

+ 📝Critique Summary + {} +

+ """.format(len(output.critique_summary)) + + if output.critique_summary: + for critique in output.critique_summary: + html += """ +
+ """ + + # Handle different critique formats + if isinstance(critique, dict): + for key, value in critique.items(): + html += f""" +
{key}:
+
{value}
+ """ + else: + html += f""" +
{critique}
+ """ + + html += """ +
+ """ + else: + html += """ +

No critique summary available

+ """ + + html += """ +
+ +
+

+ Additional Questions Identified + {} +

+ """.format(len(output.additional_questions)) + + if output.additional_questions: + for question in output.additional_questions: + html += f""" +
+ {question} +
+ """ + else: + html += """ +

No additional questions identified

+ """ + + html += """ +
+ +
+

📊Research State Summary

+
+

Main Query: {}

+

Current Stage: {}

+

Sub-questions: {}

+

Search Results: {} queries with results

+

Synthesized Info: {} topics synthesized

+
+
+ """.format( + output.state.main_query, + output.state.get_current_stage().replace("_", " ").title(), + len(output.state.sub_questions), + len(output.state.search_results), + len(output.state.synthesized_info), + ) + + # Add metadata section + html += """ + +
+ + + """ + + return html diff --git a/deep_research/materializers/tracing_metadata_materializer.py b/deep_research/materializers/tracing_metadata_materializer.py new file mode 100644 index 00000000..7cf7b51b --- /dev/null +++ b/deep_research/materializers/tracing_metadata_materializer.py @@ -0,0 +1,603 @@ +"""Materializer for TracingMetadata with custom visualization.""" + +import os +from typing import Dict + +from utils.pydantic_models import TracingMetadata +from zenml.enums import ArtifactType, VisualizationType +from zenml.io import fileio +from zenml.materializers import PydanticMaterializer + + +class TracingMetadataMaterializer(PydanticMaterializer): + """Materializer for the TracingMetadata class with visualizations.""" + + ASSOCIATED_TYPES = (TracingMetadata,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + + def save_visualizations( + self, data: TracingMetadata + ) -> Dict[str, VisualizationType]: + """Create and save visualizations for the TracingMetadata. + + Args: + data: The TracingMetadata to visualize + + Returns: + Dictionary mapping file paths to visualization types + """ + # Generate an HTML visualization + visualization_path = os.path.join(self.uri, "tracing_metadata.html") + + # Create HTML content + html_content = self._generate_visualization_html(data) + + # Write the HTML content to a file + with fileio.open(visualization_path, "w") as f: + f.write(html_content) + + # Return the visualization path and type + return {visualization_path: VisualizationType.HTML} + + def _generate_visualization_html(self, metadata: TracingMetadata) -> str: + """Generate HTML visualization for the tracing metadata. + + Args: + metadata: The TracingMetadata to visualize + + Returns: + HTML string + """ + # Calculate some derived values + avg_cost_per_token = metadata.total_cost / max( + metadata.total_tokens, 1 + ) + + # Base structure for the HTML + html = f""" + + + + Pipeline Tracing Metadata + + + +
+

Pipeline Tracing Metadata

+ +
+
+
Pipeline Run
+
{metadata.pipeline_run_name}
+
+ +
+
LLM Cost
+
${metadata.total_cost:.4f}
+
+ +
+
Total Tokens
+
{metadata.total_tokens:,}
+
+ +
+
Duration
+
{metadata.formatted_latency}
+
+
+ +

Token Usage

+
+
+
Input Tokens
+
{metadata.total_input_tokens:,}
+
+ +
+
Output Tokens
+
{metadata.total_output_tokens:,}
+
+ +
+
Observations
+
{metadata.observation_count}
+
+ +
+
Avg Cost per Token
+
${avg_cost_per_token:.6f}
+
+
+ +

Model Usage Breakdown

+ + + + + + + + + + + + """ + + # Add model breakdown + for model in metadata.models_used: + tokens = metadata.model_token_breakdown.get(model, {}) + cost = metadata.cost_breakdown_by_model.get(model, 0.0) + html += f""" + + + + + + + + """ + + html += """ + +
ModelInput TokensOutput TokensTotal TokensCost
{model}{tokens.get("input_tokens", 0):,}{tokens.get("output_tokens", 0):,}{tokens.get("total_tokens", 0):,}${cost:.4f}
+ """ + + # Add prompt-level metrics visualization + if metadata.prompt_metrics: + html += """ +

Cost Analysis by Prompt Type

+ + + + + +
+ +
+ + +
+ +
+ + +

Prompt Type Efficiency

+ + + + + + + + + + + + + + """ + + # Add prompt metrics rows + for metric in metadata.prompt_metrics: + # Format prompt type name nicely + prompt_type_display = metric.prompt_type.replace( + "_", " " + ).title() + html += f""" + + + + + + + + + + """ + + html += """ + +
Prompt TypeTotal CostCallsAvg $/Call% of TotalInput TokensOutput Tokens
{prompt_type_display}${metric.total_cost:.4f}{metric.call_count}${metric.avg_cost_per_call:.4f}{metric.percentage_of_total_cost:.1f}%{metric.input_tokens:,}{metric.output_tokens:,}
+ + + """ + + # Add search cost visualization if available + if metadata.search_costs and any(metadata.search_costs.values()): + total_search_cost = sum(metadata.search_costs.values()) + total_combined_cost = metadata.total_cost + total_search_cost + + html += """ +

Search Provider Costs

+
+ """ + + for provider, cost in metadata.search_costs.items(): + if cost > 0: + query_count = metadata.search_queries_count.get( + provider, 0 + ) + avg_cost_per_query = ( + cost / query_count if query_count > 0 else 0 + ) + html += f""" +
+
{provider.upper()} Search
+
${cost:.4f}
+
+ {query_count} queries • ${avg_cost_per_query:.4f}/query +
+
+ """ + + html += ( + f""" +
+
Total Search Cost
+
${total_search_cost:.4f}
+
+ {sum(metadata.search_queries_count.values())} total queries +
+
+
+ +

Combined Cost Summary

+
+
+
LLM Cost
+
${metadata.total_cost:.4f}
+
+ {(metadata.total_cost / total_combined_cost * 100):.1f}% of total +
+
+
+
Search Cost
+
${total_search_cost:.4f}
+
+ {(total_search_cost / total_combined_cost * 100):.1f}% of total +
+
+
+
Total Pipeline Cost
+
${total_combined_cost:.4f}
+
+
+ +

Cost Breakdown Chart

+ + + """ + ) + + # Add trace metadata + if metadata.trace_tags or metadata.trace_metadata: + html += """ +

Trace Information

+
+ """ + + if metadata.trace_tags: + html += """ +

Tags

+
+ """ + for tag in metadata.trace_tags: + html += f'{tag}' + html += """ +
+ """ + + if metadata.trace_metadata: + html += """ +

Metadata

+
+                """
+                import json
+
+                html += json.dumps(metadata.trace_metadata, indent=2)
+                html += """
+                    
+ """ + + html += """ +
+ """ + + # Add footer with collection info + from datetime import datetime + + collection_time = datetime.fromtimestamp( + metadata.collected_at + ).strftime("%Y-%m-%d %H:%M:%S") + + html += f""" + +
+ + + """ + + return html diff --git a/deep_research/pipelines/__init__.py b/deep_research/pipelines/__init__.py new file mode 100644 index 00000000..7f4ea5eb --- /dev/null +++ b/deep_research/pipelines/__init__.py @@ -0,0 +1,11 @@ +""" +Pipelines package for the ZenML Deep Research project. + +This package contains the ZenML pipeline definitions for running deep research +workflows. Each pipeline orchestrates a sequence of steps for comprehensive +research on a given query topic. +""" + +from .parallel_research_pipeline import parallelized_deep_research_pipeline + +__all__ = ["parallelized_deep_research_pipeline"] diff --git a/deep_research/pipelines/parallel_research_pipeline.py b/deep_research/pipelines/parallel_research_pipeline.py new file mode 100644 index 00000000..bd7afe14 --- /dev/null +++ b/deep_research/pipelines/parallel_research_pipeline.py @@ -0,0 +1,133 @@ +from steps.approval_step import get_research_approval_step +from steps.collect_tracing_metadata_step import collect_tracing_metadata_step +from steps.cross_viewpoint_step import cross_viewpoint_analysis_step +from steps.execute_approved_searches_step import execute_approved_searches_step +from steps.generate_reflection_step import generate_reflection_step +from steps.initialize_prompts_step import initialize_prompts_step +from steps.merge_results_step import merge_sub_question_results_step +from steps.process_sub_question_step import process_sub_question_step +from steps.pydantic_final_report_step import pydantic_final_report_step +from steps.query_decomposition_step import initial_query_decomposition_step +from utils.pydantic_models import ResearchState +from zenml import pipeline + + +@pipeline(enable_cache=False) +def parallelized_deep_research_pipeline( + query: str = "What is ZenML?", + max_sub_questions: int = 10, + require_approval: bool = False, + approval_timeout: int = 3600, + max_additional_searches: int = 2, + search_provider: str = "tavily", + search_mode: str = "auto", + num_results_per_search: int = 3, + langfuse_project_name: str = "deep-research", +) -> None: + """Parallelized ZenML pipeline for deep research on a given query. + + This pipeline uses the fan-out/fan-in pattern for parallel processing of sub-questions, + potentially improving execution time when using distributed orchestrators. + + Args: + query: The research query/topic + max_sub_questions: Maximum number of sub-questions to process in parallel + require_approval: Whether to require human approval for additional searches + approval_timeout: Timeout in seconds for human approval + max_additional_searches: Maximum number of additional searches to perform + search_provider: Search provider to use (tavily, exa, or both) + search_mode: Search mode for Exa provider (neural, keyword, or auto) + num_results_per_search: Number of search results to return per query + langfuse_project_name: Langfuse project name for LLM tracking + + Returns: + Formatted research report as HTML + """ + # Initialize prompts bundle for tracking + prompts_bundle = initialize_prompts_step(pipeline_version="1.0.0") + + # Initialize the research state with the main query + state = ResearchState(main_query=query) + + # Step 1: Decompose the query into sub-questions, limiting to max_sub_questions + decomposed_state = initial_query_decomposition_step( + state=state, + prompts_bundle=prompts_bundle, + max_sub_questions=max_sub_questions, + langfuse_project_name=langfuse_project_name, + ) + + # Fan out: Process each sub-question in parallel + # Collect artifacts to establish dependencies for the merge step + after = [] + for i in range(max_sub_questions): + # Process the i-th sub-question (if it exists) + sub_state = process_sub_question_step( + state=decomposed_state, + prompts_bundle=prompts_bundle, + question_index=i, + search_provider=search_provider, + search_mode=search_mode, + num_results_per_search=num_results_per_search, + langfuse_project_name=langfuse_project_name, + id=f"process_question_{i + 1}", + ) + after.append(sub_state) + + # Fan in: Merge results from all parallel processing + # The 'after' parameter ensures this step runs after all processing steps + # It doesn't directly use the processed_states input + merged_state = merge_sub_question_results_step( + original_state=decomposed_state, + step_prefix="process_question_", + output_name="output", + after=after, # This creates the dependency + ) + + # Continue with subsequent steps + analyzed_state = cross_viewpoint_analysis_step( + state=merged_state, + prompts_bundle=prompts_bundle, + langfuse_project_name=langfuse_project_name, + ) + + # New 3-step reflection flow with optional human approval + # Step 1: Generate reflection and recommendations (no searches yet) + reflection_output = generate_reflection_step( + state=analyzed_state, + prompts_bundle=prompts_bundle, + langfuse_project_name=langfuse_project_name, + ) + + # Step 2: Get approval for recommended searches + approval_decision = get_research_approval_step( + reflection_output=reflection_output, + require_approval=require_approval, + timeout=approval_timeout, + max_queries=max_additional_searches, + ) + + # Step 3: Execute approved searches (if any) + reflected_state = execute_approved_searches_step( + reflection_output=reflection_output, + approval_decision=approval_decision, + prompts_bundle=prompts_bundle, + search_provider=search_provider, + search_mode=search_mode, + num_results_per_search=num_results_per_search, + langfuse_project_name=langfuse_project_name, + ) + + # Use our new Pydantic-based final report step + # This returns a tuple (state, html_report) + final_state, final_report = pydantic_final_report_step( + state=reflected_state, + prompts_bundle=prompts_bundle, + langfuse_project_name=langfuse_project_name, + ) + + # Collect tracing metadata for the entire pipeline run + _, tracing_metadata = collect_tracing_metadata_step( + state=final_state, + langfuse_project_name=langfuse_project_name, + ) diff --git a/deep_research/requirements.txt b/deep_research/requirements.txt new file mode 100644 index 00000000..6d2f9166 --- /dev/null +++ b/deep_research/requirements.txt @@ -0,0 +1,10 @@ +zenml>=0.82.0 +litellm>=1.70.0,<2.0.0 +tavily-python>=0.2.8 +exa-py>=1.0.0 +PyYAML>=6.0 +click>=8.0.0 +pydantic>=2.0.0 +typing_extensions>=4.0.0 +requests +langfuse>=2.0.0 diff --git a/deep_research/run.py b/deep_research/run.py new file mode 100644 index 00000000..aa1745c3 --- /dev/null +++ b/deep_research/run.py @@ -0,0 +1,330 @@ +import logging +import os + +import click +import yaml +from logging_config import configure_logging +from pipelines.parallel_research_pipeline import ( + parallelized_deep_research_pipeline, +) +from utils.helper_functions import check_required_env_vars + +logger = logging.getLogger(__name__) + + +# Research mode presets for easy configuration +RESEARCH_MODES = { + "rapid": { + "max_sub_questions": 5, + "num_results_per_search": 2, + "max_additional_searches": 0, + "description": "Quick research with minimal depth - great for getting a fast overview", + }, + "balanced": { + "max_sub_questions": 10, + "num_results_per_search": 3, + "max_additional_searches": 2, + "description": "Balanced research with moderate depth - ideal for most use cases", + }, + "deep": { + "max_sub_questions": 15, + "num_results_per_search": 5, + "max_additional_searches": 4, + "description": "Comprehensive research with maximum depth - for thorough analysis", + "suggest_approval": True, # Suggest using approval for deep mode + }, +} + + +@click.command( + help=""" +Deep Research Agent - ZenML Pipeline for Comprehensive Research + +Run a deep research pipeline that: +1. Generates a structured report outline +2. Researches each topic with web searches and LLM analysis +3. Refines content through multiple reflection cycles +4. Produces a formatted, comprehensive research report + +Examples: + + \b + # Run with default configuration + python run.py + + \b + # Use a research mode preset for easy configuration + python run.py --mode rapid # Quick overview + python run.py --mode balanced # Standard research (default) + python run.py --mode deep # Comprehensive analysis + + \b + # Run with a custom pipeline configuration file + python run.py --config configs/custom_pipeline.yaml + + \b + # Override the research query + python run.py --query "My research topic" + + \b + # Combine mode with other options + python run.py --mode deep --query "Complex topic" --require-approval + + \b + # Run with a custom number of sub-questions + python run.py --max-sub-questions 15 +""" +) +@click.option( + "--mode", + type=click.Choice(["rapid", "balanced", "deep"], case_sensitive=False), + default=None, + help="Research mode preset: rapid (fast overview), balanced (standard), or deep (comprehensive)", +) +@click.option( + "--config", + type=str, + default="configs/enhanced_research.yaml", + help="Path to the pipeline configuration YAML file", +) +@click.option( + "--no-cache", + is_flag=True, + default=False, + help="Disable caching for the pipeline run", +) +@click.option( + "--log-file", + type=str, + default=None, + help="Path to log file (if not provided, logs only go to console)", +) +@click.option( + "--debug", + is_flag=True, + default=False, + help="Enable debug logging", +) +@click.option( + "--query", + type=str, + default=None, + help="Research query (overrides the query in the config file)", +) +@click.option( + "--max-sub-questions", + type=int, + default=10, + help="Maximum number of sub-questions to process in parallel", +) +@click.option( + "--require-approval", + is_flag=True, + default=False, + help="Enable human-in-the-loop approval for additional searches", +) +@click.option( + "--approval-timeout", + type=int, + default=3600, + help="Timeout in seconds for human approval (default: 3600)", +) +@click.option( + "--search-provider", + type=click.Choice(["tavily", "exa", "both"], case_sensitive=False), + default=None, + help="Search provider to use: tavily (default), exa, or both", +) +@click.option( + "--search-mode", + type=click.Choice(["neural", "keyword", "auto"], case_sensitive=False), + default="auto", + help="Search mode for Exa provider: neural, keyword, or auto (default: auto)", +) +@click.option( + "--num-results", + type=int, + default=3, + help="Number of search results to return per query (default: 3)", +) +def main( + mode: str = None, + config: str = "configs/enhanced_research.yaml", + no_cache: bool = False, + log_file: str = None, + debug: bool = False, + query: str = None, + max_sub_questions: int = 10, + require_approval: bool = False, + approval_timeout: int = 3600, + search_provider: str = None, + search_mode: str = "auto", + num_results: int = 3, +): + """Run the deep research pipeline. + + Args: + mode: Research mode preset (rapid, balanced, or deep) + config: Path to the pipeline configuration YAML file + no_cache: Disable caching for the pipeline run + log_file: Path to log file + debug: Enable debug logging + query: Research query (overrides the query in the config file) + max_sub_questions: Maximum number of sub-questions to process in parallel + require_approval: Enable human-in-the-loop approval for additional searches + approval_timeout: Timeout in seconds for human approval + search_provider: Search provider to use (tavily, exa, or both) + search_mode: Search mode for Exa provider (neural, keyword, or auto) + num_results: Number of search results to return per query + """ + # Configure logging + log_level = logging.DEBUG if debug else logging.INFO + configure_logging(level=log_level, log_file=log_file) + + # Apply mode presets if specified + if mode: + mode_config = RESEARCH_MODES[mode.lower()] + logger.info(f"\n{'=' * 80}") + logger.info(f"Using research mode: {mode.upper()}") + logger.info(f"Description: {mode_config['description']}") + + # Apply mode parameters (can be overridden by explicit arguments) + if max_sub_questions == 10: # Default value - apply mode preset + max_sub_questions = mode_config["max_sub_questions"] + logger.info(f" - Max sub-questions: {max_sub_questions}") + + # Store mode config for later use + mode_max_additional_searches = mode_config["max_additional_searches"] + + # Use mode's num_results_per_search only if user didn't override with --num-results + if num_results == 3: # Default value - apply mode preset + num_results = mode_config["num_results_per_search"] + + logger.info( + f" - Max additional searches: {mode_max_additional_searches}" + ) + logger.info(f" - Results per search: {num_results}") + + # Check if a mode-specific config exists and user didn't override config + if config == "configs/enhanced_research.yaml": # Default config + mode_specific_config = f"configs/{mode.lower()}_research.yaml" + if os.path.exists(mode_specific_config): + config = mode_specific_config + logger.info(f" - Using mode-specific config: {config}") + + # Suggest approval for deep mode if not already enabled + if mode_config.get("suggest_approval") and not require_approval: + logger.info(f"\n{'!' * 60}") + logger.info( + f"! TIP: Consider using --require-approval with {mode} mode" + ) + logger.info(f"! for better control over comprehensive research") + logger.info(f"{'!' * 60}") + + logger.info(f"{'=' * 80}\n") + else: + # Default values if no mode specified + mode_max_additional_searches = 2 + + # Check that required environment variables are present using the helper function + required_vars = ["SAMBANOVA_API_KEY"] + + # Add provider-specific API key requirements + if search_provider in {"exa", "both"}: + required_vars.append("EXA_API_KEY") + if search_provider in {"tavily", "both", None}: # Default is tavily + required_vars.append("TAVILY_API_KEY") + + if missing_vars := check_required_env_vars(required_vars): + logger.error( + f"The following required environment variables are not set: {', '.join(missing_vars)}" + ) + logger.info("Please set them with:") + for var in missing_vars: + logger.info(f" export {var}=your_{var.lower()}_here") + return + + # Set pipeline options + pipeline_options = {"config_path": config} + + if no_cache: + pipeline_options["enable_cache"] = False + + logger.info("\n" + "=" * 80) + logger.info("Starting Deep Research") + logger.info("Using parallel pipeline for efficient execution") + + # Log search provider settings + if search_provider: + logger.info(f"Search provider: {search_provider.upper()}") + if search_provider == "exa": + logger.info(f" - Search mode: {search_mode}") + elif search_provider == "both": + logger.info(f" - Running both Tavily and Exa searches") + logger.info(f" - Exa search mode: {search_mode}") + else: + logger.info("Search provider: TAVILY (default)") + + # Log num_results if custom value or no mode preset + if num_results != 3 or not mode: + logger.info(f"Results per search: {num_results}") + + langfuse_project_name = "deep-research" # default + try: + with open(config, "r") as f: + config_data = yaml.safe_load(f) + langfuse_project_name = config_data.get( + "langfuse_project_name", "deep-research" + ) + except Exception as e: + logger.warning( + f"Could not load langfuse_project_name from config: {e}" + ) + + # Set up the pipeline with the parallelized version as default + pipeline = parallelized_deep_research_pipeline.with_options( + **pipeline_options + ) + + # Execute the pipeline + if query: + logger.info( + f"Using query: {query} with max {max_sub_questions} parallel sub-questions" + ) + if require_approval: + logger.info( + f"Human approval enabled with {approval_timeout}s timeout" + ) + pipeline( + query=query, + max_sub_questions=max_sub_questions, + require_approval=require_approval, + approval_timeout=approval_timeout, + max_additional_searches=mode_max_additional_searches, + search_provider=search_provider or "tavily", + search_mode=search_mode, + num_results_per_search=num_results, + langfuse_project_name=langfuse_project_name, + ) + else: + logger.info( + f"Using query from config file with max {max_sub_questions} parallel sub-questions" + ) + if require_approval: + logger.info( + f"Human approval enabled with {approval_timeout}s timeout" + ) + pipeline( + max_sub_questions=max_sub_questions, + require_approval=require_approval, + approval_timeout=approval_timeout, + max_additional_searches=mode_max_additional_searches, + search_provider=search_provider or "tavily", + search_mode=search_mode, + num_results_per_search=num_results, + langfuse_project_name=langfuse_project_name, + ) + + +if __name__ == "__main__": + main() diff --git a/deep_research/steps/__init__.py b/deep_research/steps/__init__.py new file mode 100644 index 00000000..1d454e49 --- /dev/null +++ b/deep_research/steps/__init__.py @@ -0,0 +1,7 @@ +""" +Steps package for the ZenML Deep Research project. + +This package contains individual ZenML steps used in the research pipelines. +Each step is responsible for a specific part of the research process, such as +query decomposition, searching, synthesis, and report generation. +""" diff --git a/deep_research/steps/approval_step.py b/deep_research/steps/approval_step.py new file mode 100644 index 00000000..97a283a1 --- /dev/null +++ b/deep_research/steps/approval_step.py @@ -0,0 +1,308 @@ +import logging +import time +from typing import Annotated + +from materializers.approval_decision_materializer import ( + ApprovalDecisionMaterializer, +) +from utils.approval_utils import ( + format_approval_request, + summarize_research_progress, +) +from utils.pydantic_models import ApprovalDecision, ReflectionOutput +from zenml import log_metadata, step +from zenml.client import Client + +logger = logging.getLogger(__name__) + + +@step( + enable_cache=False, output_materializers=ApprovalDecisionMaterializer +) # Never cache approval decisions +def get_research_approval_step( + reflection_output: ReflectionOutput, + require_approval: bool = True, + alerter_type: str = "slack", + timeout: int = 3600, + max_queries: int = 2, +) -> Annotated[ApprovalDecision, "approval_decision"]: + """ + Get human approval for additional research queries. + + Always returns an ApprovalDecision object. If require_approval is False, + automatically approves all queries. + + Args: + reflection_output: Output from the reflection generation step + require_approval: Whether to require human approval + alerter_type: Type of alerter to use (slack, email, etc.) + timeout: Timeout in seconds for approval response + max_queries: Maximum number of queries to approve + + Returns: + ApprovalDecision object with approval status and selected queries + """ + start_time = time.time() + + # Limit queries to max_queries + limited_queries = reflection_output.recommended_queries[:max_queries] + + # If approval not required, auto-approve all + if not require_approval: + logger.info( + f"Auto-approving {len(limited_queries)} recommended queries (approval not required)" + ) + + # Log metadata for auto-approval + execution_time = time.time() - start_time + log_metadata( + metadata={ + "approval_decision": { + "execution_time_seconds": execution_time, + "approval_required": False, + "approval_method": "AUTO_APPROVED", + "num_queries_recommended": len( + reflection_output.recommended_queries + ), + "num_queries_approved": len(limited_queries), + "max_queries_allowed": max_queries, + "approval_status": "approved", + "wait_time_seconds": 0, + } + } + ) + + return ApprovalDecision( + approved=True, + selected_queries=limited_queries, + approval_method="AUTO_APPROVED", + reviewer_notes="Approval not required by configuration", + ) + + # If no queries to approve, skip + if not limited_queries: + logger.info("No additional queries recommended") + + # Log metadata for no queries + execution_time = time.time() - start_time + log_metadata( + metadata={ + "approval_decision": { + "execution_time_seconds": execution_time, + "approval_required": require_approval, + "approval_method": "NO_QUERIES", + "num_queries_recommended": 0, + "num_queries_approved": 0, + "max_queries_allowed": max_queries, + "approval_status": "skipped", + "wait_time_seconds": 0, + } + } + ) + + return ApprovalDecision( + approved=False, + selected_queries=[], + approval_method="NO_QUERIES", + reviewer_notes="No additional queries recommended", + ) + + # Prepare approval request + progress_summary = summarize_research_progress(reflection_output.state) + message = format_approval_request( + main_query=reflection_output.state.main_query, + progress_summary=progress_summary, + critique_points=reflection_output.critique_summary, + proposed_queries=limited_queries, + timeout=timeout, + ) + + # Log the approval request for visibility + logger.info("=" * 80) + logger.info("APPROVAL REQUEST:") + logger.info(message) + logger.info("=" * 80) + + try: + # Get the alerter from the active stack + client = Client() + alerter = client.active_stack.alerter + + if not alerter: + logger.warning("No alerter configured in stack, auto-approving") + + # Log metadata for no alerter scenario + execution_time = time.time() - start_time + log_metadata( + metadata={ + "approval_decision": { + "execution_time_seconds": execution_time, + "approval_required": require_approval, + "approval_method": "NO_ALERTER_AUTO_APPROVED", + "alerter_type": "none", + "num_queries_recommended": len( + reflection_output.recommended_queries + ), + "num_queries_approved": len(limited_queries), + "max_queries_allowed": max_queries, + "approval_status": "auto_approved", + "wait_time_seconds": 0, + } + } + ) + + return ApprovalDecision( + approved=True, + selected_queries=limited_queries, + approval_method="NO_ALERTER_AUTO_APPROVED", + reviewer_notes="No alerter configured - auto-approved", + ) + + # Use the alerter's ask method for interactive approval + try: + # Send the message to Discord and wait for response + logger.info( + f"Sending approval request to {alerter.flavor} alerter" + ) + + # Format message for Discord (Discord has message length limits) + discord_message = ( + f"**Research Approval Request**\n\n{message[:1900]}" + ) + if len(message) > 1900: + discord_message += ( + "\n\n*(Message truncated due to Discord limits)*" + ) + + # Add instructions for Discord responses + discord_message += "\n\n**How to respond:**\n" + discord_message += "✅ Type `yes`, `approve`, `ok`, or `LGTM` to approve ALL queries\n" + discord_message += "❌ Type `no`, `skip`, `reject`, or `decline` to skip additional research\n" + discord_message += f"⏱️ Response timeout: {timeout} seconds" + + # Use the ask method to get user response + logger.info("Waiting for approval response from Discord...") + wait_start_time = time.time() + approved = alerter.ask(discord_message) + wait_end_time = time.time() + wait_time = wait_end_time - wait_start_time + + logger.info( + f"Received Discord response: {'approved' if approved else 'rejected'}" + ) + + if approved: + # Log metadata for approved decision + execution_time = time.time() - start_time + log_metadata( + metadata={ + "approval_decision": { + "execution_time_seconds": execution_time, + "approval_required": require_approval, + "approval_method": "DISCORD_APPROVED", + "alerter_type": alerter_type, + "num_queries_recommended": len( + reflection_output.recommended_queries + ), + "num_queries_approved": len(limited_queries), + "max_queries_allowed": max_queries, + "approval_status": "approved", + "wait_time_seconds": wait_time, + "timeout_configured": timeout, + } + } + ) + + return ApprovalDecision( + approved=True, + selected_queries=limited_queries, + approval_method="DISCORD_APPROVED", + reviewer_notes="Approved via Discord", + ) + else: + # Log metadata for rejected decision + execution_time = time.time() - start_time + log_metadata( + metadata={ + "approval_decision": { + "execution_time_seconds": execution_time, + "approval_required": require_approval, + "approval_method": "DISCORD_REJECTED", + "alerter_type": alerter_type, + "num_queries_recommended": len( + reflection_output.recommended_queries + ), + "num_queries_approved": 0, + "max_queries_allowed": max_queries, + "approval_status": "rejected", + "wait_time_seconds": wait_time, + "timeout_configured": timeout, + } + } + ) + + return ApprovalDecision( + approved=False, + selected_queries=[], + approval_method="DISCORD_REJECTED", + reviewer_notes="Rejected via Discord", + ) + + except Exception as e: + logger.error(f"Failed to get approval from alerter: {e}") + + # Log metadata for alerter error + execution_time = time.time() - start_time + log_metadata( + metadata={ + "approval_decision": { + "execution_time_seconds": execution_time, + "approval_required": require_approval, + "approval_method": "ALERTER_ERROR", + "alerter_type": alerter_type, + "num_queries_recommended": len( + reflection_output.recommended_queries + ), + "num_queries_approved": 0, + "max_queries_allowed": max_queries, + "approval_status": "error", + "error_message": str(e), + } + } + ) + + return ApprovalDecision( + approved=False, + selected_queries=[], + approval_method="ALERTER_ERROR", + reviewer_notes=f"Failed to get approval: {str(e)}", + ) + + except Exception as e: + logger.error(f"Approval step failed: {e}") + + # Log metadata for general error + execution_time = time.time() - start_time + log_metadata( + metadata={ + "approval_decision": { + "execution_time_seconds": execution_time, + "approval_required": require_approval, + "approval_method": "ERROR", + "num_queries_recommended": len( + reflection_output.recommended_queries + ), + "num_queries_approved": 0, + "max_queries_allowed": max_queries, + "approval_status": "error", + "error_message": str(e), + } + } + ) + + return ApprovalDecision( + approved=False, + selected_queries=[], + approval_method="ERROR", + reviewer_notes=f"Approval failed: {str(e)}", + ) diff --git a/deep_research/steps/collect_tracing_metadata_step.py b/deep_research/steps/collect_tracing_metadata_step.py new file mode 100644 index 00000000..4f01505e --- /dev/null +++ b/deep_research/steps/collect_tracing_metadata_step.py @@ -0,0 +1,234 @@ +"""Step to collect tracing metadata from Langfuse for the pipeline run.""" + +import logging +from typing import Annotated, Tuple + +from materializers.pydantic_materializer import ResearchStateMaterializer +from materializers.tracing_metadata_materializer import ( + TracingMetadataMaterializer, +) +from utils.pydantic_models import ( + PromptTypeMetrics, + ResearchState, + TracingMetadata, +) +from utils.tracing_metadata_utils import ( + get_observations_for_trace, + get_prompt_type_statistics, + get_trace_stats, + get_traces_by_name, +) +from zenml import get_step_context, step + +logger = logging.getLogger(__name__) + + +@step( + enable_cache=False, + output_materializers={ + "state": ResearchStateMaterializer, + "tracing_metadata": TracingMetadataMaterializer, + }, +) +def collect_tracing_metadata_step( + state: ResearchState, + langfuse_project_name: str, +) -> Tuple[ + Annotated[ResearchState, "state"], + Annotated[TracingMetadata, "tracing_metadata"], +]: + """Collect tracing metadata from Langfuse for the current pipeline run. + + This step gathers comprehensive metrics about token usage, costs, and performance + for the entire pipeline run, providing insights into resource consumption. + + Args: + state: The final research state + langfuse_project_name: Langfuse project name for accessing traces + + Returns: + Tuple of (ResearchState, TracingMetadata) - the state is passed through unchanged + """ + ctx = get_step_context() + pipeline_run_name = ctx.pipeline_run.name + pipeline_run_id = str(ctx.pipeline_run.id) + + logger.info( + f"Collecting tracing metadata for pipeline run: {pipeline_run_name} (ID: {pipeline_run_id})" + ) + + # Initialize the metadata object + metadata = TracingMetadata( + pipeline_run_name=pipeline_run_name, + pipeline_run_id=pipeline_run_id, + trace_name=pipeline_run_name, + trace_id=pipeline_run_id, + ) + + try: + # Fetch the trace for this pipeline run + # The trace_name is the pipeline run name + traces = get_traces_by_name(name=pipeline_run_name, limit=1) + + if not traces: + logger.warning( + f"No trace found for pipeline run: {pipeline_run_name}" + ) + return state, metadata + + trace = traces[0] + + # Get comprehensive trace stats + trace_stats = get_trace_stats(trace) + + # Update metadata with trace stats + metadata.trace_id = trace.id + metadata.total_cost = trace_stats["total_cost"] + metadata.total_input_tokens = trace_stats["input_tokens"] + metadata.total_output_tokens = trace_stats["output_tokens"] + metadata.total_tokens = ( + trace_stats["input_tokens"] + trace_stats["output_tokens"] + ) + metadata.total_latency_seconds = trace_stats["latency_seconds"] + metadata.formatted_latency = trace_stats["latency_formatted"] + metadata.observation_count = trace_stats["observation_count"] + metadata.models_used = trace_stats["models_used"] + metadata.trace_tags = trace_stats.get("tags", []) + metadata.trace_metadata = trace_stats.get("metadata", {}) + + # Get model-specific breakdown + observations = get_observations_for_trace(trace_id=trace.id) + model_costs = {} + model_tokens = {} + step_costs = {} + step_tokens = {} + + for obs in observations: + if obs.model: + # Track by model + if obs.model not in model_costs: + model_costs[obs.model] = 0.0 + model_tokens[obs.model] = { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + } + + if obs.calculated_total_cost: + model_costs[obs.model] += obs.calculated_total_cost + + if obs.usage: + input_tokens = obs.usage.input or 0 + output_tokens = obs.usage.output or 0 + model_tokens[obs.model]["input_tokens"] += input_tokens + model_tokens[obs.model]["output_tokens"] += output_tokens + model_tokens[obs.model]["total_tokens"] += ( + input_tokens + output_tokens + ) + + # Track by step (using observation name as step indicator) + if obs.name: + step_name = obs.name + + if step_name not in step_costs: + step_costs[step_name] = 0.0 + step_tokens[step_name] = { + "input_tokens": 0, + "output_tokens": 0, + } + + if obs.calculated_total_cost: + step_costs[step_name] += obs.calculated_total_cost + + if obs.usage: + input_tokens = obs.usage.input or 0 + output_tokens = obs.usage.output or 0 + step_tokens[step_name]["input_tokens"] += input_tokens + step_tokens[step_name]["output_tokens"] += output_tokens + + metadata.cost_breakdown_by_model = model_costs + metadata.model_token_breakdown = model_tokens + metadata.step_costs = step_costs + metadata.step_tokens = step_tokens + + # Collect prompt-level metrics + try: + prompt_stats = get_prompt_type_statistics(trace_id=trace.id) + + # Convert to PromptTypeMetrics objects + prompt_metrics_list = [] + for prompt_type, stats in prompt_stats.items(): + prompt_metrics = PromptTypeMetrics( + prompt_type=prompt_type, + total_cost=stats["cost"], + input_tokens=stats["input_tokens"], + output_tokens=stats["output_tokens"], + call_count=stats["count"], + avg_cost_per_call=stats["avg_cost_per_call"], + percentage_of_total_cost=stats["percentage_of_total_cost"], + ) + prompt_metrics_list.append(prompt_metrics) + + # Sort by total cost descending + prompt_metrics_list.sort(key=lambda x: x.total_cost, reverse=True) + metadata.prompt_metrics = prompt_metrics_list + + logger.info( + f"Collected prompt-level metrics for {len(prompt_metrics_list)} prompt types" + ) + except Exception as e: + logger.warning(f"Failed to collect prompt-level metrics: {str(e)}") + + # Add search costs from the state + if hasattr(state, "search_costs") and state.search_costs: + metadata.search_costs = state.search_costs.copy() + logger.info(f"Added search costs: {metadata.search_costs}") + + if hasattr(state, "search_cost_details") and state.search_cost_details: + metadata.search_cost_details = state.search_cost_details.copy() + + # Count queries by provider + search_queries_count = {} + for detail in state.search_cost_details: + provider = detail.get("provider", "unknown") + search_queries_count[provider] = ( + search_queries_count.get(provider, 0) + 1 + ) + metadata.search_queries_count = search_queries_count + + logger.info( + f"Added {len(metadata.search_cost_details)} search cost detail entries" + ) + + total_search_cost = sum(metadata.search_costs.values()) + logger.info( + f"Successfully collected tracing metadata - " + f"LLM Cost: ${metadata.total_cost:.4f}, " + f"Search Cost: ${total_search_cost:.4f}, " + f"Total Cost: ${metadata.total_cost + total_search_cost:.4f}, " + f"Tokens: {metadata.total_tokens:,}, " + f"Models: {metadata.models_used}, " + f"Duration: {metadata.formatted_latency}" + ) + + except Exception as e: + logger.error( + f"Failed to collect tracing metadata for pipeline run {pipeline_run_name}: {str(e)}" + ) + # Return metadata with whatever we could collect + + # Still try to get search costs even if Langfuse failed + if hasattr(state, "search_costs") and state.search_costs: + metadata.search_costs = state.search_costs.copy() + if hasattr(state, "search_cost_details") and state.search_cost_details: + metadata.search_cost_details = state.search_cost_details.copy() + # Count queries by provider + search_queries_count = {} + for detail in state.search_cost_details: + provider = detail.get("provider", "unknown") + search_queries_count[provider] = ( + search_queries_count.get(provider, 0) + 1 + ) + metadata.search_queries_count = search_queries_count + + return state, metadata diff --git a/deep_research/steps/cross_viewpoint_step.py b/deep_research/steps/cross_viewpoint_step.py new file mode 100644 index 00000000..ad9c5ddb --- /dev/null +++ b/deep_research/steps/cross_viewpoint_step.py @@ -0,0 +1,228 @@ +import json +import logging +import time +from typing import Annotated, List + +from materializers.pydantic_materializer import ResearchStateMaterializer +from utils.helper_functions import ( + safe_json_loads, +) +from utils.llm_utils import run_llm_completion +from utils.prompt_models import PromptsBundle +from utils.pydantic_models import ( + ResearchState, + ViewpointAnalysis, + ViewpointTension, +) +from zenml import log_metadata, step + +logger = logging.getLogger(__name__) + + +@step(output_materializers=ResearchStateMaterializer) +def cross_viewpoint_analysis_step( + state: ResearchState, + prompts_bundle: PromptsBundle, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + viewpoint_categories: List[str] = [ + "scientific", + "political", + "economic", + "social", + "ethical", + "historical", + ], + langfuse_project_name: str = "deep-research", +) -> Annotated[ResearchState, "analyzed_state"]: + """Analyze synthesized information across different viewpoints. + + Args: + state: The current research state + prompts_bundle: Bundle containing all prompts for the pipeline + llm_model: The model to use for viewpoint analysis + viewpoint_categories: Categories of viewpoints to analyze + + Returns: + Updated research state with viewpoint analysis + """ + start_time = time.time() + logger.info( + f"Performing cross-viewpoint analysis on {len(state.synthesized_info)} sub-questions" + ) + + # Prepare input for viewpoint analysis + analysis_input = { + "main_query": state.main_query, + "sub_questions": state.sub_questions, + "synthesized_information": { + question: { + "synthesized_answer": info.synthesized_answer, + "key_sources": info.key_sources, + "confidence_level": info.confidence_level, + "information_gaps": info.information_gaps, + } + for question, info in state.synthesized_info.items() + }, + "viewpoint_categories": viewpoint_categories, + } + + # Perform viewpoint analysis + try: + logger.info(f"Calling {llm_model} for viewpoint analysis") + # Get the prompt from the bundle + system_prompt = prompts_bundle.get_prompt_content( + "viewpoint_analysis_prompt" + ) + + # Use the run_llm_completion function from llm_utils + content = run_llm_completion( + prompt=json.dumps(analysis_input), + system_prompt=system_prompt, + model=llm_model, # Model name will be prefixed in the function + max_tokens=3000, # Further increased for more comprehensive viewpoint analysis + project=langfuse_project_name, + ) + + result = safe_json_loads(content) + + if not result: + logger.warning("Failed to parse viewpoint analysis result") + # Create a default viewpoint analysis + viewpoint_analysis = ViewpointAnalysis( + main_points_of_agreement=[ + "Analysis failed to identify points of agreement." + ], + perspective_gaps="Analysis failed to identify perspective gaps.", + integrative_insights="Analysis failed to provide integrative insights.", + ) + else: + # Create tension objects + tensions = [] + for tension_data in result.get("areas_of_tension", []): + tensions.append( + ViewpointTension( + topic=tension_data.get("topic", ""), + viewpoints=tension_data.get("viewpoints", {}), + ) + ) + + # Create the viewpoint analysis object + viewpoint_analysis = ViewpointAnalysis( + main_points_of_agreement=result.get( + "main_points_of_agreement", [] + ), + areas_of_tension=tensions, + perspective_gaps=result.get("perspective_gaps", ""), + integrative_insights=result.get("integrative_insights", ""), + ) + + logger.info("Completed viewpoint analysis") + + # Update the state with the viewpoint analysis + state.update_viewpoint_analysis(viewpoint_analysis) + + # Calculate execution time + execution_time = time.time() - start_time + + # Count viewpoint tensions by category + tension_categories = {} + for tension in viewpoint_analysis.areas_of_tension: + for category in tension.viewpoints.keys(): + tension_categories[category] = ( + tension_categories.get(category, 0) + 1 + ) + + # Log metadata + log_metadata( + metadata={ + "viewpoint_analysis": { + "execution_time_seconds": execution_time, + "llm_model": llm_model, + "num_sub_questions_analyzed": len(state.synthesized_info), + "viewpoint_categories_requested": viewpoint_categories, + "num_agreement_points": len( + viewpoint_analysis.main_points_of_agreement + ), + "num_tension_areas": len( + viewpoint_analysis.areas_of_tension + ), + "tension_categories_distribution": tension_categories, + "has_perspective_gaps": bool( + viewpoint_analysis.perspective_gaps + and viewpoint_analysis.perspective_gaps != "" + ), + "has_integrative_insights": bool( + viewpoint_analysis.integrative_insights + and viewpoint_analysis.integrative_insights != "" + ), + "analysis_success": not viewpoint_analysis.main_points_of_agreement[ + 0 + ].startswith("Analysis failed"), + } + } + ) + + # Log model metadata for cross-pipeline tracking + log_metadata( + metadata={ + "research_scope": { + "num_tension_areas": len( + viewpoint_analysis.areas_of_tension + ), + } + }, + infer_model=True, + ) + + # Log artifact metadata + log_metadata( + metadata={ + "state_with_viewpoint_analysis": { + "has_viewpoint_analysis": True, + "total_viewpoints_analyzed": sum( + tension_categories.values() + ), + "most_common_tension_category": max( + tension_categories, key=tension_categories.get + ) + if tension_categories + else None, + } + }, + infer_artifact=True, + ) + + return state + + except Exception as e: + logger.error(f"Error performing viewpoint analysis: {e}") + + # Create a fallback viewpoint analysis + fallback_analysis = ViewpointAnalysis( + main_points_of_agreement=[ + "Analysis failed due to technical error." + ], + perspective_gaps=f"Analysis failed: {str(e)}", + integrative_insights="No insights available due to analysis failure.", + ) + + # Update the state with the fallback analysis + state.update_viewpoint_analysis(fallback_analysis) + + # Log error metadata + execution_time = time.time() - start_time + log_metadata( + metadata={ + "viewpoint_analysis": { + "execution_time_seconds": execution_time, + "llm_model": llm_model, + "num_sub_questions_analyzed": len(state.synthesized_info), + "viewpoint_categories_requested": viewpoint_categories, + "analysis_success": False, + "error_message": str(e), + "fallback_used": True, + } + } + ) + + return state diff --git a/deep_research/steps/execute_approved_searches_step.py b/deep_research/steps/execute_approved_searches_step.py new file mode 100644 index 00000000..90a15718 --- /dev/null +++ b/deep_research/steps/execute_approved_searches_step.py @@ -0,0 +1,423 @@ +import json +import logging +import time +from typing import Annotated + +from materializers.pydantic_materializer import ResearchStateMaterializer +from utils.llm_utils import ( + find_most_relevant_string, + get_structured_llm_output, + is_text_relevant, +) +from utils.prompt_models import PromptsBundle +from utils.pydantic_models import ( + ApprovalDecision, + ReflectionMetadata, + ReflectionOutput, + ResearchState, + SynthesizedInfo, +) +from utils.search_utils import search_and_extract_results +from zenml import log_metadata, step + +logger = logging.getLogger(__name__) + + +def create_enhanced_info_copy(synthesized_info): + """Create a deep copy of synthesized info for enhancement.""" + return { + k: SynthesizedInfo( + synthesized_answer=v.synthesized_answer, + key_sources=v.key_sources.copy(), + confidence_level=v.confidence_level, + information_gaps=v.information_gaps, + improvements=v.improvements.copy() + if hasattr(v, "improvements") + else [], + ) + for k, v in synthesized_info.items() + } + + +@step(output_materializers=ResearchStateMaterializer) +def execute_approved_searches_step( + reflection_output: ReflectionOutput, + approval_decision: ApprovalDecision, + prompts_bundle: PromptsBundle, + num_results_per_search: int = 3, + cap_search_length: int = 20000, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + search_provider: str = "tavily", + search_mode: str = "auto", + langfuse_project_name: str = "deep-research", +) -> Annotated[ResearchState, "updated_state"]: + """Execute approved searches and enhance the research state. + + This step receives the approval decision and only executes + searches that were approved by the human reviewer (or auto-approved). + + Args: + reflection_output: Output from the reflection generation step + approval_decision: Human approval decision + num_results_per_search: Number of results to fetch per search + cap_search_length: Maximum length of content to process from search results + llm_model: The model to use for synthesis enhancement + prompts_bundle: Bundle containing all prompts for the pipeline + search_provider: Search provider to use + search_mode: Search mode for the provider + + Returns: + Updated research state with enhanced information and reflection metadata + """ + start_time = time.time() + logger.info( + f"Processing approval decision: {approval_decision.approval_method}" + ) + + state = reflection_output.state + enhanced_info = create_enhanced_info_copy(state.synthesized_info) + + # Track improvements count + improvements_count = 0 + + # Check if we should execute searches + if ( + not approval_decision.approved + or not approval_decision.selected_queries + ): + logger.info("No additional searches approved") + + # Add any additional questions as new synthesized entries (from reflection) + for new_question in reflection_output.additional_questions: + if ( + new_question not in state.sub_questions + and new_question not in enhanced_info + ): + enhanced_info[new_question] = SynthesizedInfo( + synthesized_answer=f"This question was identified during reflection but has not yet been researched: {new_question}", + key_sources=[], + confidence_level="low", + information_gaps="This question requires additional research.", + ) + + # Create metadata indicating no additional research + reflection_metadata = ReflectionMetadata( + critique_summary=[ + c.get("issue", "") for c in reflection_output.critique_summary + ], + additional_questions_identified=reflection_output.additional_questions, + searches_performed=[], + improvements_made=improvements_count, + ) + + # Add approval decision info to metadata + if hasattr(reflection_metadata, "__dict__"): + reflection_metadata.__dict__["user_decision"] = ( + approval_decision.approval_method + ) + reflection_metadata.__dict__["reviewer_notes"] = ( + approval_decision.reviewer_notes + ) + + state.update_after_reflection(enhanced_info, reflection_metadata) + + # Log metadata for no approved searches + execution_time = time.time() - start_time + log_metadata( + metadata={ + "execute_approved_searches": { + "execution_time_seconds": execution_time, + "approval_method": approval_decision.approval_method, + "approval_status": "not_approved" + if not approval_decision.approved + else "no_queries", + "num_queries_approved": 0, + "num_searches_executed": 0, + "num_additional_questions": len( + reflection_output.additional_questions + ), + "improvements_made": improvements_count, + "search_provider": search_provider, + "llm_model": llm_model, + } + } + ) + + return state + + # Execute approved searches + logger.info( + f"Executing {len(approval_decision.selected_queries)} approved searches" + ) + + try: + search_enhancements = [] # Track search results for metadata + + for query in approval_decision.selected_queries: + logger.info(f"Performing approved search: {query}") + + # Execute search using the utility function + search_results, search_cost = search_and_extract_results( + query=query, + max_results=num_results_per_search, + cap_content_length=cap_search_length, + provider=search_provider, + search_mode=search_mode, + ) + + # Track search costs if using Exa + if ( + search_provider + and search_provider.lower() in ["exa", "both"] + and search_cost > 0 + ): + # Update total costs + state.search_costs["exa"] = ( + state.search_costs.get("exa", 0.0) + search_cost + ) + + # Add detailed cost entry + state.search_cost_details.append( + { + "provider": "exa", + "query": query, + "cost": search_cost, + "timestamp": time.time(), + "step": "execute_approved_searches", + "purpose": "reflection_enhancement", + } + ) + logger.info( + f"Exa search cost for approved query: ${search_cost:.4f}" + ) + + # Extract raw contents + raw_contents = [result.content for result in search_results] + + # Find the most relevant sub-question for this query + most_relevant_question = find_most_relevant_string( + query, + state.sub_questions, + llm_model, + project=langfuse_project_name, + ) + + if ( + most_relevant_question + and most_relevant_question in enhanced_info + ): + # Enhance the synthesis with new information + enhancement_input = { + "original_synthesis": enhanced_info[ + most_relevant_question + ].synthesized_answer, + "new_information": raw_contents, + "critique": [ + item + for item in reflection_output.critique_summary + if is_text_relevant( + item.get("issue", ""), most_relevant_question + ) + ], + } + + # Get the prompt from the bundle + additional_synthesis_prompt = ( + prompts_bundle.get_prompt_content( + "additional_synthesis_prompt" + ) + ) + + # Use the utility function for enhancement + enhanced_synthesis = get_structured_llm_output( + prompt=json.dumps(enhancement_input), + system_prompt=additional_synthesis_prompt, + model=llm_model, + fallback_response={ + "enhanced_synthesis": enhanced_info[ + most_relevant_question + ].synthesized_answer, + "improvements_made": ["Failed to enhance synthesis"], + "remaining_limitations": "Enhancement process failed.", + }, + project=langfuse_project_name, + ) + + if ( + enhanced_synthesis + and "enhanced_synthesis" in enhanced_synthesis + ): + # Update the synthesized answer + enhanced_info[ + most_relevant_question + ].synthesized_answer = enhanced_synthesis[ + "enhanced_synthesis" + ] + + # Add improvements + improvements = enhanced_synthesis.get( + "improvements_made", [] + ) + enhanced_info[most_relevant_question].improvements.extend( + improvements + ) + improvements_count += len(improvements) + + # Track enhancement for metadata + search_enhancements.append( + { + "query": query, + "relevant_question": most_relevant_question, + "num_results": len(search_results), + "improvements": len(improvements), + "enhanced": True, + "search_cost": search_cost + if search_provider + and search_provider.lower() in ["exa", "both"] + else 0.0, + } + ) + + # Add any additional questions as new synthesized entries + for new_question in reflection_output.additional_questions: + if ( + new_question not in state.sub_questions + and new_question not in enhanced_info + ): + enhanced_info[new_question] = SynthesizedInfo( + synthesized_answer=f"This question was identified during reflection but has not yet been researched: {new_question}", + key_sources=[], + confidence_level="low", + information_gaps="This question requires additional research.", + ) + + # Create final metadata with approval info + reflection_metadata = ReflectionMetadata( + critique_summary=[ + c.get("issue", "") for c in reflection_output.critique_summary + ], + additional_questions_identified=reflection_output.additional_questions, + searches_performed=approval_decision.selected_queries, + improvements_made=improvements_count, + ) + + # Add approval decision info to metadata + if hasattr(reflection_metadata, "__dict__"): + reflection_metadata.__dict__["user_decision"] = ( + approval_decision.approval_method + ) + reflection_metadata.__dict__["reviewer_notes"] = ( + approval_decision.reviewer_notes + ) + + logger.info( + f"Completed approved searches with {improvements_count} improvements" + ) + + state.update_after_reflection(enhanced_info, reflection_metadata) + + # Calculate metrics for metadata + execution_time = time.time() - start_time + total_results = sum( + e.get("num_results", 0) for e in search_enhancements + ) + questions_enhanced = len( + set( + e.get("relevant_question") + for e in search_enhancements + if e.get("enhanced") + ) + ) + + # Log successful execution metadata + log_metadata( + metadata={ + "execute_approved_searches": { + "execution_time_seconds": execution_time, + "approval_method": approval_decision.approval_method, + "approval_status": "approved", + "num_queries_recommended": len( + reflection_output.recommended_queries + ), + "num_queries_approved": len( + approval_decision.selected_queries + ), + "num_searches_executed": len( + approval_decision.selected_queries + ), + "total_search_results": total_results, + "questions_enhanced": questions_enhanced, + "improvements_made": improvements_count, + "num_additional_questions": len( + reflection_output.additional_questions + ), + "search_provider": search_provider, + "search_mode": search_mode, + "llm_model": llm_model, + "success": True, + "total_search_cost": state.search_costs.get("exa", 0.0), + } + } + ) + + # Log artifact metadata + log_metadata( + metadata={ + "enhanced_state_after_approval": { + "total_questions": len(enhanced_info), + "questions_with_improvements": sum( + 1 + for info in enhanced_info.values() + if info.improvements + ), + "total_improvements": sum( + len(info.improvements) + for info in enhanced_info.values() + ), + "approval_method": approval_decision.approval_method, + } + }, + infer_artifact=True, + ) + + return state + + except Exception as e: + logger.error(f"Error during approved search execution: {e}") + + # Create error metadata + error_metadata = ReflectionMetadata( + error=f"Approved search execution failed: {str(e)}", + critique_summary=[ + c.get("issue", "") for c in reflection_output.critique_summary + ], + additional_questions_identified=reflection_output.additional_questions, + searches_performed=[], + improvements_made=0, + ) + + # Update the state with the original synthesized info as enhanced info + state.update_after_reflection(state.synthesized_info, error_metadata) + + # Log error metadata + execution_time = time.time() - start_time + log_metadata( + metadata={ + "execute_approved_searches": { + "execution_time_seconds": execution_time, + "approval_method": approval_decision.approval_method, + "approval_status": "approved", + "num_queries_approved": len( + approval_decision.selected_queries + ), + "num_searches_executed": 0, + "improvements_made": 0, + "search_provider": search_provider, + "llm_model": llm_model, + "success": False, + "error_message": str(e), + } + } + ) + + return state diff --git a/deep_research/steps/generate_reflection_step.py b/deep_research/steps/generate_reflection_step.py new file mode 100644 index 00000000..e8547af6 --- /dev/null +++ b/deep_research/steps/generate_reflection_step.py @@ -0,0 +1,167 @@ +import json +import logging +import time +from typing import Annotated + +from materializers.reflection_output_materializer import ( + ReflectionOutputMaterializer, +) +from utils.llm_utils import get_structured_llm_output +from utils.prompt_models import PromptsBundle +from utils.pydantic_models import ReflectionOutput, ResearchState +from zenml import log_metadata, step + +logger = logging.getLogger(__name__) + + +@step(output_materializers=ReflectionOutputMaterializer) +def generate_reflection_step( + state: ResearchState, + prompts_bundle: PromptsBundle, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + langfuse_project_name: str = "deep-research", +) -> Annotated[ReflectionOutput, "reflection_output"]: + """ + Generate reflection and recommendations WITHOUT executing searches. + + This step only analyzes the current state and produces recommendations + for additional research that could improve the quality of the results. + + Args: + state: The current research state + prompts_bundle: Bundle containing all prompts for the pipeline + llm_model: The model to use for reflection + + Returns: + ReflectionOutput containing the state, recommendations, and critique + """ + start_time = time.time() + logger.info("Generating reflection on research") + + # Prepare input for reflection + synthesized_info_dict = { + question: { + "synthesized_answer": info.synthesized_answer, + "key_sources": info.key_sources, + "confidence_level": info.confidence_level, + "information_gaps": info.information_gaps, + } + for question, info in state.synthesized_info.items() + } + + viewpoint_analysis_dict = None + if state.viewpoint_analysis: + # Convert the viewpoint analysis to a dict for the LLM + tension_list = [] + for tension in state.viewpoint_analysis.areas_of_tension: + tension_list.append( + {"topic": tension.topic, "viewpoints": tension.viewpoints} + ) + + viewpoint_analysis_dict = { + "main_points_of_agreement": state.viewpoint_analysis.main_points_of_agreement, + "areas_of_tension": tension_list, + "perspective_gaps": state.viewpoint_analysis.perspective_gaps, + "integrative_insights": state.viewpoint_analysis.integrative_insights, + } + + reflection_input = { + "main_query": state.main_query, + "sub_questions": state.sub_questions, + "synthesized_information": synthesized_info_dict, + } + + if viewpoint_analysis_dict: + reflection_input["viewpoint_analysis"] = viewpoint_analysis_dict + + # Get reflection critique + logger.info(f"Generating self-critique via {llm_model}") + + # Get the prompt from the bundle + reflection_prompt = prompts_bundle.get_prompt_content("reflection_prompt") + + # Define fallback for reflection result + fallback_reflection = { + "critique": [], + "additional_questions": [], + "recommended_search_queries": [], + } + + # Use utility function to get structured output + reflection_result = get_structured_llm_output( + prompt=json.dumps(reflection_input), + system_prompt=reflection_prompt, + model=llm_model, + fallback_response=fallback_reflection, + project=langfuse_project_name, + ) + + # Prepare return value + reflection_output = ReflectionOutput( + state=state, + recommended_queries=reflection_result.get( + "recommended_search_queries", [] + ), + critique_summary=reflection_result.get("critique", []), + additional_questions=reflection_result.get("additional_questions", []), + ) + + # Calculate execution time + execution_time = time.time() - start_time + + # Count confidence levels in synthesized info + confidence_levels = [ + info.confidence_level for info in state.synthesized_info.values() + ] + confidence_distribution = { + "high": confidence_levels.count("high"), + "medium": confidence_levels.count("medium"), + "low": confidence_levels.count("low"), + } + + # Log step metadata + log_metadata( + metadata={ + "reflection_generation": { + "execution_time_seconds": execution_time, + "llm_model": llm_model, + "num_sub_questions_analyzed": len(state.sub_questions), + "num_synthesized_answers": len(state.synthesized_info), + "viewpoint_analysis_included": bool(viewpoint_analysis_dict), + "num_critique_points": len(reflection_output.critique_summary), + "num_additional_questions": len( + reflection_output.additional_questions + ), + "num_recommended_queries": len( + reflection_output.recommended_queries + ), + "confidence_distribution": confidence_distribution, + "has_information_gaps": any( + info.information_gaps + for info in state.synthesized_info.values() + ), + } + } + ) + + # Log artifact metadata + log_metadata( + metadata={ + "reflection_output_characteristics": { + "has_recommendations": bool( + reflection_output.recommended_queries + ), + "has_critique": bool(reflection_output.critique_summary), + "has_additional_questions": bool( + reflection_output.additional_questions + ), + "total_recommendations": len( + reflection_output.recommended_queries + ) + + len(reflection_output.additional_questions), + } + }, + infer_artifact=True, + ) + + return reflection_output diff --git a/deep_research/steps/initialize_prompts_step.py b/deep_research/steps/initialize_prompts_step.py new file mode 100644 index 00000000..377c6df4 --- /dev/null +++ b/deep_research/steps/initialize_prompts_step.py @@ -0,0 +1,45 @@ +"""Step to initialize and track prompts as artifacts. + +This step creates a PromptsBundle artifact at the beginning of the pipeline, +making all prompts trackable and versioned in ZenML. +""" + +import logging +from typing import Annotated + +from materializers.prompts_materializer import PromptsBundleMaterializer +from utils.prompt_loader import load_prompts_bundle +from utils.prompt_models import PromptsBundle +from zenml import step + +logger = logging.getLogger(__name__) + + +@step(output_materializers=PromptsBundleMaterializer) +def initialize_prompts_step( + pipeline_version: str = "1.1.0", +) -> Annotated[PromptsBundle, "prompts_bundle"]: + """Initialize the prompts bundle for the pipeline. + + This step loads all prompts from the prompts.py module and creates + a PromptsBundle artifact that can be tracked and visualized in ZenML. + + Args: + pipeline_version: Version of the pipeline using these prompts + + Returns: + PromptsBundle containing all prompts used in the pipeline + """ + logger.info( + f"Initializing prompts bundle for pipeline version {pipeline_version}" + ) + + # Load all prompts into a bundle + prompts_bundle = load_prompts_bundle(pipeline_version=pipeline_version) + + # Log some statistics + all_prompts = prompts_bundle.list_all_prompts() + logger.info(f"Loaded {len(all_prompts)} prompts into bundle") + logger.info(f"Prompts: {', '.join(all_prompts.keys())}") + + return prompts_bundle diff --git a/deep_research/steps/iterative_reflection_step.py b/deep_research/steps/iterative_reflection_step.py new file mode 100644 index 00000000..0ed38800 --- /dev/null +++ b/deep_research/steps/iterative_reflection_step.py @@ -0,0 +1,385 @@ +import json +import logging +import time +from typing import Annotated + +from materializers.pydantic_materializer import ResearchStateMaterializer +from utils.llm_utils import ( + find_most_relevant_string, + get_structured_llm_output, + is_text_relevant, +) +from utils.prompts import ADDITIONAL_SYNTHESIS_PROMPT, REFLECTION_PROMPT +from utils.pydantic_models import ( + ReflectionMetadata, + ResearchState, + SynthesizedInfo, +) +from utils.search_utils import search_and_extract_results +from zenml import log_metadata, step + +logger = logging.getLogger(__name__) + + +@step(output_materializers=ResearchStateMaterializer) +def iterative_reflection_step( + state: ResearchState, + max_additional_searches: int = 2, + num_results_per_search: int = 3, + cap_search_length: int = 20000, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + reflection_prompt: str = REFLECTION_PROMPT, + additional_synthesis_prompt: str = ADDITIONAL_SYNTHESIS_PROMPT, +) -> Annotated[ResearchState, "reflected_state"]: + """Perform iterative reflection on the research, identifying gaps and improving it. + + Args: + state: The current research state + max_additional_searches: Maximum number of additional searches to perform + num_results_per_search: Number of results to fetch per search + cap_search_length: Maximum length of content to process from search results + llm_model: The model to use for reflection + reflection_prompt: System prompt for the reflection + additional_synthesis_prompt: System prompt for incorporating additional information + + Returns: + Updated research state with enhanced information and reflection metadata + """ + start_time = time.time() + logger.info("Starting iterative reflection on research") + + # Prepare input for reflection + synthesized_info_dict = { + question: { + "synthesized_answer": info.synthesized_answer, + "key_sources": info.key_sources, + "confidence_level": info.confidence_level, + "information_gaps": info.information_gaps, + } + for question, info in state.synthesized_info.items() + } + + viewpoint_analysis_dict = None + if state.viewpoint_analysis: + # Convert the viewpoint analysis to a dict for the LLM + tension_list = [] + for tension in state.viewpoint_analysis.areas_of_tension: + tension_list.append( + {"topic": tension.topic, "viewpoints": tension.viewpoints} + ) + + viewpoint_analysis_dict = { + "main_points_of_agreement": state.viewpoint_analysis.main_points_of_agreement, + "areas_of_tension": tension_list, + "perspective_gaps": state.viewpoint_analysis.perspective_gaps, + "integrative_insights": state.viewpoint_analysis.integrative_insights, + } + + reflection_input = { + "main_query": state.main_query, + "sub_questions": state.sub_questions, + "synthesized_information": synthesized_info_dict, + } + + if viewpoint_analysis_dict: + reflection_input["viewpoint_analysis"] = viewpoint_analysis_dict + + # Get reflection critique + try: + logger.info(f"Generating self-critique via {llm_model}") + + # Define fallback for reflection result + fallback_reflection = { + "critique": [], + "additional_questions": [], + "recommended_search_queries": [], + } + + # Use utility function to get structured output + reflection_result = get_structured_llm_output( + prompt=json.dumps(reflection_input), + system_prompt=reflection_prompt, + model=llm_model, + fallback_response=fallback_reflection, + ) + + # Make a deep copy of the synthesized info to create enhanced_info + enhanced_info = { + k: SynthesizedInfo( + synthesized_answer=v.synthesized_answer, + key_sources=v.key_sources.copy(), + confidence_level=v.confidence_level, + information_gaps=v.information_gaps, + improvements=v.improvements.copy() + if hasattr(v, "improvements") + else [], + ) + for k, v in state.synthesized_info.items() + } + + # Perform additional searches based on recommendations + search_queries = reflection_result.get( + "recommended_search_queries", [] + ) + if max_additional_searches > 0 and search_queries: + # Limit to max_additional_searches + search_queries = search_queries[:max_additional_searches] + + for query in search_queries: + logger.info(f"Performing additional search: {query}") + # Execute the search using the utility function + search_results, search_cost = search_and_extract_results( + query=query, + max_results=num_results_per_search, + cap_content_length=cap_search_length, + ) + + # Extract raw contents + raw_contents = [result.content for result in search_results] + + # Find the most relevant sub-question for this query + most_relevant_question = find_most_relevant_string( + query, state.sub_questions, llm_model + ) + + # Track search costs if using Exa (default provider) + # Note: This step doesn't have a search_provider parameter, so we check the default + from utils.search_utils import SearchEngineConfig + + config = SearchEngineConfig() + if ( + config.default_provider.lower() in ["exa", "both"] + and search_cost > 0 + ): + # Update total costs + state.search_costs["exa"] = ( + state.search_costs.get("exa", 0.0) + search_cost + ) + + # Add detailed cost entry + state.search_cost_details.append( + { + "provider": "exa", + "query": query, + "cost": search_cost, + "timestamp": time.time(), + "step": "iterative_reflection", + "purpose": "gap_filling", + "relevant_question": most_relevant_question, + } + ) + logger.info( + f"Exa search cost for reflection query: ${search_cost:.4f}" + ) + + if ( + most_relevant_question + and most_relevant_question in enhanced_info + ): + # Enhance the synthesis with new information + enhancement_input = { + "original_synthesis": enhanced_info[ + most_relevant_question + ].synthesized_answer, + "new_information": raw_contents, + "critique": [ + item + for item in reflection_result.get("critique", []) + if is_text_relevant( + item.get("issue", ""), most_relevant_question + ) + ], + } + + # Use the utility function for enhancement + enhanced_synthesis = get_structured_llm_output( + prompt=json.dumps(enhancement_input), + system_prompt=additional_synthesis_prompt, + model=llm_model, + fallback_response={ + "enhanced_synthesis": enhanced_info[ + most_relevant_question + ].synthesized_answer, + "improvements_made": [ + "Failed to enhance synthesis" + ], + "remaining_limitations": "Enhancement process failed.", + }, + ) + + if ( + enhanced_synthesis + and "enhanced_synthesis" in enhanced_synthesis + ): + # Update the synthesized answer + enhanced_info[ + most_relevant_question + ].synthesized_answer = enhanced_synthesis[ + "enhanced_synthesis" + ] + + # Add improvements + improvements = enhanced_synthesis.get( + "improvements_made", [] + ) + enhanced_info[ + most_relevant_question + ].improvements.extend(improvements) + + # Add any additional questions as new synthesized entries + for new_question in reflection_result.get("additional_questions", []): + if ( + new_question not in state.sub_questions + and new_question not in enhanced_info + ): + enhanced_info[new_question] = SynthesizedInfo( + synthesized_answer=f"This question was identified during reflection but has not yet been researched: {new_question}", + key_sources=[], + confidence_level="low", + information_gaps="This question requires additional research.", + ) + + # Prepare metadata about the reflection process + reflection_metadata = ReflectionMetadata( + critique_summary=[ + item.get("issue", "") + for item in reflection_result.get("critique", []) + ], + additional_questions_identified=reflection_result.get( + "additional_questions", [] + ), + searches_performed=search_queries, + improvements_made=sum( + [len(info.improvements) for info in enhanced_info.values()] + ), + ) + + logger.info( + f"Completed iterative reflection with {reflection_metadata.improvements_made} improvements" + ) + + # Update the state with enhanced info and metadata + state.update_after_reflection(enhanced_info, reflection_metadata) + + # Calculate execution time + execution_time = time.time() - start_time + + # Count questions that were enhanced + questions_enhanced = 0 + for question, enhanced in enhanced_info.items(): + if question in state.synthesized_info: + original = state.synthesized_info[question] + if enhanced.synthesized_answer != original.synthesized_answer: + questions_enhanced += 1 + + # Calculate confidence level changes + confidence_improvements = {"improved": 0, "unchanged": 0, "new": 0} + for question, enhanced in enhanced_info.items(): + if question in state.synthesized_info: + original = state.synthesized_info[question] + original_level = original.confidence_level.lower() + enhanced_level = enhanced.confidence_level.lower() + + level_map = {"low": 0, "medium": 1, "high": 2} + if enhanced_level in level_map and original_level in level_map: + if level_map[enhanced_level] > level_map[original_level]: + confidence_improvements["improved"] += 1 + else: + confidence_improvements["unchanged"] += 1 + else: + confidence_improvements["new"] += 1 + + # Log metadata + log_metadata( + metadata={ + "iterative_reflection": { + "execution_time_seconds": execution_time, + "llm_model": llm_model, + "max_additional_searches": max_additional_searches, + "searches_performed": len(search_queries), + "num_critique_points": len( + reflection_result.get("critique", []) + ), + "num_additional_questions": len( + reflection_result.get("additional_questions", []) + ), + "questions_enhanced": questions_enhanced, + "total_improvements": reflection_metadata.improvements_made, + "confidence_improvements": confidence_improvements, + "has_viewpoint_analysis": bool(viewpoint_analysis_dict), + "total_search_cost": state.search_costs.get("exa", 0.0), + } + } + ) + + # Log model metadata for cross-pipeline tracking + log_metadata( + metadata={ + "improvement_metrics": { + "confidence_improvements": confidence_improvements, + "total_improvements": reflection_metadata.improvements_made, + } + }, + infer_model=True, + ) + + # Log artifact metadata + log_metadata( + metadata={ + "enhanced_state_characteristics": { + "total_questions": len(enhanced_info), + "questions_with_improvements": sum( + 1 + for info in enhanced_info.values() + if info.improvements + ), + "high_confidence_count": sum( + 1 + for info in enhanced_info.values() + if info.confidence_level.lower() == "high" + ), + "medium_confidence_count": sum( + 1 + for info in enhanced_info.values() + if info.confidence_level.lower() == "medium" + ), + "low_confidence_count": sum( + 1 + for info in enhanced_info.values() + if info.confidence_level.lower() == "low" + ), + } + }, + infer_artifact=True, + ) + + return state + + except Exception as e: + logger.error(f"Error during iterative reflection: {e}") + + # Create error metadata + error_metadata = ReflectionMetadata( + error=f"Reflection failed: {str(e)}" + ) + + # Update the state with the original synthesized info as enhanced info + # and the error metadata + state.update_after_reflection(state.synthesized_info, error_metadata) + + # Log error metadata + execution_time = time.time() - start_time + log_metadata( + metadata={ + "iterative_reflection": { + "execution_time_seconds": execution_time, + "llm_model": llm_model, + "max_additional_searches": max_additional_searches, + "searches_performed": 0, + "status": "failed", + "error_message": str(e), + } + } + ) + + return state diff --git a/deep_research/steps/merge_results_step.py b/deep_research/steps/merge_results_step.py new file mode 100644 index 00000000..a8c8cd53 --- /dev/null +++ b/deep_research/steps/merge_results_step.py @@ -0,0 +1,265 @@ +import copy +import logging +import time +from typing import Annotated + +from materializers.pydantic_materializer import ResearchStateMaterializer +from utils.pydantic_models import ResearchState +from zenml import get_step_context, log_metadata, step +from zenml.client import Client + +logger = logging.getLogger(__name__) + + +@step(output_materializers=ResearchStateMaterializer) +def merge_sub_question_results_step( + original_state: ResearchState, + step_prefix: str = "process_question_", + output_name: str = "output", +) -> Annotated[ResearchState, "merged_state"]: + """Merge results from individual sub-question processing steps. + + This step collects the results from the parallel sub-question processing steps + and combines them into a single, comprehensive state object. + + Args: + original_state: The original research state with all sub-questions + step_prefix: The prefix used in step IDs for the parallel processing steps + output_name: The name of the output artifact from the processing steps + + Returns: + Annotated[ResearchState, "merged_state"]: A merged ResearchState with combined + results from all sub-questions + + Note: + This step is typically configured with the 'after' parameter in the pipeline + definition to ensure it runs after all parallel sub-question processing steps + have completed. + """ + start_time = time.time() + + # Start with the original state that has all sub-questions + merged_state = copy.deepcopy(original_state) + + # Initialize empty dictionaries for the results + merged_state.search_results = {} + merged_state.synthesized_info = {} + + # Initialize search cost tracking + merged_state.search_costs = {} + merged_state.search_cost_details = [] + + # Get pipeline run information to access outputs + try: + ctx = get_step_context() + if not ctx or not ctx.pipeline_run: + logger.error("Could not get pipeline run context") + return merged_state + + run_name = ctx.pipeline_run.name + client = Client() + run = client.get_pipeline_run(run_name) + + logger.info( + f"Merging results from parallel sub-question processing steps in run: {run_name}" + ) + + # Track which sub-questions were successfully processed + processed_questions = set() + parallel_steps_processed = 0 + + # Process each step in the run + for step_name, step_info in run.steps.items(): + # Only process steps with the specified prefix + if step_name.startswith(step_prefix): + try: + # Extract the sub-question index from the step name + if "_" in step_name: + index = int(step_name.split("_")[-1]) + logger.info( + f"Processing results from step: {step_name} (index: {index})" + ) + + # Get the output artifact + if output_name in step_info.outputs: + output_artifacts = step_info.outputs[output_name] + if output_artifacts: + output_artifact = output_artifacts[0] + sub_state = output_artifact.load() + + # Check if the sub-state has valid data + if ( + hasattr(sub_state, "sub_questions") + and sub_state.sub_questions + ): + sub_question = sub_state.sub_questions[0] + logger.info( + f"Found results for sub-question: {sub_question}" + ) + parallel_steps_processed += 1 + processed_questions.add(sub_question) + + # Merge search results + if ( + hasattr(sub_state, "search_results") + and sub_question + in sub_state.search_results + ): + merged_state.search_results[ + sub_question + ] = sub_state.search_results[ + sub_question + ] + logger.info( + f"Added search results for: {sub_question}" + ) + + # Merge synthesized info + if ( + hasattr(sub_state, "synthesized_info") + and sub_question + in sub_state.synthesized_info + ): + merged_state.synthesized_info[ + sub_question + ] = sub_state.synthesized_info[ + sub_question + ] + logger.info( + f"Added synthesized info for: {sub_question}" + ) + + # Merge search costs + if hasattr(sub_state, "search_costs"): + for ( + provider, + cost, + ) in sub_state.search_costs.items(): + merged_state.search_costs[ + provider + ] = ( + merged_state.search_costs.get( + provider, 0.0 + ) + + cost + ) + + # Merge search cost details + if hasattr( + sub_state, "search_cost_details" + ): + merged_state.search_cost_details.extend( + sub_state.search_cost_details + ) + except (ValueError, IndexError, KeyError, AttributeError) as e: + logger.warning(f"Error processing step {step_name}: {e}") + continue + + # Log summary + logger.info( + f"Merged results from {parallel_steps_processed} parallel steps" + ) + logger.info( + f"Successfully processed {len(processed_questions)} sub-questions" + ) + + # Log search cost summary + if merged_state.search_costs: + total_cost = sum(merged_state.search_costs.values()) + logger.info( + f"Total search costs merged: ${total_cost:.4f} across {len(merged_state.search_cost_details)} queries" + ) + for provider, cost in merged_state.search_costs.items(): + logger.info(f" {provider}: ${cost:.4f}") + + # Check for any missing sub-questions + for sub_q in merged_state.sub_questions: + if sub_q not in processed_questions: + logger.warning(f"Missing results for sub-question: {sub_q}") + + except Exception as e: + logger.error(f"Error during merge step: {e}") + + # Final check for empty results + if not merged_state.search_results or not merged_state.synthesized_info: + logger.warning( + "No results were found or merged from parallel processing steps!" + ) + + # Calculate execution time + execution_time = time.time() - start_time + + # Calculate metrics + missing_questions = [ + q for q in merged_state.sub_questions if q not in processed_questions + ] + + # Count total search results across all questions + total_search_results = sum( + len(results) for results in merged_state.search_results.values() + ) + + # Get confidence distribution for merged results + confidence_distribution = {"high": 0, "medium": 0, "low": 0} + for info in merged_state.synthesized_info.values(): + level = info.confidence_level.lower() + if level in confidence_distribution: + confidence_distribution[level] += 1 + + # Calculate completeness ratio + completeness_ratio = ( + len(processed_questions) / len(merged_state.sub_questions) + if merged_state.sub_questions + else 0 + ) + + # Log metadata + log_metadata( + metadata={ + "merge_results": { + "execution_time_seconds": execution_time, + "total_sub_questions": len(merged_state.sub_questions), + "parallel_steps_processed": parallel_steps_processed, + "questions_successfully_merged": len(processed_questions), + "missing_questions_count": len(missing_questions), + "missing_questions": missing_questions[:5] + if missing_questions + else [], # Limit to 5 for metadata + "total_search_results": total_search_results, + "confidence_distribution": confidence_distribution, + "merge_success": bool( + merged_state.search_results + and merged_state.synthesized_info + ), + "total_search_costs": merged_state.search_costs, + "total_search_queries": len(merged_state.search_cost_details), + "total_exa_cost": merged_state.search_costs.get("exa", 0.0), + } + } + ) + + # Log model metadata for cross-pipeline tracking + log_metadata( + metadata={ + "research_quality": { + "completeness_ratio": completeness_ratio, + } + }, + infer_model=True, + ) + + # Log artifact metadata + log_metadata( + metadata={ + "merged_state_characteristics": { + "has_search_results": bool(merged_state.search_results), + "has_synthesized_info": bool(merged_state.synthesized_info), + "search_results_count": len(merged_state.search_results), + "synthesized_info_count": len(merged_state.synthesized_info), + "completeness_ratio": completeness_ratio, + } + }, + infer_artifact=True, + ) + + return merged_state diff --git a/deep_research/steps/process_sub_question_step.py b/deep_research/steps/process_sub_question_step.py new file mode 100644 index 00000000..f6b077c7 --- /dev/null +++ b/deep_research/steps/process_sub_question_step.py @@ -0,0 +1,289 @@ +import copy +import logging +import time +import warnings +from typing import Annotated + +# Suppress Pydantic serialization warnings from ZenML artifact metadata +# These occur when ZenML stores timestamp metadata as floats but models expect ints +warnings.filterwarnings( + "ignore", message=".*PydanticSerializationUnexpectedValue.*" +) + +from materializers.pydantic_materializer import ResearchStateMaterializer +from utils.llm_utils import synthesize_information +from utils.prompt_models import PromptsBundle +from utils.pydantic_models import ResearchState, SynthesizedInfo +from utils.search_utils import ( + generate_search_query, + search_and_extract_results, +) +from zenml import log_metadata, step + +logger = logging.getLogger(__name__) + + +@step(output_materializers=ResearchStateMaterializer) +def process_sub_question_step( + state: ResearchState, + prompts_bundle: PromptsBundle, + question_index: int, + llm_model_search: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + llm_model_synthesis: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + num_results_per_search: int = 3, + cap_search_length: int = 20000, + search_provider: str = "tavily", + search_mode: str = "auto", + langfuse_project_name: str = "deep-research", +) -> Annotated[ResearchState, "output"]: + """Process a single sub-question if it exists at the given index. + + This step combines the gathering and synthesis steps for a single sub-question. + It's designed to be run in parallel for each sub-question. + + Args: + state: The original research state with all sub-questions + prompts_bundle: Bundle containing all prompts for the pipeline + question_index: The index of the sub-question to process + llm_model_search: Model to use for search query generation + llm_model_synthesis: Model to use for synthesis + num_results_per_search: Number of results to fetch per search + cap_search_length: Maximum length of content to process from search results + search_provider: Search provider to use (tavily, exa, or both) + search_mode: Search mode for Exa provider (neural, keyword, or auto) + + Returns: + A new ResearchState containing only the processed sub-question's results + """ + start_time = time.time() + + # Create a copy of the state to avoid modifying the original + sub_state = copy.deepcopy(state) + + # Clear all existing data except the main query + sub_state.search_results = {} + sub_state.synthesized_info = {} + sub_state.enhanced_info = {} + sub_state.viewpoint_analysis = None + sub_state.reflection_metadata = None + sub_state.final_report_html = "" + + # Check if this index exists in sub-questions + if question_index >= len(state.sub_questions): + logger.info( + f"No sub-question at index {question_index}, skipping processing" + ) + # Log metadata for skipped processing + log_metadata( + metadata={ + "sub_question_processing": { + "question_index": question_index, + "status": "skipped", + "reason": "index_out_of_range", + "total_sub_questions": len(state.sub_questions), + } + } + ) + # Return an empty state since there's no question to process + sub_state.sub_questions = [] + return sub_state + + # Get the target sub-question + sub_question = state.sub_questions[question_index] + logger.info( + f"Processing sub-question {question_index + 1}: {sub_question}" + ) + + # Store only this sub-question in the sub-state + sub_state.sub_questions = [sub_question] + + # === INFORMATION GATHERING === + search_phase_start = time.time() + + # Generate search query with prompt from bundle + search_query_prompt = prompts_bundle.get_prompt_content( + "search_query_prompt" + ) + search_query_data = generate_search_query( + sub_question=sub_question, + model=llm_model_search, + system_prompt=search_query_prompt, + project=langfuse_project_name, + ) + search_query = search_query_data.get( + "search_query", f"research about {sub_question}" + ) + + # Perform search + logger.info(f"Performing search with query: {search_query}") + if search_provider: + logger.info(f"Using search provider: {search_provider}") + results_list, search_cost = search_and_extract_results( + query=search_query, + max_results=num_results_per_search, + cap_content_length=cap_search_length, + provider=search_provider, + search_mode=search_mode, + ) + + # Track search costs if using Exa + if ( + search_provider + and search_provider.lower() in ["exa", "both"] + and search_cost > 0 + ): + # Update total costs + sub_state.search_costs["exa"] = ( + sub_state.search_costs.get("exa", 0.0) + search_cost + ) + + # Add detailed cost entry + sub_state.search_cost_details.append( + { + "provider": "exa", + "query": search_query, + "cost": search_cost, + "timestamp": time.time(), + "step": "process_sub_question", + "sub_question": sub_question, + "question_index": question_index, + } + ) + logger.info( + f"Exa search cost for sub-question {question_index}: ${search_cost:.4f}" + ) + + search_results = {sub_question: results_list} + sub_state.update_search_results(search_results) + + search_phase_time = time.time() - search_phase_start + + # === INFORMATION SYNTHESIS === + synthesis_phase_start = time.time() + + # Extract raw contents and URLs + raw_contents = [] + sources = [] + for result in results_list: + raw_contents.append(result.content) + sources.append(result.url) + + # Prepare input for synthesis + synthesis_input = { + "sub_question": sub_question, + "search_results": raw_contents, + "sources": sources, + } + + # Synthesize information with prompt from bundle + synthesis_prompt = prompts_bundle.get_prompt_content("synthesis_prompt") + synthesis_result = synthesize_information( + synthesis_input=synthesis_input, + model=llm_model_synthesis, + system_prompt=synthesis_prompt, + project=langfuse_project_name, + ) + + # Create SynthesizedInfo object + synthesized_info = { + sub_question: SynthesizedInfo( + synthesized_answer=synthesis_result.get( + "synthesized_answer", f"Synthesis for '{sub_question}' failed." + ), + key_sources=synthesis_result.get("key_sources", sources[:1]), + confidence_level=synthesis_result.get("confidence_level", "low"), + information_gaps=synthesis_result.get( + "information_gaps", + "Synthesis process encountered technical difficulties.", + ), + improvements=synthesis_result.get("improvements", []), + ) + } + + # Update the state with synthesized information + sub_state.update_synthesized_info(synthesized_info) + + synthesis_phase_time = time.time() - synthesis_phase_start + total_execution_time = time.time() - start_time + + # Calculate total content length processed + total_content_length = sum(len(content) for content in raw_contents) + + # Get unique domains from sources + unique_domains = set() + for url in sources: + try: + from urllib.parse import urlparse + + domain = urlparse(url).netloc + unique_domains.add(domain) + except: + pass + + # Log comprehensive metadata + log_metadata( + metadata={ + "sub_question_processing": { + "question_index": question_index, + "status": "completed", + "sub_question": sub_question, + "execution_time_seconds": total_execution_time, + "search_phase_time_seconds": search_phase_time, + "synthesis_phase_time_seconds": synthesis_phase_time, + "search_query": search_query, + "search_provider": search_provider, + "search_mode": search_mode, + "num_results_requested": num_results_per_search, + "num_results_retrieved": len(results_list), + "total_content_length": total_content_length, + "cap_search_length": cap_search_length, + "unique_domains": list(unique_domains), + "llm_model_search": llm_model_search, + "llm_model_synthesis": llm_model_synthesis, + "confidence_level": synthesis_result.get( + "confidence_level", "low" + ), + "information_gaps": synthesis_result.get( + "information_gaps", "" + ), + "key_sources_count": len( + synthesis_result.get("key_sources", []) + ), + "search_cost": search_cost, + "search_cost_provider": "exa" + if search_provider + and search_provider.lower() in ["exa", "both"] + else None, + } + } + ) + + # Log model metadata for cross-pipeline tracking + log_metadata( + metadata={ + "search_metrics": { + "confidence_level": synthesis_result.get( + "confidence_level", "low" + ), + "search_provider": search_provider, + } + }, + infer_model=True, + ) + + # Log artifact metadata for the output state + log_metadata( + metadata={ + "sub_state_characteristics": { + "has_search_results": bool(sub_state.search_results), + "has_synthesized_info": bool(sub_state.synthesized_info), + "sub_question_processed": sub_question, + "confidence_level": synthesis_result.get( + "confidence_level", "low" + ), + } + }, + infer_artifact=True, + ) + + return sub_state diff --git a/deep_research/steps/pydantic_final_report_step.py b/deep_research/steps/pydantic_final_report_step.py new file mode 100644 index 00000000..dc05b1bb --- /dev/null +++ b/deep_research/steps/pydantic_final_report_step.py @@ -0,0 +1,1250 @@ +"""Final report generation step using Pydantic models and materializers. + +This module provides a ZenML pipeline step for generating the final HTML research report +using Pydantic models and improved materializers. +""" + +import html +import json +import logging +import re +import time +from typing import Annotated, Tuple + +from materializers.pydantic_materializer import ResearchStateMaterializer +from utils.helper_functions import ( + extract_html_from_content, + remove_reasoning_from_output, +) +from utils.llm_utils import run_llm_completion +from utils.prompt_models import PromptsBundle +from utils.prompts import ( + STATIC_HTML_TEMPLATE, + SUB_QUESTION_TEMPLATE, + VIEWPOINT_ANALYSIS_TEMPLATE, +) +from utils.pydantic_models import ResearchState +from zenml import log_metadata, step +from zenml.types import HTMLString + +logger = logging.getLogger(__name__) + + +def clean_html_output(html_content: str) -> str: + """Clean HTML output from LLM to ensure proper rendering. + + This function removes markdown code blocks, fixes common issues with LLM HTML output, + and ensures we have proper HTML structure for rendering. + + Args: + html_content: Raw HTML content from LLM + + Returns: + Cleaned HTML content ready for rendering + """ + # Remove markdown code block markers (```html and ```) + html_content = re.sub(r"```html\s*", "", html_content) + html_content = re.sub(r"```\s*$", "", html_content) + html_content = re.sub(r"```", "", html_content) + + # Remove any CSS code block markers + html_content = re.sub(r"```css\s*", "", html_content) + + # Ensure HTML content is properly wrapped in HTML tags if not already + if not html_content.strip().startswith( + "{html_content}' + + html_content = re.sub(r"\[CSS STYLESHEET GOES HERE\]", "", html_content) + html_content = re.sub(r"\[SUB-QUESTIONS LINKS\]", "", html_content) + html_content = re.sub(r"\[ADDITIONAL SECTIONS LINKS\]", "", html_content) + html_content = re.sub(r"\[FOR EACH SUB-QUESTION\]:", "", html_content) + html_content = re.sub(r"\[FOR EACH TENSION\]:", "", html_content) + + # Replace content placeholders with appropriate defaults + html_content = re.sub( + r"\[CONCISE SUMMARY OF KEY FINDINGS\]", + "Summary of findings from the research query.", + html_content, + ) + html_content = re.sub( + r"\[INTRODUCTION TO THE RESEARCH QUERY\]", + "Introduction to the research topic.", + html_content, + ) + html_content = re.sub( + r"\[OVERVIEW OF THE APPROACH AND SUB-QUESTIONS\]", + "Overview of the research approach.", + html_content, + ) + html_content = re.sub( + r"\[CONCLUSION TEXT\]", + "Conclusion of the research findings.", + html_content, + ) + + return html_content + + +def format_text_with_code_blocks(text: str) -> str: + """Format text with proper handling of code blocks and markdown formatting. + + Args: + text: The raw text to format + + Returns: + str: HTML-formatted text + """ + if not text: + return "" + + # First escape HTML + escaped_text = html.escape(text) + + # Handle code blocks (wrap content in ``` or ```) + pattern = r"```(?:\w*\n)?(.*?)```" + + def code_block_replace(match): + code_content = match.group(1) + # Strip extra newlines at beginning and end + code_content = code_content.strip("\n") + return f"
{code_content}
" + + # Replace code blocks + formatted_text = re.sub( + pattern, code_block_replace, escaped_text, flags=re.DOTALL + ) + + # Convert regular newlines to
tags (but not inside
 blocks)
+    parts = []
+    in_pre = False
+    for line in formatted_text.split("\n"):
+        if "
" in line:
+            in_pre = True
+            parts.append(line)
+        elif "
" in line: + in_pre = False + parts.append(line) + elif in_pre: + # Inside a code block, preserve newlines + parts.append(line) + else: + # Outside code blocks, convert newlines to
+ parts.append(line + "
") + + return "".join(parts) + + +def generate_executive_summary( + state: ResearchState, + prompts_bundle: PromptsBundle, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + langfuse_project_name: str = "deep-research", +) -> str: + """Generate an executive summary using LLM based on research findings. + + Args: + state: The current research state + prompts_bundle: Bundle containing all prompts for the pipeline + llm_model: The model to use for generation + langfuse_project_name: Name of the Langfuse project for tracking + + Returns: + HTML formatted executive summary + """ + logger.info("Generating executive summary using LLM") + + # Prepare the context with all research findings + context = f"Main Research Query: {state.main_query}\n\n" + + # Add synthesized findings for each sub-question + for i, sub_question in enumerate(state.sub_questions, 1): + info = state.enhanced_info.get( + sub_question + ) or state.synthesized_info.get(sub_question) + if info: + context += f"Sub-question {i}: {sub_question}\n" + context += f"Answer Summary: {info.synthesized_answer[:500]}...\n" + context += f"Confidence: {info.confidence_level}\n" + context += f"Key Sources: {', '.join(info.key_sources[:3]) if info.key_sources else 'N/A'}\n\n" + + # Add viewpoint analysis insights if available + if state.viewpoint_analysis: + context += "Key Areas of Agreement:\n" + for agreement in state.viewpoint_analysis.main_points_of_agreement[:3]: + context += f"- {agreement}\n" + context += "\nKey Tensions:\n" + for tension in state.viewpoint_analysis.areas_of_tension[:2]: + context += f"- {tension.topic}\n" + + # Get the executive summary prompt + try: + executive_summary_prompt = prompts_bundle.get_prompt_content( + "executive_summary_prompt" + ) + logger.info("Successfully retrieved executive_summary_prompt") + except Exception as e: + logger.error(f"Failed to get executive_summary_prompt: {e}") + logger.info( + f"Available prompts: {list(prompts_bundle.list_all_prompts().keys())}" + ) + return generate_fallback_executive_summary(state) + + try: + # Call LLM to generate executive summary + result = run_llm_completion( + prompt=context, + system_prompt=executive_summary_prompt, + model=llm_model, + temperature=0.7, + max_tokens=800, + project=langfuse_project_name, + tags=["executive_summary_generation"], + ) + + if result: + content = remove_reasoning_from_output(result) + # Clean up the HTML + content = extract_html_from_content(content) + logger.info("Successfully generated LLM-based executive summary") + return content + else: + logger.warning("Failed to generate executive summary via LLM") + return generate_fallback_executive_summary(state) + + except Exception as e: + logger.error(f"Error generating executive summary: {e}") + return generate_fallback_executive_summary(state) + + +def generate_introduction( + state: ResearchState, + prompts_bundle: PromptsBundle, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + langfuse_project_name: str = "deep-research", +) -> str: + """Generate an introduction using LLM based on research query and sub-questions. + + Args: + state: The current research state + prompts_bundle: Bundle containing all prompts for the pipeline + llm_model: The model to use for generation + langfuse_project_name: Name of the Langfuse project for tracking + + Returns: + HTML formatted introduction + """ + logger.info("Generating introduction using LLM") + + # Prepare the context + context = f"Main Research Query: {state.main_query}\n\n" + context += "Sub-questions being explored:\n" + for i, sub_question in enumerate(state.sub_questions, 1): + context += f"{i}. {sub_question}\n" + + # Get the introduction prompt + try: + introduction_prompt = prompts_bundle.get_prompt_content( + "introduction_prompt" + ) + logger.info("Successfully retrieved introduction_prompt") + except Exception as e: + logger.error(f"Failed to get introduction_prompt: {e}") + logger.info( + f"Available prompts: {list(prompts_bundle.list_all_prompts().keys())}" + ) + return generate_fallback_introduction(state) + + try: + # Call LLM to generate introduction + result = run_llm_completion( + prompt=context, + system_prompt=introduction_prompt, + model=llm_model, + temperature=0.7, + max_tokens=600, + project=langfuse_project_name, + tags=["introduction_generation"], + ) + + if result: + content = remove_reasoning_from_output(result) + # Clean up the HTML + content = extract_html_from_content(content) + logger.info("Successfully generated LLM-based introduction") + return content + else: + logger.warning("Failed to generate introduction via LLM") + return generate_fallback_introduction(state) + + except Exception as e: + logger.error(f"Error generating introduction: {e}") + return generate_fallback_introduction(state) + + +def generate_fallback_executive_summary(state: ResearchState) -> str: + """Generate a fallback executive summary when LLM fails.""" + summary = f"

This report examines the question: {html.escape(state.main_query)}

" + summary += f"

The research explored {len(state.sub_questions)} key dimensions of this topic, " + summary += "synthesizing findings from multiple sources to provide a comprehensive analysis.

" + + # Add confidence overview + confidence_counts = {"high": 0, "medium": 0, "low": 0} + for info in state.enhanced_info.values(): + level = info.confidence_level.lower() + if level in confidence_counts: + confidence_counts[level] += 1 + + summary += f"

Overall confidence in findings: {confidence_counts['high']} high, " + summary += f"{confidence_counts['medium']} medium, {confidence_counts['low']} low.

" + + return summary + + +def generate_fallback_introduction(state: ResearchState) -> str: + """Generate a fallback introduction when LLM fails.""" + intro = f"

This report addresses the research query: {html.escape(state.main_query)}

" + intro += f"

The research was conducted by breaking down the main query into {len(state.sub_questions)} " + intro += ( + "sub-questions to explore different aspects of the topic in depth. " + ) + intro += "Each sub-question was researched independently, with findings synthesized from various sources.

" + return intro + + +def generate_conclusion( + state: ResearchState, + prompts_bundle: PromptsBundle, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + langfuse_project_name: str = "deep-research", +) -> str: + """Generate a comprehensive conclusion using LLM based on all research findings. + + Args: + state: The ResearchState containing all research findings + prompts_bundle: Bundle containing all prompts for the pipeline + llm_model: The model to use for conclusion generation + + Returns: + str: HTML-formatted conclusion content + """ + logger.info("Generating comprehensive conclusion using LLM") + + # Prepare input data for conclusion generation + conclusion_input = { + "main_query": state.main_query, + "sub_questions": state.sub_questions, + "enhanced_info": {}, + } + + # Include enhanced information for each sub-question + for question in state.sub_questions: + if question in state.enhanced_info: + info = state.enhanced_info[question] + conclusion_input["enhanced_info"][question] = { + "synthesized_answer": info.synthesized_answer, + "confidence_level": info.confidence_level, + "information_gaps": info.information_gaps, + "key_sources": info.key_sources, + "improvements": getattr(info, "improvements", []), + } + elif question in state.synthesized_info: + # Fallback to synthesized info if enhanced info not available + info = state.synthesized_info[question] + conclusion_input["enhanced_info"][question] = { + "synthesized_answer": info.synthesized_answer, + "confidence_level": info.confidence_level, + "information_gaps": info.information_gaps, + "key_sources": info.key_sources, + "improvements": [], + } + + # Include viewpoint analysis if available + if state.viewpoint_analysis: + conclusion_input["viewpoint_analysis"] = { + "main_points_of_agreement": state.viewpoint_analysis.main_points_of_agreement, + "areas_of_tension": [ + {"topic": tension.topic, "viewpoints": tension.viewpoints} + for tension in state.viewpoint_analysis.areas_of_tension + ], + "perspective_gaps": state.viewpoint_analysis.perspective_gaps, + "integrative_insights": state.viewpoint_analysis.integrative_insights, + } + + # Include reflection metadata if available + if state.reflection_metadata: + conclusion_input["reflection_metadata"] = { + "critique_summary": state.reflection_metadata.critique_summary, + "additional_questions_identified": state.reflection_metadata.additional_questions_identified, + "improvements_made": state.reflection_metadata.improvements_made, + } + + try: + # Get the prompt from the bundle + conclusion_prompt = prompts_bundle.get_prompt_content( + "conclusion_generation_prompt" + ) + + # Generate conclusion using LLM + conclusion_html = run_llm_completion( + prompt=json.dumps(conclusion_input, indent=2), + system_prompt=conclusion_prompt, + model=llm_model, + clean_output=True, + max_tokens=1500, # Sufficient for comprehensive conclusion + project=langfuse_project_name, + ) + + # Clean up any formatting issues + conclusion_html = conclusion_html.strip() + + # Remove any h2 tags with "Conclusion" text that LLM might have added + # Since we already have a Conclusion header in the template + conclusion_html = re.sub( + r"]*>\s*Conclusion\s*

\s*", + "", + conclusion_html, + flags=re.IGNORECASE, + ) + conclusion_html = re.sub( + r"]*>\s*Conclusion\s*\s*", + "", + conclusion_html, + flags=re.IGNORECASE, + ) + + # Also remove plain text "Conclusion" at the start if it exists + conclusion_html = re.sub( + r"^Conclusion\s*\n*", + "", + conclusion_html.strip(), + flags=re.IGNORECASE, + ) + + if not conclusion_html.startswith("

"): + # Wrap in paragraph tags if not already formatted + conclusion_html = f"

{conclusion_html}

" + + logger.info("Successfully generated LLM-based conclusion") + return conclusion_html + + except Exception as e: + logger.warning(f"Failed to generate LLM conclusion: {e}") + # Return a basic fallback conclusion + return f"""

This report has explored {html.escape(state.main_query)} through a structured research approach, examining {len(state.sub_questions)} focused sub-questions and synthesizing information from diverse sources. The findings provide a comprehensive understanding of the topic, highlighting key aspects, perspectives, and current knowledge.

+

While some information gaps remain, as noted in the respective sections, this research provides a solid foundation for understanding the topic and its implications.

""" + + +def generate_report_from_template( + state: ResearchState, + prompts_bundle: PromptsBundle, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + langfuse_project_name: str = "deep-research", +) -> str: + """Generate a final HTML report from a static template. + + Instead of using an LLM to generate HTML, this function uses predefined HTML + templates and populates them with data from the research state. + + Args: + state: The current research state + prompts_bundle: Bundle containing all prompts for the pipeline + llm_model: The model to use for conclusion generation + + Returns: + str: The HTML content of the report + """ + logger.info( + f"Generating templated HTML report for query: {state.main_query}" + ) + + # Generate table of contents for sub-questions + sub_questions_toc = "" + for i, question in enumerate(state.sub_questions, 1): + safe_id = f"question-{i}" + sub_questions_toc += ( + f'
  • {html.escape(question)}
  • \n' + ) + + # Add viewpoint analysis to TOC if available + additional_sections_toc = "" + if state.viewpoint_analysis: + additional_sections_toc += ( + '
  • Viewpoint Analysis
  • \n' + ) + + # Generate HTML for sub-questions + sub_questions_html = "" + all_sources = set() + + for i, question in enumerate(state.sub_questions, 1): + info = state.enhanced_info.get(question, None) + + # Skip if no information is available + if not info: + continue + + # Process confidence level + confidence = info.confidence_level.lower() + confidence_upper = info.confidence_level.upper() + + # Process key sources + key_sources_html = "" + if info.key_sources: + all_sources.update(info.key_sources) + sources_list = "\n".join( + [ + f'
  • {html.escape(source)}
  • ' + if source.startswith(("http://", "https://")) + else f"
  • {html.escape(source)}
  • " + for source in info.key_sources + ] + ) + key_sources_html = f""" +
    +

    📚 Key Sources

    +
      + {sources_list} +
    +
    + """ + + # Process information gaps + info_gaps_html = "" + if info.information_gaps: + info_gaps_html = f""" +
    +

    🧩 Information Gaps

    +

    {format_text_with_code_blocks(info.information_gaps)}

    +
    + """ + + # Determine confidence icon based on level + confidence_icon = "🔴" # Default (low) + if confidence_upper == "HIGH": + confidence_icon = "🟢" + elif confidence_upper == "MEDIUM": + confidence_icon = "🟡" + + # Format the subquestion section using the template + sub_question_html = SUB_QUESTION_TEMPLATE.format( + index=i, + question=html.escape(question), + confidence=confidence, + confidence_upper=confidence_upper, + confidence_icon=confidence_icon, + answer=format_text_with_code_blocks(info.synthesized_answer), + info_gaps_html=info_gaps_html, + key_sources_html=key_sources_html, + ) + + sub_questions_html += sub_question_html + + # Generate viewpoint analysis HTML if available + viewpoint_analysis_html = "" + if state.viewpoint_analysis: + # Format points of agreement + agreements_html = "" + for point in state.viewpoint_analysis.main_points_of_agreement: + agreements_html += f"
  • {html.escape(point)}
  • \n" + + # Format areas of tension + tensions_html = "" + for tension in state.viewpoint_analysis.areas_of_tension: + viewpoints_html = "" + for title, content in tension.viewpoints.items(): + # Create category-specific styling + category_class = f"category-{title.lower()}" + category_title = title.capitalize() + + viewpoints_html += f""" +
    + {category_title} +

    {html.escape(content)}

    +
    + """ + + tensions_html += f""" +
    +

    {html.escape(tension.topic)}

    +
    + {viewpoints_html} +
    +
    + """ + + # Format the viewpoint analysis section using the template + viewpoint_analysis_html = VIEWPOINT_ANALYSIS_TEMPLATE.format( + agreements_html=agreements_html, + tensions_html=tensions_html, + perspective_gaps=format_text_with_code_blocks( + state.viewpoint_analysis.perspective_gaps + ), + integrative_insights=format_text_with_code_blocks( + state.viewpoint_analysis.integrative_insights + ), + ) + + # Generate references HTML + references_html = "" + + # Generate dynamic executive summary using LLM + logger.info("Generating dynamic executive summary...") + executive_summary = generate_executive_summary( + state, prompts_bundle, llm_model, langfuse_project_name + ) + logger.info( + f"Executive summary generated: {len(executive_summary)} characters" + ) + + # Generate dynamic introduction using LLM + logger.info("Generating dynamic introduction...") + introduction_html = generate_introduction( + state, prompts_bundle, llm_model, langfuse_project_name + ) + logger.info(f"Introduction generated: {len(introduction_html)} characters") + + # Generate comprehensive conclusion using LLM + conclusion_html = generate_conclusion( + state, prompts_bundle, llm_model, langfuse_project_name + ) + + # Generate complete HTML report + html_content = STATIC_HTML_TEMPLATE.format( + main_query=html.escape(state.main_query), + sub_questions_toc=sub_questions_toc, + additional_sections_toc=additional_sections_toc, + executive_summary=executive_summary, + introduction_html=introduction_html, + num_sub_questions=len(state.sub_questions), + sub_questions_html=sub_questions_html, + viewpoint_analysis_html=viewpoint_analysis_html, + conclusion_html=conclusion_html, + references_html=references_html, + ) + + return html_content + + +def _generate_fallback_report(state: ResearchState) -> str: + """Generate a minimal fallback report when the main report generation fails. + + This function creates a simplified HTML report with a consistent structure when + the main report generation process encounters an error. The HTML includes: + - A header section with the main research query + - An error notice + - Introduction section + - Individual sections for each sub-question with available answers + - A references section if sources are available + + Args: + state: The current research state containing query and answer information + + Returns: + str: A basic HTML report with a standard research report structure + """ + # Create a simple HTML structure with embedded CSS for styling + html = f""" + + + + + + + +
    +

    Research Report: {state.main_query}

    + +
    +

    Note: This is a fallback report generated due to an error in the report generation process.

    +
    + + +
    +

    Table of Contents

    +
      +
    • Introduction
    • +""" + + # Add TOC entries for each sub-question + for i, sub_question in enumerate(state.sub_questions): + safe_id = f"question-{i + 1}" + html += f'
    • {sub_question}
    • \n' + + html += """
    • References
    • +
    +
    + + +
    +

    Executive Summary

    +

    This report presents findings related to the main research query. It explores multiple aspects of the topic through structured sub-questions and synthesizes information from various sources.

    +
    + +
    +

    Introduction

    +

    This report addresses the research query: "{state.main_query}"

    +

    The analysis is structured around {len(state.sub_questions)} sub-questions that explore different dimensions of this topic.

    +
    +""" + + # Add each sub-question and its synthesized information + for i, sub_question in enumerate(state.sub_questions): + safe_id = f"question-{i + 1}" + info = state.enhanced_info.get(sub_question, None) + + if not info: + # Try to get from synthesized info if not in enhanced info + info = state.synthesized_info.get(sub_question, None) + + if info: + answer = info.synthesized_answer + confidence = info.confidence_level + + # Add appropriate confidence class + confidence_class = "" + if confidence == "high": + confidence_class = "confidence-high" + elif confidence == "medium": + confidence_class = "confidence-medium" + elif confidence == "low": + confidence_class = "confidence-low" + + html += f""" +
    +

    {i + 1}. {sub_question}

    +

    Confidence Level: {confidence.upper()}

    +
    +

    {answer}

    +
    + """ + + # Add information gaps if available + if hasattr(info, "information_gaps") and info.information_gaps: + html += f""" +
    +

    Information Gaps

    +

    {info.information_gaps}

    +
    + """ + + # Add improvements if available + if hasattr(info, "improvements") and info.improvements: + html += """ +
    +

    Improvements Made

    +
      + """ + + for improvement in info.improvements: + html += f"
    • {improvement}
    • \n" + + html += """ +
    +
    + """ + + # Add key sources if available + if hasattr(info, "key_sources") and info.key_sources: + html += """ +
    +

    Key Sources

    +
      + """ + + for source in info.key_sources: + html += f"
    • {source}
    • \n" + + html += """ +
    +
    + """ + + html += """ +
    + """ + else: + html += f""" +
    +

    {i + 1}. {sub_question}

    +

    No information available for this question.

    +
    + """ + + # Add conclusion section + html += """ +
    +

    Conclusion

    +

    This report has explored the research query through multiple sub-questions, providing synthesized information based on available sources. While limitations exist in some areas, the report provides a structured analysis of the topic.

    +
    + """ + + # Add sources if available + sources_set = set() + for info in state.enhanced_info.values(): + if info.key_sources: + sources_set.update(info.key_sources) + + if sources_set: + html += """ +
    +

    References

    +
      + """ + + for source in sorted(sources_set): + html += f"
    • {source}
    • \n" + + html += """ +
    +
    + """ + else: + html += """ +
    +

    References

    +

    No references available.

    +
    + """ + + # Close the HTML structure + html += """ +
    + + + """ + + return html + + +@step( + output_materializers={ + "state": ResearchStateMaterializer, + } +) +def pydantic_final_report_step( + state: ResearchState, + prompts_bundle: PromptsBundle, + use_static_template: bool = True, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + langfuse_project_name: str = "deep-research", +) -> Tuple[ + Annotated[ResearchState, "state"], + Annotated[HTMLString, "report_html"], +]: + """Generate the final research report in HTML format using Pydantic models. + + This step uses the Pydantic models and materializers to generate a final + HTML report and return both the updated state and the HTML report as + separate artifacts. + + Args: + state: The current research state (Pydantic model) + prompts_bundle: Bundle containing all prompts for the pipeline + use_static_template: Whether to use a static template instead of LLM generation + llm_model: The model to use for report generation with provider prefix + + Returns: + A tuple containing the updated research state and the HTML report + """ + start_time = time.time() + logger.info("Generating final research report using Pydantic models") + + if use_static_template: + # Use the static HTML template approach + logger.info("Using static HTML template for report generation") + html_content = generate_report_from_template( + state, prompts_bundle, llm_model, langfuse_project_name + ) + + # Update the state with the final report HTML + state.set_final_report(html_content) + + # Collect metadata about the report + execution_time = time.time() - start_time + + # Count sources + all_sources = set() + for info in state.enhanced_info.values(): + if info.key_sources: + all_sources.update(info.key_sources) + + # Count confidence levels + confidence_distribution = {"high": 0, "medium": 0, "low": 0} + for info in state.enhanced_info.values(): + level = info.confidence_level.lower() + if level in confidence_distribution: + confidence_distribution[level] += 1 + + # Log metadata + log_metadata( + metadata={ + "report_generation": { + "execution_time_seconds": execution_time, + "generation_method": "static_template", + "llm_model": llm_model, + "report_length_chars": len(html_content), + "num_sub_questions": len(state.sub_questions), + "num_sources": len(all_sources), + "has_viewpoint_analysis": bool(state.viewpoint_analysis), + "has_reflection": bool(state.reflection_metadata), + "confidence_distribution": confidence_distribution, + "fallback_report": False, + } + } + ) + + # Log model metadata for cross-pipeline tracking + log_metadata( + metadata={ + "research_quality": { + "confidence_distribution": confidence_distribution, + } + }, + infer_model=True, + ) + + # Log artifact metadata for the HTML report + log_metadata( + metadata={ + "html_report_characteristics": { + "size_bytes": len(html_content.encode("utf-8")), + "has_toc": "toc" in html_content.lower(), + "has_executive_summary": "executive summary" + in html_content.lower(), + "has_conclusion": "conclusion" in html_content.lower(), + "has_references": "references" in html_content.lower(), + } + }, + infer_artifact=True, + artifact_name="report_html", + ) + + logger.info( + "Final research report generated successfully with static template" + ) + return state, HTMLString(html_content) + + # Otherwise use the LLM-generated approach + # Convert Pydantic model to dict for LLM input + report_input = { + "main_query": state.main_query, + "sub_questions": state.sub_questions, + "synthesized_information": state.enhanced_info, + } + + if state.viewpoint_analysis: + report_input["viewpoint_analysis"] = state.viewpoint_analysis + + if state.reflection_metadata: + report_input["reflection_metadata"] = state.reflection_metadata + + # Generate the report + try: + logger.info(f"Calling {llm_model} to generate final report") + + # Get the prompt from the bundle + report_prompt = prompts_bundle.get_prompt_content( + "report_generation_prompt" + ) + + # Use the utility function to run LLM completion + html_content = run_llm_completion( + prompt=json.dumps(report_input), + system_prompt=report_prompt, + model=llm_model, + clean_output=False, # Don't clean in case of breaking HTML formatting + max_tokens=4000, # Increased token limit for detailed report generation + project=langfuse_project_name, + ) + + # Clean up any JSON wrapper or other artifacts + html_content = remove_reasoning_from_output(html_content) + + # Process the HTML content to remove code block markers and fix common issues + html_content = clean_html_output(html_content) + + # Basic validation of HTML content + if not html_content.strip().startswith("<"): + logger.warning( + "Generated content does not appear to be valid HTML" + ) + # Try to extract HTML if it might be wrapped in code blocks or JSON + html_content = extract_html_from_content(html_content) + + # Update the state with the final report HTML + state.set_final_report(html_content) + + # Collect metadata about the report + execution_time = time.time() - start_time + + # Count sources + all_sources = set() + for info in state.enhanced_info.values(): + if info.key_sources: + all_sources.update(info.key_sources) + + # Count confidence levels + confidence_distribution = {"high": 0, "medium": 0, "low": 0} + for info in state.enhanced_info.values(): + level = info.confidence_level.lower() + if level in confidence_distribution: + confidence_distribution[level] += 1 + + # Log metadata + log_metadata( + metadata={ + "report_generation": { + "execution_time_seconds": execution_time, + "generation_method": "llm_generated", + "llm_model": llm_model, + "report_length_chars": len(html_content), + "num_sub_questions": len(state.sub_questions), + "num_sources": len(all_sources), + "has_viewpoint_analysis": bool(state.viewpoint_analysis), + "has_reflection": bool(state.reflection_metadata), + "confidence_distribution": confidence_distribution, + "fallback_report": False, + } + } + ) + + # Log model metadata for cross-pipeline tracking + log_metadata( + metadata={ + "research_quality": { + "confidence_distribution": confidence_distribution, + } + }, + infer_model=True, + ) + + logger.info("Final research report generated successfully") + return state, HTMLString(html_content) + + except Exception as e: + logger.error(f"Error generating final report: {e}") + # Generate a minimal fallback report + fallback_html = _generate_fallback_report(state) + + # Process the fallback HTML to ensure it's clean + fallback_html = clean_html_output(fallback_html) + + # Update the state with the fallback report + state.set_final_report(fallback_html) + + # Collect metadata about the fallback report + execution_time = time.time() - start_time + + # Count sources + all_sources = set() + for info in state.enhanced_info.values(): + if info.key_sources: + all_sources.update(info.key_sources) + + # Count confidence levels + confidence_distribution = {"high": 0, "medium": 0, "low": 0} + for info in state.enhanced_info.values(): + level = info.confidence_level.lower() + if level in confidence_distribution: + confidence_distribution[level] += 1 + + # Log metadata for fallback report + log_metadata( + metadata={ + "report_generation": { + "execution_time_seconds": execution_time, + "generation_method": "fallback", + "llm_model": llm_model, + "report_length_chars": len(fallback_html), + "num_sub_questions": len(state.sub_questions), + "num_sources": len(all_sources), + "has_viewpoint_analysis": bool(state.viewpoint_analysis), + "has_reflection": bool(state.reflection_metadata), + "confidence_distribution": confidence_distribution, + "fallback_report": True, + "error_message": str(e), + } + } + ) + + # Log model metadata for cross-pipeline tracking + log_metadata( + metadata={ + "research_quality": { + "confidence_distribution": confidence_distribution, + } + }, + infer_model=True, + ) + + return state, HTMLString(fallback_html) diff --git a/deep_research/steps/query_decomposition_step.py b/deep_research/steps/query_decomposition_step.py new file mode 100644 index 00000000..78e50b9f --- /dev/null +++ b/deep_research/steps/query_decomposition_step.py @@ -0,0 +1,174 @@ +import logging +import time +from typing import Annotated + +from materializers.pydantic_materializer import ResearchStateMaterializer +from utils.llm_utils import get_structured_llm_output +from utils.prompt_models import PromptsBundle +from utils.pydantic_models import ResearchState +from zenml import log_metadata, step + +logger = logging.getLogger(__name__) + + +@step(output_materializers=ResearchStateMaterializer) +def initial_query_decomposition_step( + state: ResearchState, + prompts_bundle: PromptsBundle, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + max_sub_questions: int = 8, + langfuse_project_name: str = "deep-research", +) -> Annotated[ResearchState, "updated_state"]: + """Break down a complex research query into specific sub-questions. + + Args: + state: The current research state + prompts_bundle: Bundle containing all prompts for the pipeline + llm_model: The reasoning model to use with provider prefix + max_sub_questions: Maximum number of sub-questions to generate + + Returns: + Updated research state with sub-questions + """ + start_time = time.time() + logger.info(f"Decomposing research query: {state.main_query}") + + # Get the prompt from the bundle + system_prompt = prompts_bundle.get_prompt_content( + "query_decomposition_prompt" + ) + + try: + # Call OpenAI API to decompose the query + updated_system_prompt = ( + system_prompt + + f"\nPlease generate at most {max_sub_questions} sub-questions." + ) + logger.info( + f"Calling {llm_model} to decompose query into max {max_sub_questions} sub-questions" + ) + + # Define fallback questions + fallback_questions = [ + { + "sub_question": f"What is {state.main_query}?", + "reasoning": "Basic understanding of the topic", + }, + { + "sub_question": f"What are the key aspects of {state.main_query}?", + "reasoning": "Exploring important dimensions", + }, + { + "sub_question": f"What are the implications of {state.main_query}?", + "reasoning": "Understanding broader impact", + }, + ] + + # Use utility function to get structured output + decomposed_questions = get_structured_llm_output( + prompt=state.main_query, + system_prompt=updated_system_prompt, + model=llm_model, + fallback_response=fallback_questions, + project=langfuse_project_name, + ) + + # Extract just the sub-questions + sub_questions = [ + item.get("sub_question") + for item in decomposed_questions + if "sub_question" in item + ] + + # Limit to max_sub_questions + sub_questions = sub_questions[:max_sub_questions] + + logger.info(f"Generated {len(sub_questions)} sub-questions") + for i, question in enumerate(sub_questions, 1): + logger.info(f" {i}. {question}") + + # Update the state with the new sub-questions + state.update_sub_questions(sub_questions) + + # Log step metadata + execution_time = time.time() - start_time + log_metadata( + metadata={ + "query_decomposition": { + "execution_time_seconds": execution_time, + "num_sub_questions": len(sub_questions), + "llm_model": llm_model, + "max_sub_questions_requested": max_sub_questions, + "fallback_used": False, + "main_query_length": len(state.main_query), + "sub_questions": sub_questions, + } + } + ) + + # Log model metadata for cross-pipeline tracking + log_metadata( + metadata={ + "research_scope": { + "num_sub_questions": len(sub_questions), + } + }, + infer_model=True, + ) + + # Log artifact metadata for the output state + log_metadata( + metadata={ + "state_characteristics": { + "total_sub_questions": len(state.sub_questions), + "has_search_results": bool(state.search_results), + "has_synthesized_info": bool(state.synthesized_info), + } + }, + infer_artifact=True, + ) + + return state + + except Exception as e: + logger.error(f"Error decomposing query: {e}") + # Return fallback questions in the state + fallback_questions = [ + f"What is {state.main_query}?", + f"What are the key aspects of {state.main_query}?", + f"What are the implications of {state.main_query}?", + ] + fallback_questions = fallback_questions[:max_sub_questions] + logger.info(f"Using {len(fallback_questions)} fallback questions:") + for i, question in enumerate(fallback_questions, 1): + logger.info(f" {i}. {question}") + state.update_sub_questions(fallback_questions) + + # Log metadata for fallback scenario + execution_time = time.time() - start_time + log_metadata( + metadata={ + "query_decomposition": { + "execution_time_seconds": execution_time, + "num_sub_questions": len(fallback_questions), + "llm_model": llm_model, + "max_sub_questions_requested": max_sub_questions, + "fallback_used": True, + "error_message": str(e), + "main_query_length": len(state.main_query), + "sub_questions": fallback_questions, + } + } + ) + + # Log model metadata for cross-pipeline tracking + log_metadata( + metadata={ + "research_scope": { + "num_sub_questions": len(fallback_questions), + } + }, + infer_model=True, + ) + + return state diff --git a/deep_research/tests/__init__.py b/deep_research/tests/__init__.py new file mode 100644 index 00000000..6206856b --- /dev/null +++ b/deep_research/tests/__init__.py @@ -0,0 +1 @@ +"""Test package for ZenML Deep Research project.""" diff --git a/deep_research/tests/conftest.py b/deep_research/tests/conftest.py new file mode 100644 index 00000000..b972a5e1 --- /dev/null +++ b/deep_research/tests/conftest.py @@ -0,0 +1,11 @@ +"""Test configuration for pytest. + +This file sets up the proper Python path for importing modules in tests. +""" + +import os +import sys + +# Add the project root directory to the Python path +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, project_root) diff --git a/deep_research/tests/test_approval_utils.py b/deep_research/tests/test_approval_utils.py new file mode 100644 index 00000000..fe859b0f --- /dev/null +++ b/deep_research/tests/test_approval_utils.py @@ -0,0 +1,150 @@ +"""Unit tests for approval utility functions.""" + +from utils.approval_utils import ( + calculate_estimated_cost, + format_approval_request, + format_critique_summary, + format_query_list, + parse_approval_response, + summarize_research_progress, +) +from utils.pydantic_models import ResearchState, SynthesizedInfo + + +def test_parse_approval_responses(): + """Test parsing different approval responses.""" + queries = ["query1", "query2", "query3"] + + # Test approve all + decision = parse_approval_response("APPROVE ALL", queries) + assert decision.approved == True + assert decision.selected_queries == queries + assert decision.approval_method == "APPROVE_ALL" + + # Test skip + decision = parse_approval_response( + "skip", queries + ) # Test case insensitive + assert decision.approved == False + assert decision.selected_queries == [] + assert decision.approval_method == "SKIP" + + # Test selection + decision = parse_approval_response("SELECT 1,3", queries) + assert decision.approved == True + assert decision.selected_queries == ["query1", "query3"] + assert decision.approval_method == "SELECT_SPECIFIC" + + # Test invalid selection + decision = parse_approval_response("SELECT invalid", queries) + assert decision.approved == False + assert decision.approval_method == "PARSE_ERROR" + + # Test out of range indices + decision = parse_approval_response("SELECT 1,5,10", queries) + assert decision.approved == True + assert decision.selected_queries == ["query1"] # Only valid indices + assert decision.approval_method == "SELECT_SPECIFIC" + + # Test unknown response + decision = parse_approval_response("maybe later", queries) + assert decision.approved == False + assert decision.approval_method == "UNKNOWN_RESPONSE" + + +def test_format_approval_request(): + """Test formatting of approval request messages.""" + message = format_approval_request( + main_query="Test query", + progress_summary={ + "completed_count": 5, + "avg_confidence": 0.75, + "low_confidence_count": 2, + }, + critique_points=[ + {"issue": "Missing data", "importance": "high"}, + {"issue": "Minor gap", "importance": "low"}, + ], + proposed_queries=["query1", "query2"], + ) + + assert "Test query" in message + assert "5" in message + assert "0.75" in message + assert "2 queries" in message + assert "approve" in message.lower() + assert "reject" in message.lower() + assert "Missing data" in message + + +def test_summarize_research_progress(): + """Test research progress summarization.""" + state = ResearchState( + main_query="test", + synthesized_info={ + "q1": SynthesizedInfo( + synthesized_answer="a1", confidence_level="high" + ), + "q2": SynthesizedInfo( + synthesized_answer="a2", confidence_level="medium" + ), + "q3": SynthesizedInfo( + synthesized_answer="a3", confidence_level="low" + ), + "q4": SynthesizedInfo( + synthesized_answer="a4", confidence_level="low" + ), + }, + ) + + summary = summarize_research_progress(state) + + assert summary["completed_count"] == 4 + # (1.0 + 0.5 + 0.0 + 0.0) / 4 = 1.5 / 4 = 0.375, rounded to 0.38 + assert summary["avg_confidence"] == 0.38 + assert summary["low_confidence_count"] == 2 + + +def test_format_critique_summary(): + """Test critique summary formatting.""" + # Test with no critiques + result = format_critique_summary([]) + assert result == "No critical issues identified." + + # Test with few critiques + critiques = [{"issue": "Issue 1"}, {"issue": "Issue 2"}] + result = format_critique_summary(critiques) + assert "- Issue 1" in result + assert "- Issue 2" in result + assert "more issues" not in result + + # Test with many critiques + critiques = [{"issue": f"Issue {i}"} for i in range(5)] + result = format_critique_summary(critiques) + assert "- Issue 0" in result + assert "- Issue 1" in result + assert "- Issue 2" in result + assert "- Issue 3" not in result + assert "... and 2 more issues" in result + + +def test_format_query_list(): + """Test query list formatting.""" + # Test empty list + result = format_query_list([]) + assert result == "No queries proposed." + + # Test with queries + queries = ["Query A", "Query B", "Query C"] + result = format_query_list(queries) + assert "1. Query A" in result + assert "2. Query B" in result + assert "3. Query C" in result + + +def test_calculate_estimated_cost(): + """Test cost estimation.""" + assert calculate_estimated_cost([]) == 0.0 + assert calculate_estimated_cost(["q1"]) == 0.01 + assert calculate_estimated_cost(["q1", "q2", "q3"]) == 0.03 + assert calculate_estimated_cost(["q1"] * 10) == 0.10 diff --git a/deep_research/tests/test_prompt_loader.py b/deep_research/tests/test_prompt_loader.py new file mode 100644 index 00000000..fdc7b7da --- /dev/null +++ b/deep_research/tests/test_prompt_loader.py @@ -0,0 +1,108 @@ +"""Unit tests for prompt loader utilities.""" + +import pytest +from utils.prompt_loader import get_prompt_for_step, load_prompts_bundle +from utils.prompt_models import PromptsBundle + + +class TestPromptLoader: + """Test cases for prompt loader functions.""" + + def test_load_prompts_bundle(self): + """Test loading prompts bundle from prompts.py.""" + bundle = load_prompts_bundle(pipeline_version="2.0.0") + + # Check it returns a PromptsBundle + assert isinstance(bundle, PromptsBundle) + + # Check pipeline version + assert bundle.pipeline_version == "2.0.0" + + # Check all core prompts are loaded + assert bundle.search_query_prompt is not None + assert bundle.query_decomposition_prompt is not None + assert bundle.synthesis_prompt is not None + assert bundle.viewpoint_analysis_prompt is not None + assert bundle.reflection_prompt is not None + assert bundle.additional_synthesis_prompt is not None + assert bundle.conclusion_generation_prompt is not None + + # Check prompts have correct metadata + assert bundle.search_query_prompt.name == "search_query_prompt" + assert ( + bundle.search_query_prompt.description + == "Generates effective search queries from sub-questions" + ) + assert bundle.search_query_prompt.version == "1.0.0" + assert "search" in bundle.search_query_prompt.tags + + # Check that actual prompt content is loaded + assert "search query" in bundle.search_query_prompt.content.lower() + assert "json schema" in bundle.search_query_prompt.content.lower() + + def test_load_prompts_bundle_default_version(self): + """Test loading prompts bundle with default version.""" + bundle = load_prompts_bundle() + assert bundle.pipeline_version == "1.0.0" + + def test_get_prompt_for_step(self): + """Test getting prompt content for specific steps.""" + bundle = load_prompts_bundle() + + # Test valid step names + test_cases = [ + ("query_decomposition", "query_decomposition_prompt"), + ("search_query_generation", "search_query_prompt"), + ("synthesis", "synthesis_prompt"), + ("viewpoint_analysis", "viewpoint_analysis_prompt"), + ("reflection", "reflection_prompt"), + ("additional_synthesis", "additional_synthesis_prompt"), + ("conclusion_generation", "conclusion_generation_prompt"), + ] + + for step_name, expected_prompt_attr in test_cases: + content = get_prompt_for_step(bundle, step_name) + expected_content = getattr(bundle, expected_prompt_attr).content + assert content == expected_content + + def test_get_prompt_for_step_invalid(self): + """Test getting prompt for invalid step name.""" + bundle = load_prompts_bundle() + + with pytest.raises( + ValueError, match="No prompt mapping found for step: invalid_step" + ): + get_prompt_for_step(bundle, "invalid_step") + + def test_all_prompts_have_content(self): + """Test that all loaded prompts have non-empty content.""" + bundle = load_prompts_bundle() + + all_prompts = bundle.list_all_prompts() + for name, prompt in all_prompts.items(): + assert prompt.content, f"Prompt {name} has empty content" + assert ( + len(prompt.content) > 50 + ), f"Prompt {name} content seems too short" + + def test_all_prompts_have_descriptions(self): + """Test that all loaded prompts have descriptions.""" + bundle = load_prompts_bundle() + + all_prompts = bundle.list_all_prompts() + for name, prompt in all_prompts.items(): + assert prompt.description, f"Prompt {name} has no description" + assert ( + len(prompt.description) > 10 + ), f"Prompt {name} description seems too short" + + def test_all_prompts_have_tags(self): + """Test that all loaded prompts have at least one tag.""" + bundle = load_prompts_bundle() + + all_prompts = bundle.list_all_prompts() + for name, prompt in all_prompts.items(): + assert prompt.tags, f"Prompt {name} has no tags" + assert ( + len(prompt.tags) >= 1 + ), f"Prompt {name} should have at least one tag" diff --git a/deep_research/tests/test_prompt_models.py b/deep_research/tests/test_prompt_models.py new file mode 100644 index 00000000..f8b0519f --- /dev/null +++ b/deep_research/tests/test_prompt_models.py @@ -0,0 +1,183 @@ +"""Unit tests for prompt models and utilities.""" + +import pytest +from utils.prompt_models import PromptsBundle, PromptTemplate + + +class TestPromptTemplate: + """Test cases for PromptTemplate model.""" + + def test_prompt_template_creation(self): + """Test creating a prompt template with all fields.""" + prompt = PromptTemplate( + name="test_prompt", + content="This is a test prompt", + description="A test prompt for unit testing", + version="1.0.0", + tags=["test", "unit"], + ) + + assert prompt.name == "test_prompt" + assert prompt.content == "This is a test prompt" + assert prompt.description == "A test prompt for unit testing" + assert prompt.version == "1.0.0" + assert prompt.tags == ["test", "unit"] + + def test_prompt_template_minimal(self): + """Test creating a prompt template with minimal fields.""" + prompt = PromptTemplate( + name="minimal_prompt", content="Minimal content" + ) + + assert prompt.name == "minimal_prompt" + assert prompt.content == "Minimal content" + assert prompt.description == "" + assert prompt.version == "1.0.0" + assert prompt.tags == [] + + +class TestPromptsBundle: + """Test cases for PromptsBundle model.""" + + @pytest.fixture + def sample_prompts(self): + """Create sample prompts for testing.""" + return { + "search_query_prompt": PromptTemplate( + name="search_query_prompt", + content="Search query content", + description="Generates search queries", + tags=["search"], + ), + "query_decomposition_prompt": PromptTemplate( + name="query_decomposition_prompt", + content="Query decomposition content", + description="Decomposes queries", + tags=["analysis"], + ), + "synthesis_prompt": PromptTemplate( + name="synthesis_prompt", + content="Synthesis content", + description="Synthesizes information", + tags=["synthesis"], + ), + "viewpoint_analysis_prompt": PromptTemplate( + name="viewpoint_analysis_prompt", + content="Viewpoint analysis content", + description="Analyzes viewpoints", + tags=["analysis"], + ), + "reflection_prompt": PromptTemplate( + name="reflection_prompt", + content="Reflection content", + description="Reflects on research", + tags=["reflection"], + ), + "additional_synthesis_prompt": PromptTemplate( + name="additional_synthesis_prompt", + content="Additional synthesis content", + description="Additional synthesis", + tags=["synthesis"], + ), + "conclusion_generation_prompt": PromptTemplate( + name="conclusion_generation_prompt", + content="Conclusion generation content", + description="Generates conclusions", + tags=["report"], + ), + } + + def test_prompts_bundle_creation(self, sample_prompts): + """Test creating a prompts bundle.""" + bundle = PromptsBundle(**sample_prompts) + + assert bundle.search_query_prompt.name == "search_query_prompt" + assert ( + bundle.query_decomposition_prompt.name + == "query_decomposition_prompt" + ) + assert bundle.pipeline_version == "1.0.0" + assert isinstance(bundle.created_at, str) + assert bundle.custom_prompts == {} + + def test_prompts_bundle_with_custom_prompts(self, sample_prompts): + """Test creating a prompts bundle with custom prompts.""" + custom_prompt = PromptTemplate( + name="custom_prompt", + content="Custom prompt content", + description="A custom prompt", + ) + + bundle = PromptsBundle( + **sample_prompts, custom_prompts={"custom_prompt": custom_prompt} + ) + + assert "custom_prompt" in bundle.custom_prompts + assert bundle.custom_prompts["custom_prompt"].name == "custom_prompt" + + def test_get_prompt_by_name(self, sample_prompts): + """Test retrieving prompts by name.""" + bundle = PromptsBundle(**sample_prompts) + + # Test getting a core prompt + prompt = bundle.get_prompt_by_name("search_query_prompt") + assert prompt is not None + assert prompt.name == "search_query_prompt" + + # Test getting a non-existent prompt + prompt = bundle.get_prompt_by_name("non_existent") + assert prompt is None + + def test_get_prompt_by_name_custom(self, sample_prompts): + """Test retrieving custom prompts by name.""" + custom_prompt = PromptTemplate( + name="custom_prompt", content="Custom content" + ) + + bundle = PromptsBundle( + **sample_prompts, custom_prompts={"custom_prompt": custom_prompt} + ) + + prompt = bundle.get_prompt_by_name("custom_prompt") + assert prompt is not None + assert prompt.name == "custom_prompt" + + def test_list_all_prompts(self, sample_prompts): + """Test listing all prompts.""" + bundle = PromptsBundle(**sample_prompts) + + all_prompts = bundle.list_all_prompts() + assert len(all_prompts) == 7 # 7 core prompts + assert "search_query_prompt" in all_prompts + assert "conclusion_generation_prompt" in all_prompts + + def test_list_all_prompts_with_custom(self, sample_prompts): + """Test listing all prompts including custom ones.""" + custom_prompt = PromptTemplate( + name="custom_prompt", content="Custom content" + ) + + bundle = PromptsBundle( + **sample_prompts, custom_prompts={"custom_prompt": custom_prompt} + ) + + all_prompts = bundle.list_all_prompts() + assert len(all_prompts) == 8 # 7 core + 1 custom + assert "custom_prompt" in all_prompts + + def test_get_prompt_content(self, sample_prompts): + """Test getting prompt content by type.""" + bundle = PromptsBundle(**sample_prompts) + + content = bundle.get_prompt_content("search_query_prompt") + assert content == "Search query content" + + content = bundle.get_prompt_content("synthesis_prompt") + assert content == "Synthesis content" + + def test_get_prompt_content_invalid(self, sample_prompts): + """Test getting prompt content with invalid type.""" + bundle = PromptsBundle(**sample_prompts) + + with pytest.raises(AttributeError): + bundle.get_prompt_content("invalid_prompt_type") diff --git a/deep_research/tests/test_pydantic_final_report_step.py b/deep_research/tests/test_pydantic_final_report_step.py new file mode 100644 index 00000000..c0f13530 --- /dev/null +++ b/deep_research/tests/test_pydantic_final_report_step.py @@ -0,0 +1,167 @@ +"""Tests for the Pydantic-based final report step. + +This module contains tests for the Pydantic-based implementation of +final_report_step, which uses the new Pydantic models and materializers. +""" + +from typing import Dict, List + +import pytest +from steps.pydantic_final_report_step import pydantic_final_report_step +from utils.pydantic_models import ( + ReflectionMetadata, + ResearchState, + SearchResult, + SynthesizedInfo, + ViewpointAnalysis, + ViewpointTension, +) +from zenml.types import HTMLString + + +@pytest.fixture +def sample_research_state() -> ResearchState: + """Create a sample research state for testing.""" + # Create a basic research state + state = ResearchState(main_query="What are the impacts of climate change?") + + # Add sub-questions + state.update_sub_questions(["Economic impacts", "Environmental impacts"]) + + # Add search results + search_results: Dict[str, List[SearchResult]] = { + "Economic impacts": [ + SearchResult( + url="https://example.com/economy", + title="Economic Impacts of Climate Change", + snippet="Overview of economic impacts", + content="Detailed content about economic impacts of climate change", + ) + ] + } + state.update_search_results(search_results) + + # Add synthesized info + synthesized_info: Dict[str, SynthesizedInfo] = { + "Economic impacts": SynthesizedInfo( + synthesized_answer="Climate change will have significant economic impacts...", + key_sources=["https://example.com/economy"], + confidence_level="high", + ), + "Environmental impacts": SynthesizedInfo( + synthesized_answer="Environmental impacts include rising sea levels...", + key_sources=["https://example.com/environment"], + confidence_level="high", + ), + } + state.update_synthesized_info(synthesized_info) + + # Add enhanced info (same as synthesized for this test) + state.enhanced_info = state.synthesized_info + + # Add viewpoint analysis + viewpoint_analysis = ViewpointAnalysis( + main_points_of_agreement=[ + "Climate change is happening", + "Action is needed", + ], + areas_of_tension=[ + ViewpointTension( + topic="Economic policy", + viewpoints={ + "Progressive": "Support carbon taxes and regulations", + "Conservative": "Prefer market-based solutions", + }, + ) + ], + perspective_gaps="Indigenous perspectives are underrepresented", + integrative_insights="A balanced approach combining regulations and market incentives may be most effective", + ) + state.update_viewpoint_analysis(viewpoint_analysis) + + # Add reflection metadata + reflection_metadata = ReflectionMetadata( + critique_summary=["Need more sources for economic impacts"], + additional_questions_identified=[ + "How will climate change affect different regions?" + ], + searches_performed=[ + "economic impacts of climate change", + "regional climate impacts", + ], + improvements_made=2, + ) + state.reflection_metadata = reflection_metadata + + return state + + +def test_pydantic_final_report_step_returns_tuple(): + """Test that the step returns a tuple with state and HTML.""" + # Create a simple state + state = ResearchState(main_query="What is climate change?") + state.update_sub_questions(["What causes climate change?"]) + + # Run the step + result = pydantic_final_report_step(state=state) + + # Assert that result is a tuple with 2 elements + assert isinstance(result, tuple) + assert len(result) == 2 + + # Assert first element is ResearchState + assert isinstance(result[0], ResearchState) + + # Assert second element is HTMLString + assert isinstance(result[1], HTMLString) + + +def test_pydantic_final_report_step_with_complex_state(sample_research_state): + """Test that the step handles a complex state properly.""" + # Run the step with a complex state + result = pydantic_final_report_step(state=sample_research_state) + + # Unpack the results + updated_state, html_report = result + + # Assert state contains final report HTML + assert updated_state.final_report_html != "" + + # Assert HTML report contains key elements + html_str = str(html_report) + assert "Economic impacts" in html_str + assert "Environmental impacts" in html_str + assert "Viewpoint Analysis" in html_str + assert "Progressive" in html_str + assert "Conservative" in html_str + + +def test_pydantic_final_report_step_updates_state(): + """Test that the step properly updates the state.""" + # Create an initial state without a final report + state = ResearchState( + main_query="What is climate change?", + sub_questions=["What causes climate change?"], + synthesized_info={ + "What causes climate change?": SynthesizedInfo( + synthesized_answer="Climate change is caused by greenhouse gases.", + confidence_level="high", + ) + }, + enhanced_info={ + "What causes climate change?": SynthesizedInfo( + synthesized_answer="Climate change is caused by greenhouse gases.", + confidence_level="high", + ) + }, + ) + + # Verify initial state has no report + assert state.final_report_html == "" + + # Run the step + updated_state, _ = pydantic_final_report_step(state=state) + + # Verify state was updated with a report + assert updated_state.final_report_html != "" + assert "climate change" in updated_state.final_report_html.lower() diff --git a/deep_research/tests/test_pydantic_materializer.py b/deep_research/tests/test_pydantic_materializer.py new file mode 100644 index 00000000..49cb17f8 --- /dev/null +++ b/deep_research/tests/test_pydantic_materializer.py @@ -0,0 +1,161 @@ +"""Tests for Pydantic-based materializer. + +This module contains tests for the Pydantic-based implementation of +ResearchStateMaterializer, verifying that it correctly serializes and +visualizes ResearchState objects. +""" + +import os +import tempfile +from typing import Dict, List + +import pytest +from materializers.pydantic_materializer import ResearchStateMaterializer +from utils.pydantic_models import ( + ResearchState, + SearchResult, + SynthesizedInfo, + ViewpointAnalysis, + ViewpointTension, +) + + +@pytest.fixture +def sample_state() -> ResearchState: + """Create a sample research state for testing.""" + # Create a basic research state + state = ResearchState(main_query="What are the impacts of climate change?") + + # Add sub-questions + state.update_sub_questions(["Economic impacts", "Environmental impacts"]) + + # Add search results + search_results: Dict[str, List[SearchResult]] = { + "Economic impacts": [ + SearchResult( + url="https://example.com/economy", + title="Economic Impacts of Climate Change", + snippet="Overview of economic impacts", + content="Detailed content about economic impacts of climate change", + ) + ] + } + state.update_search_results(search_results) + + # Add synthesized info + synthesized_info: Dict[str, SynthesizedInfo] = { + "Economic impacts": SynthesizedInfo( + synthesized_answer="Climate change will have significant economic impacts...", + key_sources=["https://example.com/economy"], + confidence_level="high", + ) + } + state.update_synthesized_info(synthesized_info) + + return state + + +def test_materializer_initialization(): + """Test that the materializer can be initialized.""" + # Create a temporary directory for artifact storage + with tempfile.TemporaryDirectory() as tmpdirname: + materializer = ResearchStateMaterializer(uri=tmpdirname) + assert materializer is not None + + +def test_materializer_save_and_load(sample_state: ResearchState): + """Test saving and loading a state using the materializer.""" + # Create a temporary directory for artifact storage + with tempfile.TemporaryDirectory() as tmpdirname: + # Initialize materializer with temporary artifact URI + materializer = ResearchStateMaterializer(uri=tmpdirname) + + # Save the state + materializer.save(sample_state) + + # Load the state + loaded_state = materializer.load(ResearchState) + + # Verify that the loaded state matches the original + assert loaded_state.main_query == sample_state.main_query + assert loaded_state.sub_questions == sample_state.sub_questions + assert len(loaded_state.search_results) == len( + sample_state.search_results + ) + assert ( + loaded_state.get_current_stage() + == sample_state.get_current_stage() + ) + + # Check that key fields were preserved + question = "Economic impacts" + assert ( + loaded_state.synthesized_info[question].synthesized_answer + == sample_state.synthesized_info[question].synthesized_answer + ) + assert ( + loaded_state.synthesized_info[question].confidence_level + == sample_state.synthesized_info[question].confidence_level + ) + + +def test_materializer_save_visualizations(sample_state: ResearchState): + """Test generating and saving visualizations.""" + # Create a temporary directory for artifact storage + with tempfile.TemporaryDirectory() as tmpdirname: + # Initialize materializer with temporary artifact URI + materializer = ResearchStateMaterializer(uri=tmpdirname) + + # Generate and save visualizations + viz_paths = materializer.save_visualizations(sample_state) + + # Verify visualization file exists + html_path = list(viz_paths.keys())[0] + assert os.path.exists(html_path) + + # Verify the file has content + with open(html_path, "r") as f: + content = f.read() + # Check for expected elements in the HTML + assert "Research State" in content + assert sample_state.main_query in content + assert "Economic impacts" in content + + +def test_html_generation_stages(sample_state: ResearchState): + """Test that HTML visualization reflects the correct research stage.""" + # Create the materializer + with tempfile.TemporaryDirectory() as tmpdirname: + materializer = ResearchStateMaterializer(uri=tmpdirname) + + # Generate visualization at initial state + html = materializer._generate_visualization_html(sample_state) + # Verify stage by checking for expected elements in the HTML + assert ( + "Synthesized Information" in html + ) # Should show synthesized info + + # Add viewpoint analysis + state_with_viewpoints = sample_state.model_copy(deep=True) + viewpoint_analysis = ViewpointAnalysis( + main_points_of_agreement=["There will be economic impacts"], + areas_of_tension=[ + ViewpointTension( + topic="Job impacts", + viewpoints={ + "Positive": "New green jobs", + "Negative": "Job losses", + }, + ) + ], + ) + state_with_viewpoints.update_viewpoint_analysis(viewpoint_analysis) + html = materializer._generate_visualization_html(state_with_viewpoints) + assert "Viewpoint Analysis" in html + assert "Points of Agreement" in html + + # Add final report + state_with_report = state_with_viewpoints.model_copy(deep=True) + state_with_report.set_final_report("Final report content") + html = materializer._generate_visualization_html(state_with_report) + assert "Final Report" in html diff --git a/deep_research/tests/test_pydantic_models.py b/deep_research/tests/test_pydantic_models.py new file mode 100644 index 00000000..b900f8d8 --- /dev/null +++ b/deep_research/tests/test_pydantic_models.py @@ -0,0 +1,303 @@ +"""Tests for Pydantic model implementations. + +This module contains tests for the Pydantic models that validate: +1. Basic model instantiation +2. Default values +3. Serialization and deserialization +4. Method functionality +""" + +import json +from typing import Dict, List + +from utils.pydantic_models import ( + ReflectionMetadata, + ResearchState, + SearchResult, + SynthesizedInfo, + ViewpointAnalysis, + ViewpointTension, +) + + +def test_search_result_creation(): + """Test creating a SearchResult model.""" + # Create with defaults + result = SearchResult() + assert result.url == "" + assert result.content == "" + assert result.title == "" + assert result.snippet == "" + + # Create with values + result = SearchResult( + url="https://example.com", + content="Example content", + title="Example Title", + snippet="This is a snippet", + ) + assert result.url == "https://example.com" + assert result.content == "Example content" + assert result.title == "Example Title" + assert result.snippet == "This is a snippet" + + +def test_search_result_serialization(): + """Test serializing and deserializing a SearchResult.""" + result = SearchResult( + url="https://example.com", + content="Example content", + title="Example Title", + snippet="This is a snippet", + ) + + # Serialize to dict + result_dict = result.model_dump() + assert result_dict["url"] == "https://example.com" + assert result_dict["content"] == "Example content" + + # Serialize to JSON + result_json = result.model_dump_json() + result_dict_from_json = json.loads(result_json) + assert result_dict_from_json["url"] == "https://example.com" + + # Deserialize from dict + new_result = SearchResult.model_validate(result_dict) + assert new_result.url == "https://example.com" + assert new_result.content == "Example content" + + # Deserialize from JSON + new_result_from_json = SearchResult.model_validate_json(result_json) + assert new_result_from_json.url == "https://example.com" + + +def test_viewpoint_tension_model(): + """Test the ViewpointTension model.""" + # Empty model + tension = ViewpointTension() + assert tension.topic == "" + assert tension.viewpoints == {} + + # With data + tension = ViewpointTension( + topic="Climate Change Impacts", + viewpoints={ + "Economic": "Focuses on financial costs and benefits", + "Environmental": "Emphasizes ecosystem impacts", + }, + ) + assert tension.topic == "Climate Change Impacts" + assert len(tension.viewpoints) == 2 + assert "Economic" in tension.viewpoints + + # Serialization + tension_dict = tension.model_dump() + assert tension_dict["topic"] == "Climate Change Impacts" + assert len(tension_dict["viewpoints"]) == 2 + + # Deserialization + new_tension = ViewpointTension.model_validate(tension_dict) + assert new_tension.topic == tension.topic + assert new_tension.viewpoints == tension.viewpoints + + +def test_synthesized_info_model(): + """Test the SynthesizedInfo model.""" + # Default values + info = SynthesizedInfo() + assert info.synthesized_answer == "" + assert info.key_sources == [] + assert info.confidence_level == "medium" + assert info.information_gaps == "" + assert info.improvements == [] + + # With values + info = SynthesizedInfo( + synthesized_answer="This is a synthesized answer", + key_sources=["https://source1.com", "https://source2.com"], + confidence_level="high", + information_gaps="Missing some context", + improvements=["Add more detail", "Check more sources"], + ) + assert info.synthesized_answer == "This is a synthesized answer" + assert len(info.key_sources) == 2 + assert info.confidence_level == "high" + + # Serialization and deserialization + info_dict = info.model_dump() + new_info = SynthesizedInfo.model_validate(info_dict) + assert new_info.synthesized_answer == info.synthesized_answer + assert new_info.key_sources == info.key_sources + + +def test_viewpoint_analysis_model(): + """Test the ViewpointAnalysis model.""" + # Create tensions for the analysis + tension1 = ViewpointTension( + topic="Economic Impact", + viewpoints={ + "Positive": "Creates jobs", + "Negative": "Increases inequality", + }, + ) + tension2 = ViewpointTension( + topic="Environmental Impact", + viewpoints={ + "Positive": "Reduces emissions", + "Negative": "Land use changes", + }, + ) + + # Create the analysis + analysis = ViewpointAnalysis( + main_points_of_agreement=[ + "Need for action", + "Technological innovation", + ], + areas_of_tension=[tension1, tension2], + perspective_gaps="Missing indigenous perspectives", + integrative_insights="Combined economic and environmental approach needed", + ) + + assert len(analysis.main_points_of_agreement) == 2 + assert len(analysis.areas_of_tension) == 2 + assert analysis.areas_of_tension[0].topic == "Economic Impact" + + # Test serialization + analysis_dict = analysis.model_dump() + assert len(analysis_dict["areas_of_tension"]) == 2 + assert analysis_dict["areas_of_tension"][0]["topic"] == "Economic Impact" + + # Test deserialization + new_analysis = ViewpointAnalysis.model_validate(analysis_dict) + assert len(new_analysis.areas_of_tension) == 2 + assert new_analysis.areas_of_tension[0].topic == "Economic Impact" + assert new_analysis.perspective_gaps == analysis.perspective_gaps + + +def test_reflection_metadata_model(): + """Test the ReflectionMetadata model.""" + metadata = ReflectionMetadata( + critique_summary=["Need more sources", "Missing detailed analysis"], + additional_questions_identified=["What about future trends?"], + searches_performed=["future climate trends", "economic impacts"], + improvements_made=3, + error=None, + ) + + assert len(metadata.critique_summary) == 2 + assert len(metadata.additional_questions_identified) == 1 + assert metadata.improvements_made == 3 + assert metadata.error is None + + # Serialization + metadata_dict = metadata.model_dump() + assert len(metadata_dict["critique_summary"]) == 2 + assert metadata_dict["improvements_made"] == 3 + + # Deserialization + new_metadata = ReflectionMetadata.model_validate(metadata_dict) + assert new_metadata.improvements_made == metadata.improvements_made + assert new_metadata.critique_summary == metadata.critique_summary + + +def test_research_state_model(): + """Test the main ResearchState model.""" + # Create with defaults + state = ResearchState() + assert state.main_query == "" + assert state.sub_questions == [] + assert state.search_results == {} + assert state.get_current_stage() == "empty" + + # Set main query + state.main_query = "What are the impacts of climate change?" + assert state.get_current_stage() == "initial" + + # Test update methods + state.update_sub_questions( + ["What are economic impacts?", "What are environmental impacts?"] + ) + assert len(state.sub_questions) == 2 + assert state.get_current_stage() == "after_query_decomposition" + + # Add search results + search_results: Dict[str, List[SearchResult]] = { + "What are economic impacts?": [ + SearchResult( + url="https://example.com/economy", + title="Economic Impacts", + snippet="Overview of economic impacts", + content="Detailed content about economic impacts", + ) + ] + } + state.update_search_results(search_results) + assert state.get_current_stage() == "after_search" + assert len(state.search_results["What are economic impacts?"]) == 1 + + # Add synthesized info + synthesized_info: Dict[str, SynthesizedInfo] = { + "What are economic impacts?": SynthesizedInfo( + synthesized_answer="Economic impacts include job losses and growth opportunities", + key_sources=["https://example.com/economy"], + confidence_level="high", + ) + } + state.update_synthesized_info(synthesized_info) + assert state.get_current_stage() == "after_synthesis" + + # Add viewpoint analysis + analysis = ViewpointAnalysis( + main_points_of_agreement=["Economic changes are happening"], + areas_of_tension=[ + ViewpointTension( + topic="Job impacts", + viewpoints={ + "Positive": "New green jobs", + "Negative": "Fossil fuel job losses", + }, + ) + ], + ) + state.update_viewpoint_analysis(analysis) + assert state.get_current_stage() == "after_viewpoint_analysis" + + # Add reflection results + enhanced_info = { + "What are economic impacts?": SynthesizedInfo( + synthesized_answer="Enhanced answer with more details", + key_sources=[ + "https://example.com/economy", + "https://example.com/new-source", + ], + confidence_level="high", + improvements=["Added more context", "Added more sources"], + ) + } + metadata = ReflectionMetadata( + critique_summary=["Needed more sources"], + improvements_made=2, + ) + state.update_after_reflection(enhanced_info, metadata) + assert state.get_current_stage() == "after_reflection" + + # Set final report + state.set_final_report("Final report content") + assert state.get_current_stage() == "final_report" + assert state.final_report_html == "Final report content" + + # Test serialization and deserialization + state_dict = state.model_dump() + new_state = ResearchState.model_validate(state_dict) + + # Verify key properties were preserved + assert new_state.main_query == state.main_query + assert len(new_state.sub_questions) == len(state.sub_questions) + assert new_state.get_current_stage() == state.get_current_stage() + assert new_state.viewpoint_analysis is not None + assert len(new_state.viewpoint_analysis.areas_of_tension) == 1 + assert ( + new_state.viewpoint_analysis.areas_of_tension[0].topic == "Job impacts" + ) + assert new_state.final_report_html == state.final_report_html diff --git a/deep_research/utils/__init__.py b/deep_research/utils/__init__.py new file mode 100644 index 00000000..395e1d67 --- /dev/null +++ b/deep_research/utils/__init__.py @@ -0,0 +1,7 @@ +""" +Utilities package for the ZenML Deep Research project. + +This package contains various utility functions and helpers used throughout the project, +including data models, LLM interaction utilities, search functionality, and common helper +functions for text processing and state management. +""" diff --git a/deep_research/utils/approval_utils.py b/deep_research/utils/approval_utils.py new file mode 100644 index 00000000..56a8ff91 --- /dev/null +++ b/deep_research/utils/approval_utils.py @@ -0,0 +1,159 @@ +"""Utility functions for the human approval process.""" + +from typing import Any, Dict, List + +from utils.pydantic_models import ApprovalDecision + + +def summarize_research_progress(state) -> Dict[str, Any]: + """Summarize the current research progress.""" + completed_count = len(state.synthesized_info) + confidence_levels = [ + info.confidence_level for info in state.synthesized_info.values() + ] + + # Calculate average confidence (high=1.0, medium=0.5, low=0.0) + confidence_map = {"high": 1.0, "medium": 0.5, "low": 0.0} + avg_confidence = sum( + confidence_map.get(c, 0.5) for c in confidence_levels + ) / max(len(confidence_levels), 1) + + low_confidence_count = sum(1 for c in confidence_levels if c == "low") + + return { + "completed_count": completed_count, + "avg_confidence": round(avg_confidence, 2), + "low_confidence_count": low_confidence_count, + } + + +def format_critique_summary(critique_points: List[Dict[str, Any]]) -> str: + """Format critique points for display.""" + if not critique_points: + return "No critical issues identified." + + formatted = [] + for point in critique_points[:3]: # Show top 3 + issue = point.get("issue", "Unknown issue") + formatted.append(f"- {issue}") + + if len(critique_points) > 3: + formatted.append(f"- ... and {len(critique_points) - 3} more issues") + + return "\n".join(formatted) + + +def format_query_list(queries: List[str]) -> str: + """Format query list for display.""" + if not queries: + return "No queries proposed." + + formatted = [] + for i, query in enumerate(queries, 1): + formatted.append(f"{i}. {query}") + + return "\n".join(formatted) + + +def calculate_estimated_cost(queries: List[str]) -> float: + """Calculate estimated cost for additional queries.""" + # Rough estimate: ~$0.01 per query (including search API + LLM costs) + return round(len(queries) * 0.01, 2) + + +def format_approval_request( + main_query: str, + progress_summary: Dict[str, Any], + critique_points: List[Dict[str, Any]], + proposed_queries: List[str], + timeout: int = 3600, +) -> str: + """Format the approval request message.""" + + # High-priority critiques + high_priority = [ + c for c in critique_points if c.get("importance") == "high" + ] + + message = f"""📊 **Research Progress Update** + +**Main Query:** {main_query} + +**Current Status:** +- Sub-questions analyzed: {progress_summary["completed_count"]} +- Average confidence: {progress_summary["avg_confidence"]} +- Low confidence areas: {progress_summary["low_confidence_count"]} + +**Key Issues Identified:** +{format_critique_summary(high_priority or critique_points)} + +**Proposed Additional Research** ({len(proposed_queries)} queries): +{format_query_list(proposed_queries)} + +**Estimated Additional Time:** ~{len(proposed_queries) * 2} minutes +**Estimated Additional Cost:** ~${calculate_estimated_cost(proposed_queries)} + +**Response Options:** +- Reply with `approve`, `yes`, `ok`, or `LGTM` to proceed with all queries +- Reply with `reject`, `no`, `skip`, or `decline` to finish with current findings + +**Timeout:** Response required within {timeout // 60} minutes""" + + return message + + +def parse_approval_response( + response: str, proposed_queries: List[str] +) -> ApprovalDecision: + """Parse the approval response from user.""" + + response_upper = response.strip().upper() + + if response_upper == "APPROVE ALL": + return ApprovalDecision( + approved=True, + selected_queries=proposed_queries, + approval_method="APPROVE_ALL", + reviewer_notes=response, + ) + + elif response_upper == "SKIP": + return ApprovalDecision( + approved=False, + selected_queries=[], + approval_method="SKIP", + reviewer_notes=response, + ) + + elif response_upper.startswith("SELECT"): + # Parse selection like "SELECT 1,3,5" + try: + # Extract the part after "SELECT" + selection_part = response_upper[6:].strip() + indices = [int(x.strip()) - 1 for x in selection_part.split(",")] + selected = [ + proposed_queries[i] + for i in indices + if 0 <= i < len(proposed_queries) + ] + return ApprovalDecision( + approved=True, + selected_queries=selected, + approval_method="SELECT_SPECIFIC", + reviewer_notes=response, + ) + except Exception as e: + return ApprovalDecision( + approved=False, + selected_queries=[], + approval_method="PARSE_ERROR", + reviewer_notes=f"Failed to parse: {response} - {str(e)}", + ) + + else: + return ApprovalDecision( + approved=False, + selected_queries=[], + approval_method="UNKNOWN_RESPONSE", + reviewer_notes=f"Unknown response: {response}", + ) diff --git a/deep_research/utils/helper_functions.py b/deep_research/utils/helper_functions.py new file mode 100644 index 00000000..256cdd28 --- /dev/null +++ b/deep_research/utils/helper_functions.py @@ -0,0 +1,192 @@ +import json +import logging +import os +from json.decoder import JSONDecodeError +from typing import Any, Dict, Optional + +import yaml + +logger = logging.getLogger(__name__) + + +def remove_reasoning_from_output(output: str) -> str: + """Remove the reasoning portion from LLM output. + + Args: + output: Raw output from LLM that may contain reasoning + + Returns: + Cleaned output without the reasoning section + """ + if not output: + return "" + + if "" in output: + return output.split("")[-1].strip() + return output.strip() + + +def clean_json_tags(text: str) -> str: + """Clean JSON markdown tags from text. + + Args: + text: Text with potential JSON markdown tags + + Returns: + Cleaned text without JSON markdown tags + """ + if not text: + return "" + + cleaned = text.replace("```json\n", "").replace("\n```", "") + cleaned = cleaned.replace("```json", "").replace("```", "") + return cleaned + + +def clean_markdown_tags(text: str) -> str: + """Clean Markdown tags from text. + + Args: + text: Text with potential markdown tags + + Returns: + Cleaned text without markdown tags + """ + if not text: + return "" + + cleaned = text.replace("```markdown\n", "").replace("\n```", "") + cleaned = cleaned.replace("```markdown", "").replace("```", "") + return cleaned + + +def extract_html_from_content(content: str) -> str: + """Attempt to extract HTML content from a response that might be wrapped in other formats. + + Args: + content: The content to extract HTML from + + Returns: + The extracted HTML, or a basic fallback if extraction fails + """ + if not content: + return "" + + # Try to find HTML between tags + if "" in content: + start = content.find("") + 7 # Include the closing tag + return content[start:end] + + # Try to find div class="research-report" + if '
    " in content: + start = content.find('
    ") + if last_div > start: + return content[start : last_div + 6] # Include the closing tag + + # Look for code blocks + if "```html" in content and "```" in content: + start = content.find("```html") + 7 + end = content.find("```", start) + if end > start: + return content[start:end].strip() + + # Look for JSON with an "html" field + try: + parsed = json.loads(content) + if isinstance(parsed, dict) and "html" in parsed: + return parsed["html"] + except: + pass + + # If all extraction attempts fail, return the original content + return content + + +def safe_json_loads(json_str: Optional[str]) -> Dict[str, Any]: + """Safely parse JSON string. + + Args: + json_str: JSON string to parse, can be None. + + Returns: + Dict[str, Any]: Parsed JSON as dictionary or empty dict if parsing fails or input is None. + """ + if json_str is None: + # Optionally, log a warning here if None input is unexpected for certain call sites + # logger.warning("safe_json_loads received None input.") + return {} + try: + return json.loads(json_str) + except ( + JSONDecodeError, + TypeError, + ): # Catch TypeError if json_str is not a valid type for json.loads + # Optionally, log the error and the problematic string (or its beginning) + # logger.warning(f"Failed to decode JSON string: '{str(json_str)[:200]}...'", exc_info=True) + return {} + + +def load_pipeline_config(config_path: str) -> Dict[str, Any]: + """Load pipeline configuration from YAML file. + + This is used only for pipeline-level configuration, not for step parameters. + Step parameters should be defined directly in the step functions. + + Args: + config_path: Path to the configuration YAML file + + Returns: + Pipeline configuration dictionary + """ + # Get absolute path if relative + if not os.path.isabs(config_path): + base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + config_path = os.path.join(base_dir, config_path) + + # Load YAML configuration + try: + with open(config_path, "r") as f: + config = yaml.safe_load(f) + return config + except Exception as e: + logger.error(f"Error loading pipeline configuration: {e}") + # Return a minimal default configuration in case of loading error + return { + "pipeline": { + "name": "deep_research_pipeline", + "enable_cache": True, + }, + "environment": { + "docker": { + "requirements": [ + "openai>=1.0.0", + "tavily-python>=0.2.8", + "PyYAML>=6.0", + "click>=8.0.0", + "pydantic>=2.0.0", + "typing_extensions>=4.0.0", + ] + } + }, + "resources": {"cpu": 1, "memory": "4Gi"}, + "timeout": 3600, + } + + +def check_required_env_vars(env_vars: list[str]) -> list[str]: + """Check if required environment variables are set. + + Args: + env_vars: List of environment variable names to check + + Returns: + List of missing environment variables + """ + missing_vars = [] + for var in env_vars: + if not os.environ.get(var): + missing_vars.append(var) + return missing_vars diff --git a/deep_research/utils/llm_utils.py b/deep_research/utils/llm_utils.py new file mode 100644 index 00000000..1f4dc194 --- /dev/null +++ b/deep_research/utils/llm_utils.py @@ -0,0 +1,387 @@ +import contextlib +import json +import logging +from typing import Any, Dict, List, Optional + +import litellm +from litellm import completion +from utils.helper_functions import ( + clean_json_tags, + remove_reasoning_from_output, + safe_json_loads, +) +from utils.prompts import SYNTHESIS_PROMPT +from zenml import get_step_context + +logger = logging.getLogger(__name__) + +# This module uses litellm for all LLM interactions +# Models are specified with a provider prefix (e.g., "sambanova/DeepSeek-R1-Distill-Llama-70B") +# ALL model names require a provider prefix (e.g., "sambanova/", "openai/", "anthropic/") + +litellm.callbacks = ["langfuse"] + + +def run_llm_completion( + prompt: str, + system_prompt: str, + model: str = "sambanova/Llama-4-Maverick-17B-128E-Instruct", + clean_output: bool = True, + max_tokens: int = 2000, # Increased default token limit + temperature: float = 0.2, + top_p: float = 0.9, + project: str = "deep-research", + tags: Optional[List[str]] = None, +) -> str: + """Run an LLM completion with standard error handling and output cleaning. + + Uses litellm for model inference. + + Args: + prompt: User prompt for the LLM + system_prompt: System prompt for the LLM + model: Model to use for completion (with provider prefix) + clean_output: Whether to clean reasoning and JSON tags from output. When True, + this removes any reasoning sections marked with tags and strips JSON + code block markers. + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + top_p: Top-p sampling value + project: Langfuse project name for LLM tracking + tags: Optional list of tags for Langfuse tracking. If provided, also converted to trace_metadata format. + + Returns: + str: Processed LLM output with optional cleaning applied + """ + try: + # Ensure model name has provider prefix + if not any( + model.startswith(prefix + "/") + for prefix in [ + "sambanova", + "openai", + "anthropic", + "meta", + "google", + "aws", + "openrouter", + ] + ): + # Raise an error if no provider prefix is specified + error_msg = f"Model '{model}' does not have a provider prefix. Please specify provider (e.g., 'sambanova/{model}')" + logger.error(error_msg) + raise ValueError(error_msg) + + # Get pipeline run name and id for trace_name and trace_id if running in a step + trace_name = None + trace_id = None + with contextlib.suppress(RuntimeError): + context = get_step_context() + trace_name = context.pipeline_run.name + trace_id = str(context.pipeline_run.id) + # Build metadata dict + metadata = {"project": project} + if tags is not None: + metadata["tags"] = tags + # Convert tags to trace_metadata format + metadata["trace_metadata"] = {tag: True for tag in tags} + if trace_name: + metadata["trace_name"] = trace_name + if trace_id: + metadata["trace_id"] = trace_id + + response = completion( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ], + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + metadata=metadata, + ) + + # Defensive access to content + content = None + if response and response.choices and len(response.choices) > 0: + choice = response.choices[0] + if choice and choice.message: + content = choice.message.content + + if content is None: + logger.warning("LLM response content is missing or empty.") + return "" + + if clean_output: + content = remove_reasoning_from_output(content) + content = clean_json_tags(content) + + return content + except Exception as e: + logger.error(f"Error in LLM completion: {e}") + return "" + + +def get_structured_llm_output( + prompt: str, + system_prompt: str, + model: str = "sambanova/Llama-4-Maverick-17B-128E-Instruct", + fallback_response: Optional[Dict[str, Any]] = None, + max_tokens: int = 2000, # Increased default token limit for structured outputs + temperature: float = 0.2, + top_p: float = 0.9, + project: str = "deep-research", + tags: Optional[List[str]] = None, +) -> Dict[str, Any]: + """Get structured JSON output from an LLM with error handling. + + Uses litellm for model inference. + + Args: + prompt: User prompt for the LLM + system_prompt: System prompt for the LLM + model: Model to use for completion (with provider prefix) + fallback_response: Fallback response if parsing fails + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + top_p: Top-p sampling value + project: Langfuse project name for LLM tracking + tags: Optional list of tags for Langfuse tracking. Defaults to ["structured_llm_output"] if None. + + Returns: + Parsed JSON response or fallback + """ + try: + # Use provided tags or default to ["structured_llm_output"] + if tags is None: + tags = ["structured_llm_output"] + + content = run_llm_completion( + prompt=prompt, + system_prompt=system_prompt, + model=model, + clean_output=True, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + project=project, + tags=tags, + ) + + if not content: + logger.warning("Empty content returned from LLM") + return fallback_response if fallback_response is not None else {} + + result = safe_json_loads(content) + + if not result and fallback_response is not None: + return fallback_response + + return result + except Exception as e: + logger.error(f"Error processing structured LLM output: {e}") + return fallback_response if fallback_response is not None else {} + + +def is_text_relevant(text1: str, text2: str, min_word_length: int = 4) -> bool: + """Determine if two pieces of text are relevant to each other. + + Relevance is determined by checking if one text is contained within the other, + or if they share significant words (words longer than min_word_length). + This is a simple heuristic approach that checks for: + 1. Complete containment (one text string inside the other) + 2. Shared significant words (words longer than min_word_length) + + Args: + text1: First text to compare + text2: Second text to compare + min_word_length: Minimum length of words to check for shared content + + Returns: + bool: True if the texts are deemed relevant to each other based on the criteria + """ + if not text1 or not text2: + return False + + return ( + text1.lower() in text2.lower() + or text2.lower() in text1.lower() + or any( + word + for word in text1.lower().split() + if len(word) > min_word_length and word in text2.lower() + ) + ) + + +def find_most_relevant_string( + target: str, + options: List[str], + model: Optional[str] = "sambanova/Llama-4-Maverick-17B-128E-Instruct", + project: str = "deep-research", + tags: Optional[List[str]] = None, +) -> Optional[str]: + """Find the most relevant string from a list of options using simple text matching. + + If model is provided, uses litellm to determine relevance. + + Args: + target: The target string to find relevance for + options: List of string options to check against + model: Model to use for matching (with provider prefix) + project: Langfuse project name for LLM tracking + tags: Optional list of tags for Langfuse tracking. Defaults to ["find_most_relevant_string"] if None. + + Returns: + The most relevant string, or None if no relevant options + """ + if not options: + return None + + if len(options) == 1: + return options[0] + + # If model is provided, use litellm for more accurate matching + if model: + try: + # Ensure model name has provider prefix + if not any( + model.startswith(prefix + "/") + for prefix in [ + "sambanova", + "openai", + "anthropic", + "meta", + "google", + "aws", + "openrouter", + ] + ): + # Raise an error if no provider prefix is specified + error_msg = f"Model '{model}' does not have a provider prefix. Please specify provider (e.g., 'sambanova/{model}')" + logger.error(error_msg) + raise ValueError(error_msg) + + system_prompt = "You are a research assistant." + prompt = f"""Given the text: "{target}" +Which of the following options is most relevant to this text? +{options} + +Respond with only the exact text of the most relevant option.""" + + # Get pipeline run name and id for trace_name and trace_id if running in a step + trace_name = None + trace_id = None + try: + context = get_step_context() + trace_name = context.pipeline_run.name + trace_id = str(context.pipeline_run.id) + except RuntimeError: + # Not running in a step context + pass + + # Use provided tags or default to ["find_most_relevant_string"] + if tags is None: + tags = ["find_most_relevant_string"] + + # Build metadata dict + metadata = {"project": project, "tags": tags} + # Convert tags to trace_metadata format + metadata["trace_metadata"] = {tag: True for tag in tags} + if trace_name: + metadata["trace_name"] = trace_name + if trace_id: + metadata["trace_id"] = trace_id + + response = completion( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ], + max_tokens=100, + temperature=0.2, + metadata=metadata, + ) + + answer = response.choices[0].message.content.strip() + + # Check if the answer is one of the options + if answer in options: + return answer + + # If not an exact match, find the closest one + for option in options: + if option in answer or answer in option: + return option + + except Exception as e: + logger.error(f"Error finding relevant string with LLM: {e}") + + # Simple relevance check - find exact matches first + for option in options: + if target.lower() == option.lower(): + return option + + # Then check partial matches + for option in options: + if is_text_relevant(target, option): + return option + + # Return the first option as a fallback + return options[0] + + +def synthesize_information( + synthesis_input: Dict[str, Any], + model: str = "sambanova/Llama-4-Maverick-17B-128E-Instruct", + system_prompt: Optional[str] = None, + project: str = "deep-research", + tags: Optional[List[str]] = None, +) -> Dict[str, Any]: + """Synthesize information from search results for a sub-question. + + Uses litellm for model inference. + + Args: + synthesis_input: Dictionary with sub-question, search results, and sources + model: Model to use (with provider prefix) + system_prompt: System prompt for the LLM + project: Langfuse project name for LLM tracking + tags: Optional list of tags for Langfuse tracking. Defaults to ["information_synthesis"] if None. + + Returns: + Dictionary with synthesized information + """ + if system_prompt is None: + system_prompt = SYNTHESIS_PROMPT + + sub_question_for_log = synthesis_input.get( + "sub_question", "unknown question" + ) + + # Define the fallback response + fallback_response = { + "synthesized_answer": f"Synthesis failed for '{sub_question_for_log}'.", + "key_sources": synthesis_input.get("sources", [])[:1], + "confidence_level": "low", + "information_gaps": "An error occurred during the synthesis process.", + } + + # Use provided tags or default to ["information_synthesis"] + if tags is None: + tags = ["information_synthesis"] + + # Use the utility function to get structured output + result = get_structured_llm_output( + prompt=json.dumps(synthesis_input), + system_prompt=system_prompt, + model=model, + fallback_response=fallback_response, + max_tokens=3000, # Increased for more detailed synthesis + project=project, + tags=tags, + ) + + return result diff --git a/deep_research/utils/prompt_loader.py b/deep_research/utils/prompt_loader.py new file mode 100644 index 00000000..c2b3d2f7 --- /dev/null +++ b/deep_research/utils/prompt_loader.py @@ -0,0 +1,136 @@ +"""Utility functions for loading prompts into the PromptsBundle model. + +This module provides functions to create PromptsBundle instances from +the existing prompt definitions in prompts.py. +""" + +from utils import prompts +from utils.prompt_models import PromptsBundle, PromptTemplate + + +def load_prompts_bundle(pipeline_version: str = "1.2.0") -> PromptsBundle: + """Load all prompts from prompts.py into a PromptsBundle. + + Args: + pipeline_version: Version of the pipeline using these prompts + + Returns: + PromptsBundle containing all prompts + """ + # Create PromptTemplate instances for each prompt + search_query_prompt = PromptTemplate( + name="search_query_prompt", + content=prompts.DEFAULT_SEARCH_QUERY_PROMPT, + description="Generates effective search queries from sub-questions", + version="1.0.0", + tags=["search", "query", "information-gathering"], + ) + + query_decomposition_prompt = PromptTemplate( + name="query_decomposition_prompt", + content=prompts.QUERY_DECOMPOSITION_PROMPT, + description="Breaks down complex research queries into specific sub-questions", + version="1.0.0", + tags=["analysis", "decomposition", "planning"], + ) + + synthesis_prompt = PromptTemplate( + name="synthesis_prompt", + content=prompts.SYNTHESIS_PROMPT, + description="Synthesizes search results into comprehensive answers for sub-questions", + version="1.1.0", + tags=["synthesis", "integration", "analysis"], + ) + + viewpoint_analysis_prompt = PromptTemplate( + name="viewpoint_analysis_prompt", + content=prompts.VIEWPOINT_ANALYSIS_PROMPT, + description="Analyzes synthesized answers across different perspectives and viewpoints", + version="1.1.0", + tags=["analysis", "viewpoint", "perspective"], + ) + + reflection_prompt = PromptTemplate( + name="reflection_prompt", + content=prompts.REFLECTION_PROMPT, + description="Evaluates research and identifies gaps, biases, and areas for improvement", + version="1.0.0", + tags=["reflection", "critique", "improvement"], + ) + + additional_synthesis_prompt = PromptTemplate( + name="additional_synthesis_prompt", + content=prompts.ADDITIONAL_SYNTHESIS_PROMPT, + description="Enhances original synthesis with new information and addresses critique points", + version="1.1.0", + tags=["synthesis", "enhancement", "integration"], + ) + + conclusion_generation_prompt = PromptTemplate( + name="conclusion_generation_prompt", + content=prompts.CONCLUSION_GENERATION_PROMPT, + description="Synthesizes all research findings into a comprehensive conclusion", + version="1.0.0", + tags=["report", "conclusion", "synthesis"], + ) + + executive_summary_prompt = PromptTemplate( + name="executive_summary_prompt", + content=prompts.EXECUTIVE_SUMMARY_GENERATION_PROMPT, + description="Creates a compelling, insight-driven executive summary", + version="1.1.0", + tags=["report", "summary", "insights"], + ) + + introduction_prompt = PromptTemplate( + name="introduction_prompt", + content=prompts.INTRODUCTION_GENERATION_PROMPT, + description="Creates a contextual, engaging introduction", + version="1.1.0", + tags=["report", "introduction", "context"], + ) + + # Create and return the bundle + return PromptsBundle( + search_query_prompt=search_query_prompt, + query_decomposition_prompt=query_decomposition_prompt, + synthesis_prompt=synthesis_prompt, + viewpoint_analysis_prompt=viewpoint_analysis_prompt, + reflection_prompt=reflection_prompt, + additional_synthesis_prompt=additional_synthesis_prompt, + conclusion_generation_prompt=conclusion_generation_prompt, + executive_summary_prompt=executive_summary_prompt, + introduction_prompt=introduction_prompt, + pipeline_version=pipeline_version, + ) + + +def get_prompt_for_step(bundle: PromptsBundle, step_name: str) -> str: + """Get the appropriate prompt content for a specific step. + + Args: + bundle: The PromptsBundle containing all prompts + step_name: Name of the step requesting the prompt + + Returns: + The prompt content string + + Raises: + ValueError: If no prompt mapping exists for the step + """ + # Map step names to prompt attributes + step_to_prompt_mapping = { + "query_decomposition": "query_decomposition_prompt", + "search_query_generation": "search_query_prompt", + "synthesis": "synthesis_prompt", + "viewpoint_analysis": "viewpoint_analysis_prompt", + "reflection": "reflection_prompt", + "additional_synthesis": "additional_synthesis_prompt", + "conclusion_generation": "conclusion_generation_prompt", + } + + prompt_attr = step_to_prompt_mapping.get(step_name) + if not prompt_attr: + raise ValueError(f"No prompt mapping found for step: {step_name}") + + return bundle.get_prompt_content(prompt_attr) diff --git a/deep_research/utils/prompt_models.py b/deep_research/utils/prompt_models.py new file mode 100644 index 00000000..c96a7bd0 --- /dev/null +++ b/deep_research/utils/prompt_models.py @@ -0,0 +1,123 @@ +"""Pydantic models for prompt tracking and management. + +This module contains models for bundling prompts as trackable artifacts +in the ZenML pipeline, enabling better observability and version control. +""" + +from datetime import datetime +from typing import Dict, Optional + +from pydantic import BaseModel, Field + + +class PromptTemplate(BaseModel): + """Represents a single prompt template with metadata.""" + + name: str = Field(..., description="Unique identifier for the prompt") + content: str = Field(..., description="The actual prompt template content") + description: str = Field( + "", description="Human-readable description of what this prompt does" + ) + version: str = Field("1.0.0", description="Version of the prompt template") + tags: list[str] = Field( + default_factory=list, description="Tags for categorizing prompts" + ) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +class PromptsBundle(BaseModel): + """Bundle of all prompts used in the research pipeline. + + This model serves as a single artifact that contains all prompts, + making them trackable, versionable, and visualizable in the ZenML dashboard. + """ + + # Core prompts used in the pipeline + search_query_prompt: PromptTemplate + query_decomposition_prompt: PromptTemplate + synthesis_prompt: PromptTemplate + viewpoint_analysis_prompt: PromptTemplate + reflection_prompt: PromptTemplate + additional_synthesis_prompt: PromptTemplate + conclusion_generation_prompt: PromptTemplate + executive_summary_prompt: PromptTemplate + introduction_prompt: PromptTemplate + + # Metadata + pipeline_version: str = Field( + "1.0.0", description="Version of the pipeline using these prompts" + ) + created_at: str = Field( + default_factory=lambda: datetime.now().isoformat(), + description="Timestamp when this bundle was created", + ) + + # Additional prompts can be stored here + custom_prompts: Dict[str, PromptTemplate] = Field( + default_factory=dict, + description="Additional custom prompts not part of the core set", + ) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + def get_prompt_by_name(self, name: str) -> Optional[PromptTemplate]: + """Retrieve a prompt by its name. + + Args: + name: Name of the prompt to retrieve + + Returns: + PromptTemplate if found, None otherwise + """ + # Check core prompts + for field_name, field_value in self.__dict__.items(): + if ( + isinstance(field_value, PromptTemplate) + and field_value.name == name + ): + return field_value + + # Check custom prompts + return self.custom_prompts.get(name) + + def list_all_prompts(self) -> Dict[str, PromptTemplate]: + """Get all prompts as a dictionary. + + Returns: + Dictionary mapping prompt names to PromptTemplate objects + """ + all_prompts = {} + + # Add core prompts + for field_name, field_value in self.__dict__.items(): + if isinstance(field_value, PromptTemplate): + all_prompts[field_value.name] = field_value + + # Add custom prompts + all_prompts.update(self.custom_prompts) + + return all_prompts + + def get_prompt_content(self, prompt_type: str) -> str: + """Get the content of a specific prompt by its type. + + Args: + prompt_type: Type of prompt (e.g., 'search_query_prompt', 'synthesis_prompt') + + Returns: + The prompt content string + + Raises: + AttributeError: If prompt type doesn't exist + """ + prompt = getattr(self, prompt_type) + return prompt.content diff --git a/deep_research/utils/prompts.py b/deep_research/utils/prompts.py new file mode 100644 index 00000000..e072b238 --- /dev/null +++ b/deep_research/utils/prompts.py @@ -0,0 +1,1605 @@ +""" +Centralized collection of prompts used throughout the deep research pipeline. + +This module contains all system prompts used by LLM calls in various steps of the +research pipeline to ensure consistency and make prompt management easier. +""" + +# Search query generation prompt +# Used to generate effective search queries from sub-questions +DEFAULT_SEARCH_QUERY_PROMPT = """ +You are a Deep Research assistant. Given a specific research sub-question, your task is to formulate an effective search +query that will help find relevant information to answer the question. + +A good search query should: +1. Extract the key concepts from the sub-question +2. Use precise, specific terminology +3. Exclude unnecessary words or context +4. Include alternative terms or synonyms when helpful +5. Be concise yet comprehensive enough to find relevant results + +Format the output in json with the following json schema definition: + + +{ + "type": "object", + "properties": { + "search_query": {"type": "string"}, + "reasoning": {"type": "string"} + } +} + + +Make sure that the output is a json object with an output json schema defined above. +Only return the json object, no explanation or additional text. +""" + +# Query decomposition prompt +# Used to break down complex research queries into specific sub-questions +QUERY_DECOMPOSITION_PROMPT = """ +You are a Deep Research assistant specializing in research design. You will be given a MAIN RESEARCH QUERY that needs to be explored comprehensively. Your task is to create diverse, insightful sub-questions that explore different dimensions of the topic. + +IMPORTANT: The main query should be interpreted as a single research question, not as a noun phrase. For example: +- If the query is "Is LLMOps a subset of MLOps?", create questions ABOUT LLMOps and MLOps, not questions like "What is 'Is LLMOps a subset of MLOps?'" +- Focus on the concepts, relationships, and implications within the query + +Create sub-questions that explore these DIFFERENT DIMENSIONS: + +1. **Definitional/Conceptual**: Define key terms and establish conceptual boundaries + Example: "What are the core components and characteristics of LLMOps?" + +2. **Comparative/Relational**: Compare and contrast the concepts mentioned + Example: "How do the workflows and tooling of LLMOps differ from traditional MLOps?" + +3. **Historical/Evolutionary**: Trace development and emergence + Example: "How did LLMOps emerge from MLOps practices?" + +4. **Structural/Technical**: Examine technical architecture and implementation + Example: "What specific tools and platforms are unique to LLMOps?" + +5. **Practical/Use Cases**: Explore real-world applications + Example: "What are the key use cases that require LLMOps but not traditional MLOps?" + +6. **Stakeholder/Industry**: Consider different perspectives and adoption + Example: "How are different industries adopting LLMOps vs MLOps?" + +7. **Challenges/Limitations**: Identify problems and constraints + Example: "What unique challenges does LLMOps face that MLOps doesn't?" + +8. **Future/Trends**: Look at emerging developments + Example: "How is the relationship between LLMOps and MLOps expected to evolve?" + +QUALITY GUIDELINES: +- Each sub-question must explore a DIFFERENT dimension - no repetitive variations +- Questions should be specific, concrete, and investigable +- Mix descriptive ("what/who") with analytical ("why/how") questions +- Ensure questions build toward answering the main query comprehensively +- Frame questions to elicit detailed, nuanced responses +- Consider technical, business, organizational, and strategic aspects + +Format the output in json with the following json schema definition: + + +{ + "type": "array", + "items": { + "type": "object", + "properties": { + "sub_question": {"type": "string"}, + "reasoning": {"type": "string"} + } + } +} + + +Make sure that the output is a json object with an output json schema defined above. +Only return the json object, no explanation or additional text. +""" + +# Synthesis prompt for individual sub-questions +# Used to synthesize search results into comprehensive answers for sub-questions +SYNTHESIS_PROMPT = """ +You are a Deep Research assistant specializing in information synthesis. Given a sub-question and search results, your task is to synthesize the information +into a comprehensive, accurate, and well-structured answer. + +Your synthesis should: +1. Begin with a direct, concise answer to the sub-question in the first paragraph +2. Provide detailed evidence and explanation in subsequent paragraphs (at least 3-5 paragraphs total) +3. Integrate information from multiple sources, citing them within your answer +4. Acknowledge any conflicting information or contrasting viewpoints you encounter +5. Use data, statistics, examples, and quotations when available to strengthen your answer +6. Organize information logically with a clear flow between concepts +7. Identify key sources that provided the most valuable information (at least 2-3 sources) +8. Explicitly acknowledge information gaps where the search results were incomplete +9. Write in plain text format - do NOT use markdown formatting, bullet points, or special characters + +Confidence level criteria: +- HIGH: Multiple high-quality sources provide consistent information, comprehensive coverage of the topic, and few information gaps +- MEDIUM: Decent sources with some consistency, but notable information gaps or some conflicting information +- LOW: Limited sources, major information gaps, significant contradictions, or only tangentially relevant information + +Information gaps should specifically identify: +1. Aspects of the question that weren't addressed in the search results +2. Areas where more detailed or up-to-date information would be valuable +3. Perspectives or data sources that would complement the existing information + +Format the output in json with the following json schema definition: + + +{ + "type": "object", + "properties": { + "synthesized_answer": {"type": "string"}, + "key_sources": { + "type": "array", + "items": {"type": "string"} + }, + "confidence_level": {"type": "string", "enum": ["high", "medium", "low"]}, + "information_gaps": {"type": "string"}, + "improvements": { + "type": "array", + "items": {"type": "string"} + } + } +} + + +Make sure that the output is a json object with an output json schema defined above. +Only return the json object, no explanation or additional text. +""" + +# Viewpoint analysis prompt for cross-perspective examination +# Used to analyze synthesized answers across different perspectives and viewpoints +VIEWPOINT_ANALYSIS_PROMPT = """ +You are a Deep Research assistant specializing in multi-perspective analysis. You will be given a set of synthesized answers +to sub-questions related to a main research query. Your task is to perform a thorough, nuanced analysis of how different +perspectives would interpret this information. + +Think deeply about the following viewpoint categories and how they would approach the information differently: +- Scientific: Evidence-based, empirical approach focused on data, research findings, and methodological rigor +- Political: Power dynamics, governance structures, policy implications, and ideological frameworks +- Economic: Resource allocation, financial impacts, market dynamics, and incentive structures +- Social: Cultural norms, community impacts, group dynamics, and public welfare +- Ethical: Moral principles, values considerations, rights and responsibilities, and normative judgments +- Historical: Long-term patterns, precedents, contextual development, and evolutionary change + +For each synthesized answer, analyze how these different perspectives would interpret the information by: + +1. Identifying 5-8 main points of agreement where multiple perspectives align (with specific examples) +2. Analyzing at least 3-5 areas of tension between perspectives with: + - A clear topic title for each tension point + - Contrasting interpretations from at least 2-3 different viewpoint categories per tension + - Specific examples or evidence showing why these perspectives differ + - The nuanced positions of each perspective, not just simplified oppositions + +3. Thoroughly examining perspective gaps by identifying: + - Which perspectives are underrepresented or missing in the current research + - How including these missing perspectives would enrich understanding + - Specific questions or dimensions that remain unexplored + - Write in plain text format - do NOT use markdown formatting, bullet points, or special characters + +4. Developing integrative insights that: + - Synthesize across multiple perspectives to form a more complete understanding + - Highlight how seemingly contradictory viewpoints can complement each other + - Suggest frameworks for reconciling tensions or finding middle-ground approaches + - Identify actionable takeaways that incorporate multiple perspectives + - Write in plain text format - do NOT use markdown formatting, bullet points, or special characters + +Format the output in json with the following json schema definition: + + +{ + "type": "object", + "properties": { + "main_points_of_agreement": { + "type": "array", + "items": {"type": "string"} + }, + "areas_of_tension": { + "type": "array", + "items": { + "type": "object", + "properties": { + "topic": {"type": "string"}, + "viewpoints": { + "type": "object", + "additionalProperties": {"type": "string"} + } + } + } + }, + "perspective_gaps": {"type": "string"}, + "integrative_insights": {"type": "string"} + } +} + + +Make sure that the output is a json object with an output json schema defined above. +Only return the json object, no explanation or additional text. +""" + +# Reflection prompt for self-critique and improvement +# Used to evaluate the research and identify gaps, biases, and areas for improvement +REFLECTION_PROMPT = """ +You are a Deep Research assistant with the ability to critique and improve your own research. You will be given: +1. The main research query +2. The sub-questions explored so far +3. The synthesized information for each sub-question +4. Any viewpoint analysis performed + +Your task is to critically evaluate this research and identify: +1. Areas where the research is incomplete or has gaps +2. Questions that are important but not yet answered +3. Aspects where additional evidence or depth would significantly improve the research +4. Potential biases or limitations in the current findings + +Be constructively critical and identify the most important improvements that would substantially enhance the research. + +Format the output in json with the following json schema definition: + + +{ + "type": "object", + "properties": { + "critique": { + "type": "array", + "items": { + "type": "object", + "properties": { + "area": {"type": "string"}, + "issue": {"type": "string"}, + "importance": {"type": "string", "enum": ["high", "medium", "low"]} + } + } + }, + "additional_questions": { + "type": "array", + "items": {"type": "string"} + }, + "recommended_search_queries": { + "type": "array", + "items": {"type": "string"} + } + } +} + + +Make sure that the output is a json object with an output json schema defined above. +Only return the json object, no explanation or additional text. +""" + +# Additional synthesis prompt for incorporating new information +# Used to enhance original synthesis with new information and address critique points +ADDITIONAL_SYNTHESIS_PROMPT = """ +You are a Deep Research assistant. You will be given: +1. The original synthesized information on a research topic +2. New information from additional research +3. A critique of the original synthesis + +Your task is to enhance the original synthesis by incorporating the new information and addressing the critique. +The updated synthesis should: +1. Integrate new information seamlessly +2. Address gaps identified in the critique +3. Maintain a balanced, comprehensive, and accurate representation +4. Preserve the strengths of the original synthesis +5. Write in plain text format - do NOT use markdown formatting, bullet points, or special characters + +Format the output in json with the following json schema definition: + + +{ + "type": "object", + "properties": { + "enhanced_synthesis": {"type": "string"}, + "improvements_made": { + "type": "array", + "items": {"type": "string"} + }, + "remaining_limitations": {"type": "string"} + } +} + + +Make sure that the output is a json object with an output json schema defined above. +Only return the json object, no explanation or additional text. +""" + +# Final report generation prompt +# Used to compile a comprehensive HTML research report from all synthesized information +REPORT_GENERATION_PROMPT = """ +You are a Deep Research assistant responsible for compiling an in-depth, comprehensive research report. You will be given: +1. The original research query +2. The sub-questions that were explored +3. Synthesized information for each sub-question +4. Viewpoint analysis comparing different perspectives (if available) +5. Reflection metadata highlighting improvements and limitations + +Your task is to create a well-structured, coherent, professional-quality research report with the following features: + +EXECUTIVE SUMMARY (250-400 words): +- Begin with a compelling, substantive executive summary that provides genuine insight +- Highlight 3-5 key findings or insights that represent the most important discoveries +- Include brief mention of methodology and limitations +- Make the summary self-contained so it can be read independently of the full report +- End with 1-2 sentences on broader implications or applications of the research + +INTRODUCTION (200-300 words): +- Provide relevant background context on the main research query +- Explain why this topic is significant or worth investigating +- Outline the methodological approach used (sub-questions, search strategy, synthesis) +- Preview the overall structure of the report + +SUB-QUESTION SECTIONS: +- For each sub-question, create a dedicated section with: + * A descriptive section title (not just repeating the sub-question) + * A brief (1 paragraph) overview of key findings for this sub-question + * A "Key Findings" box highlighting 3-4 important discoveries for scannable reading + * The detailed, synthesized answer with appropriate paragraph breaks, lists, and formatting + * Proper citation of sources within the text (e.g., "According to [Source Name]...") + * Clear confidence indicator with appropriate styling + * Information gaps clearly identified in their own subsection + * Complete list of key sources used + +VIEWPOINT ANALYSIS SECTION (if available): +- Create a detailed section that: + * Explains the purpose and value of multi-perspective analysis + * Presents points of agreement as actionable insights, not just observations + * Structures tension areas with clear topic headings and balanced presentation of viewpoints + * Uses visual elements (different background colors, icons) to distinguish different perspectives + * Integrates perspective gaps and insights into a cohesive narrative + +CONCLUSION (300-400 words): +- Synthesize the overall findings, not just summarizing each section +- Connect insights from different sub-questions to form higher-level understanding +- Address the main research query directly with evidence-based conclusions +- Acknowledge remaining uncertainties and suggestions for further research +- End with implications or applications of the research findings + +OVERALL QUALITY REQUIREMENTS: +1. Create visually scannable content with clear headings, bullet points, and short paragraphs +2. Use semantic HTML (h1, h2, h3, p, blockquote, etc.) to create proper document structure +3. Include a comprehensive table of contents with anchor links to all major sections +4. Format all sources consistently in the references section with proper linking when available +5. Use tables, lists, and blockquotes to improve readability and highlight important information +6. Apply appropriate styling for different confidence levels (high, medium, low) +7. Ensure proper HTML nesting and structure throughout the document +8. Balance sufficient detail with clarity and conciseness +9. Make all text directly actionable and insight-driven, not just descriptive + +The report should be formatted in HTML with appropriate headings, paragraphs, citations, and formatting. +Use semantic HTML (h1, h2, h3, p, blockquote, etc.) to create a structured document. +Include a table of contents at the beginning with anchor links to each section. +For citations, use a consistent format and collect them in a references section at the end. + +Include this exact CSS stylesheet in your HTML to ensure consistent styling (do not modify it): + +```css + +``` + +The HTML structure should follow this pattern: + +```html + + + + + + [CSS STYLESHEET GOES HERE] + + +
    +

    Research Report: [Main Query]

    + + +
    +

    Table of Contents

    + +
    + + +
    +

    Executive Summary

    + [CONCISE SUMMARY OF KEY FINDINGS] +
    + + +
    +

    Introduction

    +

    [INTRODUCTION TO THE RESEARCH QUERY]

    +

    [OVERVIEW OF THE APPROACH AND SUB-QUESTIONS]

    +
    + + + [FOR EACH SUB-QUESTION]: +
    +

    [INDEX]. [SUB-QUESTION TEXT]

    +

    Confidence Level: [LEVEL]

    + + +
    +

    Key Findings

    +
      +
    • [KEY FINDING 1]
    • +
    • [KEY FINDING 2]
    • + [...] +
    +
    + +
    + [DETAILED ANSWER] +
    + + +
    +

    Information Gaps

    +

    [GAPS TEXT]

    +
    + + +
    +

    Key Sources

    +
      +
    • [SOURCE 1]
    • +
    • [SOURCE 2]
    • + [...] +
    +
    +
    + + +
    +

    Viewpoint Analysis

    + +

    Points of Agreement

    +
    +
      +
    • [AGREEMENT 1]
    • +
    • [AGREEMENT 2]
    • + [...] +
    +
    + +

    Areas of Tension

    + [FOR EACH TENSION]: +
    +

    [TENSION TOPIC]

    +
    +
    [VIEWPOINT 1 TITLE]
    +
    [VIEWPOINT 1 CONTENT]
    +
    [VIEWPOINT 2 TITLE]
    +
    [VIEWPOINT 2 CONTENT]
    + [...] +
    +
    + +

    Perspective Gaps

    +

    [PERSPECTIVE GAPS CONTENT]

    + +

    Integrative Insights

    +

    [INTEGRATIVE INSIGHTS CONTENT]

    +
    + + +
    +

    Conclusion

    +

    [CONCLUSION TEXT]

    +
    + + +
    +

    References

    +
      +
    • [REFERENCE 1]
    • +
    • [REFERENCE 2]
    • + [...] +
    +
    +
    + + +``` + +Special instructions: +1. For each sub-question, display the confidence level with appropriate styling (confidence-high, confidence-medium, or confidence-low) +2. Extract 2-3 key findings from each answer to create the key-findings box +3. Format all sources consistently in the references section +4. Use tables, lists, and blockquotes where appropriate to improve readability +5. Use the notice classes (info, warning) to highlight important information or limitations +6. Ensure all sections have proper ID attributes for the table of contents links + +Return only the complete HTML code for the report, with no explanations or additional text. +""" + +# Static HTML template for direct report generation without LLM +STATIC_HTML_TEMPLATE = """ + + + + + Research Report: {main_query} + + + +
    +

    Research Report: {main_query}

    + + +
    +

    Table of Contents

    + +
    + + +
    +

    Executive Summary

    +

    {executive_summary}

    +
    + + +
    +

    Introduction

    + {introduction_html} +
    + + + {sub_questions_html} + + + {viewpoint_analysis_html} + + +
    +

    Conclusion

    + {conclusion_html} +
    + + +
    +

    References

    + {references_html} +
    +
    + + +""" + +# Template for sub-question section in the static HTML report +SUB_QUESTION_TEMPLATE = """ +
    +
    +

    {index}. {question}

    + + + {confidence_icon} + + Confidence: {confidence_upper} + +
    + +
    +

    {answer}

    +
    + + {info_gaps_html} + + {key_sources_html} +
    +""" + +# Template for viewpoint analysis section in the static HTML report +VIEWPOINT_ANALYSIS_TEMPLATE = """ +
    +

    Viewpoint Analysis

    + +
    +

    🤝 Points of Agreement

    +
    +
      + {agreements_html} +
    +
    +
    + +
    +

    ⚖️ Areas of Tension

    +
    + {tensions_html} +
    +
    + +
    +

    🔍 Perspective Gaps

    +
    +

    {perspective_gaps}

    +
    +
    + +
    +

    💡 Integrative Insights

    +
    +

    {integrative_insights}

    +
    +
    +
    +""" + +# Executive Summary generation prompt +# Used to create a compelling, insight-driven executive summary +EXECUTIVE_SUMMARY_GENERATION_PROMPT = """ +You are a Deep Research assistant specializing in creating executive summaries. Given comprehensive research findings, your task is to create a compelling executive summary that captures the essence of the research and its key insights. + +Your executive summary should: + +1. **Opening Statement (1-2 sentences):** + - Start with a powerful, direct answer to the main research question + - Make it clear and definitive based on the evidence gathered + +2. **Key Findings (3-5 bullet points):** + - Extract the MOST IMPORTANT discoveries from across all sub-questions + - Focus on insights that are surprising, actionable, or paradigm-shifting + - Each finding should be specific and evidence-based, not generic + - Prioritize findings that directly address the main query + +3. **Critical Insights (2-3 sentences):** + - Synthesize patterns or themes that emerged across multiple sub-questions + - Highlight any unexpected discoveries or counter-intuitive findings + - Connect disparate findings to reveal higher-level understanding + +4. **Implications (2-3 sentences):** + - What do these findings mean for practitioners/stakeholders? + - What actions or decisions can be made based on this research? + - Why should the reader care about these findings? + +5. **Confidence and Limitations (1-2 sentences):** + - Briefly acknowledge the overall confidence level of the findings + - Note any significant gaps or areas requiring further investigation + +IMPORTANT GUIDELINES: +- Be CONCISE but INSIGHTFUL - every sentence should add value +- Use active voice and strong, definitive language where evidence supports it +- Avoid generic statements - be specific to the actual research findings +- Lead with the most important information +- Make it self-contained - reader should understand key findings without reading the full report +- Target length: 250-400 words + +Format as well-structured HTML paragraphs using

    tags and

      /
    • for bullet points. +""" + +# Introduction generation prompt +# Used to create a contextual, engaging introduction +INTRODUCTION_GENERATION_PROMPT = """ +You are a Deep Research assistant specializing in creating engaging introductions. Given a research query and the sub-questions explored, your task is to create an introduction that provides context and sets up the reader's expectations. + +Your introduction should: + +1. **Context and Relevance (2-3 sentences):** + - Why is this research question important NOW? + - What makes this topic significant or worth investigating? + - Connect to current trends, debates, or challenges in the field + +2. **Scope and Approach (2-3 sentences):** + - What specific aspects of the topic does this research explore? + - Briefly mention the key dimensions covered (based on sub-questions) + - Explain the systematic approach without being too technical + +3. **What to Expect (2-3 sentences):** + - Preview the structure of the report + - Hint at some of the interesting findings or tensions discovered + - Set expectations about the depth and breadth of analysis + +IMPORTANT GUIDELINES: +- Make it engaging - hook the reader's interest from the start +- Provide real context, not generic statements +- Connect to why this matters for the reader +- Keep it concise but informative (200-300 words) +- Use active voice and clear language +- Build anticipation for the findings without giving everything away + +Format as well-structured HTML paragraphs using

      tags. Do NOT include any headings or section titles. +""" + +# Conclusion generation prompt +# Used to synthesize all research findings into a comprehensive conclusion +CONCLUSION_GENERATION_PROMPT = """ +You are a Deep Research assistant specializing in synthesizing comprehensive research conclusions. Given all the research findings from a deep research study, your task is to create a thoughtful, evidence-based conclusion that ties together the overall findings. + +Your conclusion should: + +1. **Synthesis and Integration (150-200 words):** + - Connect insights from different sub-questions to form a higher-level understanding + - Identify overarching themes and patterns that emerge from the research + - Highlight how different findings relate to and support each other + - Avoid simply summarizing each section separately + +2. **Direct Response to Main Query (100-150 words):** + - Address the original research question directly with evidence-based conclusions + - State what the research definitively established vs. what remains uncertain + - Provide a clear, actionable answer based on the synthesized evidence + +3. **Limitations and Future Directions (100-120 words):** + - Acknowledge remaining uncertainties and information gaps across all sections + - Suggest specific areas where additional research would be most valuable + - Identify what types of evidence or perspectives would strengthen the findings + +4. **Implications and Applications (80-100 words):** + - Explain the practical significance of the research findings + - Suggest how the insights might be applied or what they mean for stakeholders + - Connect findings to broader contexts or implications + +Format your output as a well-structured conclusion section in HTML format with appropriate paragraph breaks and formatting. Use

      tags for paragraphs and organize the content logically with clear transitions between the different aspects outlined above. + +IMPORTANT: Do NOT include any headings like "Conclusion",

      , or

      tags - the section already has a heading. Start directly with the conclusion content in paragraph form. Just create flowing, well-structured paragraphs that cover all four aspects naturally. + +Ensure the conclusion feels cohesive and draws meaningful connections between findings rather than just listing them sequentially. +""" diff --git a/deep_research/utils/pydantic_models.py b/deep_research/utils/pydantic_models.py new file mode 100644 index 00000000..9fca23a3 --- /dev/null +++ b/deep_research/utils/pydantic_models.py @@ -0,0 +1,300 @@ +"""Pydantic model definitions for the research pipeline. + +This module contains all the Pydantic models that represent the state of the research +pipeline. These models replace the previous dataclasses implementation and leverage +Pydantic's validation, serialization, and integration with ZenML. +""" + +import time +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field +from typing_extensions import Literal + + +class SearchResult(BaseModel): + """Represents a search result for a sub-question.""" + + url: str = "" + content: str = "" + title: str = "" + snippet: str = "" + metadata: Optional[Dict[str, Any]] = Field(default_factory=dict) + + model_config = { + "extra": "ignore", # Ignore extra fields during deserialization + "frozen": False, # Allow attribute updates + "validate_assignment": True, # Validate when attributes are set + } + + +class ViewpointTension(BaseModel): + """Represents a tension between different viewpoints on a topic.""" + + topic: str = "" + viewpoints: Dict[str, str] = Field(default_factory=dict) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +class SynthesizedInfo(BaseModel): + """Represents synthesized information for a sub-question.""" + + synthesized_answer: str = "" + key_sources: List[str] = Field(default_factory=list) + confidence_level: Literal["high", "medium", "low"] = "medium" + information_gaps: str = "" + improvements: List[str] = Field(default_factory=list) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +class ViewpointAnalysis(BaseModel): + """Represents the analysis of different viewpoints on the research topic.""" + + main_points_of_agreement: List[str] = Field(default_factory=list) + areas_of_tension: List[ViewpointTension] = Field(default_factory=list) + perspective_gaps: str = "" + integrative_insights: str = "" + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +class ReflectionMetadata(BaseModel): + """Metadata about the reflection process.""" + + critique_summary: List[str] = Field(default_factory=list) + additional_questions_identified: List[str] = Field(default_factory=list) + searches_performed: List[str] = Field(default_factory=list) + improvements_made: float = Field( + default=0 + ) # Changed from int to float to handle timestamp values + error: Optional[str] = None + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +class ResearchState(BaseModel): + """Comprehensive state object for the enhanced research pipeline.""" + + # Initial query information + main_query: str = "" + sub_questions: List[str] = Field(default_factory=list) + + # Information gathering results + search_results: Dict[str, List[SearchResult]] = Field(default_factory=dict) + + # Synthesized information + synthesized_info: Dict[str, SynthesizedInfo] = Field(default_factory=dict) + + # Viewpoint analysis + viewpoint_analysis: Optional[ViewpointAnalysis] = None + + # Reflection results + enhanced_info: Dict[str, SynthesizedInfo] = Field(default_factory=dict) + reflection_metadata: Optional[ReflectionMetadata] = None + + # Final report + final_report_html: str = "" + + # Search cost tracking + search_costs: Dict[str, float] = Field( + default_factory=dict, + description="Total costs by search provider (e.g., {'exa': 0.0, 'tavily': 0.0})", + ) + search_cost_details: List[Dict[str, Any]] = Field( + default_factory=list, + description="Detailed log of each search with cost information", + ) + # Format: [{"provider": "exa", "query": "...", "cost": 0.0, "timestamp": ..., "step": "...", "sub_question": "..."}] + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + def get_current_stage(self) -> str: + """Determine the current stage of research based on filled data.""" + if self.final_report_html: + return "final_report" + elif self.enhanced_info: + return "after_reflection" + elif self.viewpoint_analysis: + return "after_viewpoint_analysis" + elif self.synthesized_info: + return "after_synthesis" + elif self.search_results: + return "after_search" + elif self.sub_questions: + return "after_query_decomposition" + elif self.main_query: + return "initial" + else: + return "empty" + + def update_sub_questions(self, sub_questions: List[str]) -> None: + """Update the sub-questions list.""" + self.sub_questions = sub_questions + + def update_search_results( + self, search_results: Dict[str, List[SearchResult]] + ) -> None: + """Update the search results.""" + self.search_results = search_results + + def update_synthesized_info( + self, synthesized_info: Dict[str, SynthesizedInfo] + ) -> None: + """Update the synthesized information.""" + self.synthesized_info = synthesized_info + + def update_viewpoint_analysis( + self, viewpoint_analysis: ViewpointAnalysis + ) -> None: + """Update the viewpoint analysis.""" + self.viewpoint_analysis = viewpoint_analysis + + def update_after_reflection( + self, + enhanced_info: Dict[str, SynthesizedInfo], + metadata: ReflectionMetadata, + ) -> None: + """Update with reflection results.""" + self.enhanced_info = enhanced_info + self.reflection_metadata = metadata + + def set_final_report(self, html: str) -> None: + """Set the final report HTML.""" + self.final_report_html = html + + +class ReflectionOutput(BaseModel): + """Output from the reflection generation step.""" + + state: ResearchState + recommended_queries: List[str] = Field(default_factory=list) + critique_summary: List[Dict[str, Any]] = Field(default_factory=list) + additional_questions: List[str] = Field(default_factory=list) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +class ApprovalDecision(BaseModel): + """Approval decision from human reviewer.""" + + approved: bool = False + selected_queries: List[str] = Field(default_factory=list) + approval_method: str = "" # "APPROVE_ALL", "SKIP", "SELECT_SPECIFIC" + reviewer_notes: str = "" + timestamp: float = Field(default_factory=lambda: time.time()) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +class PromptTypeMetrics(BaseModel): + """Metrics for a specific prompt type.""" + + prompt_type: str + total_cost: float + input_tokens: int + output_tokens: int + call_count: int + avg_cost_per_call: float + percentage_of_total_cost: float + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +class TracingMetadata(BaseModel): + """Metadata about token usage, costs, and performance for a pipeline run.""" + + # Pipeline information + pipeline_run_name: str = "" + pipeline_run_id: str = "" + + # Token usage + total_input_tokens: int = 0 + total_output_tokens: int = 0 + total_tokens: int = 0 + + # Cost information + total_cost: float = 0.0 + cost_breakdown_by_model: Dict[str, float] = Field(default_factory=dict) + + # Performance metrics + total_latency_seconds: float = 0.0 + formatted_latency: str = "" + observation_count: int = 0 + + # Model usage + models_used: List[str] = Field(default_factory=list) + model_token_breakdown: Dict[str, Dict[str, int]] = Field( + default_factory=dict + ) + # Format: {"model_name": {"input_tokens": X, "output_tokens": Y, "total_tokens": Z}} + + # Trace information + trace_id: str = "" + trace_name: str = "" + trace_tags: List[str] = Field(default_factory=list) + trace_metadata: Dict[str, Any] = Field(default_factory=dict) + + # Step-by-step breakdown + step_costs: Dict[str, float] = Field(default_factory=dict) + step_tokens: Dict[str, Dict[str, int]] = Field(default_factory=dict) + # Format: {"step_name": {"input_tokens": X, "output_tokens": Y}} + + # Prompt-level metrics + prompt_metrics: List[PromptTypeMetrics] = Field( + default_factory=list, description="Cost breakdown by prompt type" + ) + + # Search provider costs + search_costs: Dict[str, float] = Field( + default_factory=dict, description="Total costs by search provider" + ) + search_queries_count: Dict[str, int] = Field( + default_factory=dict, + description="Number of queries by search provider", + ) + search_cost_details: List[Dict[str, Any]] = Field( + default_factory=list, description="Detailed search cost information" + ) + + # Timestamp + collected_at: float = Field(default_factory=lambda: time.time()) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } diff --git a/deep_research/utils/search_utils.py b/deep_research/utils/search_utils.py new file mode 100644 index 00000000..a9b04d51 --- /dev/null +++ b/deep_research/utils/search_utils.py @@ -0,0 +1,721 @@ +import logging +import os +from enum import Enum +from typing import Any, Dict, List, Optional, Union + +from tavily import TavilyClient + +try: + from exa_py import Exa + + EXA_AVAILABLE = True +except ImportError: + EXA_AVAILABLE = False + Exa = None + +from utils.llm_utils import get_structured_llm_output +from utils.prompts import DEFAULT_SEARCH_QUERY_PROMPT +from utils.pydantic_models import SearchResult + +logger = logging.getLogger(__name__) + + +class SearchProvider(Enum): + TAVILY = "tavily" + EXA = "exa" + BOTH = "both" + + +class SearchEngineConfig: + """Configuration for search engines""" + + def __init__(self): + self.tavily_api_key = os.getenv("TAVILY_API_KEY") + self.exa_api_key = os.getenv("EXA_API_KEY") + self.default_provider = os.getenv("DEFAULT_SEARCH_PROVIDER", "tavily") + self.enable_parallel_search = ( + os.getenv("ENABLE_PARALLEL_SEARCH", "false").lower() == "true" + ) + + +def get_search_client(provider: Union[str, SearchProvider]) -> Optional[Any]: + """Get the appropriate search client based on provider.""" + if isinstance(provider, str): + provider = SearchProvider(provider.lower()) + + config = SearchEngineConfig() + + if provider == SearchProvider.TAVILY: + if not config.tavily_api_key: + raise ValueError("TAVILY_API_KEY environment variable not set") + return TavilyClient(api_key=config.tavily_api_key) + + elif provider == SearchProvider.EXA: + if not EXA_AVAILABLE: + raise ImportError( + "exa-py is not installed. Please install it with: pip install exa-py" + ) + if not config.exa_api_key: + raise ValueError("EXA_API_KEY environment variable not set") + return Exa(config.exa_api_key) + + return None + + +def tavily_search( + query: str, + include_raw_content: bool = True, + max_results: int = 3, + cap_content_length: int = 20000, +) -> Dict[str, Any]: + """Perform a search using the Tavily API. + + Args: + query: Search query + include_raw_content: Whether to include raw content in results + max_results: Maximum number of results to return + cap_content_length: Maximum length of content to return + + Returns: + Dict[str, Any]: Search results from Tavily in the following format: + { + "query": str, # The original query + "results": List[Dict], # List of search result objects + "error": str, # Error message (if an error occurred, otherwise omitted) + } + + Each result in "results" has the following structure: + { + "url": str, # URL of the search result + "raw_content": str, # Raw content of the page (if include_raw_content=True) + "title": str, # Title of the page + "snippet": str, # Snippet of the page content + } + """ + try: + tavily_client = get_search_client(SearchProvider.TAVILY) + + # First try with advanced search + results = tavily_client.search( + query=query, + include_raw_content=include_raw_content, + max_results=max_results, + search_depth="advanced", # Use advanced search for better results + include_domains=[], # No domain restrictions + exclude_domains=[], # No exclusions + include_answer=False, # We don't need the answer field + include_images=False, # We don't need images + # Note: 'include_snippets' is not a supported parameter + ) + + # Check if we got good results (with non-None and non-empty content) + if include_raw_content and "results" in results: + bad_content_count = sum( + 1 + for r in results["results"] + if "raw_content" in r + and ( + r["raw_content"] is None or r["raw_content"].strip() == "" + ) + ) + + # If more than half of results have bad content, try a different approach + if bad_content_count > len(results["results"]) / 2: + logger.warning( + f"{bad_content_count}/{len(results['results'])} results have None or empty content. " + "Trying to use 'content' field instead of 'raw_content'..." + ) + + # Try to use the 'content' field which comes by default + for result in results["results"]: + if ( + "raw_content" in result + and ( + result["raw_content"] is None + or result["raw_content"].strip() == "" + ) + ) and "content" in result: + result["raw_content"] = result["content"] + logger.info( + f"Using 'content' field as 'raw_content' for URL {result.get('url', 'unknown')}" + ) + + # Re-check after our fix + bad_content_count = sum( + 1 + for r in results["results"] + if "raw_content" in r + and ( + r["raw_content"] is None + or r["raw_content"].strip() == "" + ) + ) + + if bad_content_count > 0: + logger.warning( + f"Still have {bad_content_count}/{len(results['results'])} results with bad content after fixes." + ) + + # Try alternative approach - search with 'include_answer=True' + try: + # Search with include_answer=True which may give us better content + logger.info( + "Trying alternative search with include_answer=True" + ) + alt_results = tavily_client.search( + query=query, + include_raw_content=include_raw_content, + max_results=max_results, + search_depth="advanced", + include_domains=[], + exclude_domains=[], + include_answer=True, # Include answer this time + include_images=False, + ) + + # Check if we got any improved content + if "results" in alt_results: + # Create a merged results set taking the best content + for i, result in enumerate(alt_results["results"]): + if i < len(results["results"]): + if ( + "raw_content" in result + and result["raw_content"] + and ( + results["results"][i].get( + "raw_content" + ) + is None + or results["results"][i] + .get("raw_content", "") + .strip() + == "" + ) + ): + # Replace the bad content with better content from alt_results + results["results"][i]["raw_content"] = ( + result["raw_content"] + ) + logger.info( + f"Replaced bad content with better content from alternative search for URL {result.get('url', 'unknown')}" + ) + + # If answer is available, add it as a special result + if "answer" in alt_results and alt_results["answer"]: + answer_text = alt_results["answer"] + answer_result = { + "url": "tavily-generated-answer", + "title": "Generated Answer", + "raw_content": f"Generated Answer based on search results:\n\n{answer_text}", + "content": answer_text, + } + results["results"].append(answer_result) + logger.info( + "Added Tavily generated answer as additional search result" + ) + + except Exception as alt_error: + logger.warning( + f"Failed to get better results with alternative search: {alt_error}" + ) + + # Cap content length if specified + if cap_content_length > 0 and "results" in results: + for result in results["results"]: + if "raw_content" in result and result["raw_content"]: + result["raw_content"] = result["raw_content"][ + :cap_content_length + ] + + return results + except Exception as e: + logger.error(f"Error in Tavily search: {e}") + # Return an error structure that's compatible with our expected format + return {"query": query, "results": [], "error": str(e)} + + +def exa_search( + query: str, + max_results: int = 3, + cap_content_length: int = 20000, + search_mode: str = "auto", + include_highlights: bool = False, +) -> Dict[str, Any]: + """Perform a search using the Exa API. + + Args: + query: Search query + max_results: Maximum number of results to return + cap_content_length: Maximum length of content to return + search_mode: Search mode ("neural", "keyword", or "auto") + include_highlights: Whether to include highlights in results + + Returns: + Dict[str, Any]: Search results from Exa in a format compatible with Tavily + """ + try: + exa_client = get_search_client(SearchProvider.EXA) + + # Configure content options + text_options = {"max_characters": cap_content_length} + + kwargs = { + "query": query, + "num_results": max_results, + "type": search_mode, # "neural", "keyword", or "auto" + "text": text_options, + } + + if include_highlights: + kwargs["highlights"] = { + "highlights_per_url": 2, + "num_sentences": 3, + } + + response = exa_client.search_and_contents(**kwargs) + + # Extract cost information + exa_cost = 0.0 + if hasattr(response, "cost_dollars") and hasattr( + response.cost_dollars, "total" + ): + exa_cost = response.cost_dollars.total + logger.info( + f"Exa search cost for query '{query}': ${exa_cost:.4f}" + ) + + # Convert to standardized format compatible with Tavily + results = {"query": query, "results": [], "exa_cost": exa_cost} + + for r in response.results: + result_dict = { + "url": r.url, + "title": r.title or "", + "snippet": "", + "raw_content": getattr(r, "text", ""), + "content": getattr(r, "text", ""), + } + + # Add highlights as snippet if available + if hasattr(r, "highlights") and r.highlights: + result_dict["snippet"] = " ".join(r.highlights[:1]) + + # Store additional metadata + result_dict["_metadata"] = { + "provider": "exa", + "score": getattr(r, "score", None), + "published_date": getattr(r, "published_date", None), + "author": getattr(r, "author", None), + } + + results["results"].append(result_dict) + + return results + + except Exception as e: + logger.error(f"Error in Exa search: {e}") + return {"query": query, "results": [], "error": str(e)} + + +def unified_search( + query: str, + provider: Union[str, SearchProvider, None] = None, + max_results: int = 3, + cap_content_length: int = 20000, + search_mode: str = "auto", + include_highlights: bool = False, + compare_results: bool = False, + **kwargs, +) -> Union[List[SearchResult], Dict[str, List[SearchResult]]]: + """Unified search interface supporting multiple providers. + + Args: + query: Search query + provider: Search provider to use (tavily, exa, both) + max_results: Maximum number of results + cap_content_length: Maximum content length + search_mode: Search mode for Exa ("neural", "keyword", "auto") + include_highlights: Include highlights for Exa results + compare_results: Return results from both providers separately + + Returns: + List[SearchResult] or Dict mapping provider to results (when compare_results=True or provider="both") + """ + # Use default provider if not specified + if provider is None: + config = SearchEngineConfig() + provider = config.default_provider + + # Convert string to enum if needed + if isinstance(provider, str): + provider = SearchProvider(provider.lower()) + + # Handle single provider case + if provider == SearchProvider.TAVILY: + results = tavily_search( + query, + max_results=max_results, + cap_content_length=cap_content_length, + ) + extracted, cost = extract_search_results(results, provider="tavily") + return extracted if not compare_results else {"tavily": extracted} + + elif provider == SearchProvider.EXA: + results = exa_search( + query=query, + max_results=max_results, + cap_content_length=cap_content_length, + search_mode=search_mode, + include_highlights=include_highlights, + ) + extracted, cost = extract_search_results(results, provider="exa") + return extracted if not compare_results else {"exa": extracted} + + elif provider == SearchProvider.BOTH: + # Run both searches + tavily_results = tavily_search( + query, + max_results=max_results, + cap_content_length=cap_content_length, + ) + exa_results = exa_search( + query=query, + max_results=max_results, + cap_content_length=cap_content_length, + search_mode=search_mode, + include_highlights=include_highlights, + ) + + # Extract results from both + tavily_extracted, tavily_cost = extract_search_results( + tavily_results, provider="tavily" + ) + exa_extracted, exa_cost = extract_search_results( + exa_results, provider="exa" + ) + + if compare_results: + return {"tavily": tavily_extracted, "exa": exa_extracted} + else: + # Merge results, interleaving them + merged = [] + max_len = max(len(tavily_extracted), len(exa_extracted)) + for i in range(max_len): + if i < len(tavily_extracted): + merged.append(tavily_extracted[i]) + if i < len(exa_extracted): + merged.append(exa_extracted[i]) + return merged[:max_results] # Limit to requested number + + else: + raise ValueError(f"Unknown provider: {provider}") + + +def extract_search_results( + search_results: Dict[str, Any], provider: str = "tavily" +) -> tuple[List[SearchResult], float]: + """Extract SearchResult objects from provider-specific API responses. + + Args: + search_results: Results from search API + provider: Which provider the results came from + + Returns: + Tuple of (List[SearchResult], float): List of converted SearchResult objects with standardized fields + and the search cost (0.0 if not available). + SearchResult is a Pydantic model defined in data_models.py that includes: + - url: The URL of the search result + - content: The raw content of the page + - title: The title of the page + - snippet: A brief snippet of the page content + """ + results_list = [] + search_cost = search_results.get( + "exa_cost", 0.0 + ) # Extract cost if present + + if "results" in search_results: + for result in search_results["results"]: + if "url" in result: + # Get fields with defaults + url = result["url"] + title = result.get("title", "") + + # Try to extract the best content available: + # 1. First try raw_content (if we requested it) + # 2. Then try regular content (always available) + # 3. Then try to use snippet combined with title + # 4. Last resort: use just title + + raw_content = result.get("raw_content", None) + regular_content = result.get("content", "") + snippet = result.get("snippet", "") + + # Set our final content - prioritize raw_content if available and not None + if raw_content is not None and raw_content.strip(): + content = raw_content + # Next best is the regular content field + elif regular_content and regular_content.strip(): + content = regular_content + logger.info( + f"Using 'content' field for URL {url} because raw_content was not available" + ) + # Try to create a usable content from snippet and title + elif snippet: + content = f"Title: {title}\n\nContent: {snippet}" + logger.warning( + f"Using title and snippet as content fallback for {url}" + ) + # Last resort - just use the title + elif title: + content = ( + f"Title: {title}\n\nNo content available for this URL." + ) + logger.warning( + f"Using only title as content fallback for {url}" + ) + # Nothing available + else: + content = "" + logger.warning( + f"No content available for URL {url}, using empty string" + ) + + # Create SearchResult with provider metadata + search_result = SearchResult( + url=url, + content=content, + title=title, + snippet=snippet, + ) + + # Add provider info to metadata if available + if "_metadata" in result: + search_result.metadata = result["_metadata"] + else: + search_result.metadata = {"provider": provider} + + results_list.append(search_result) + + # If we got the answer (Tavily specific), add it as a special result + if ( + provider == "tavily" + and "answer" in search_results + and search_results["answer"] + ): + answer_text = search_results["answer"] + results_list.append( + SearchResult( + url="tavily-generated-answer", + content=f"Generated Answer based on search results:\n\n{answer_text}", + title="Tavily Generated Answer", + snippet=answer_text[:100] + "..." + if len(answer_text) > 100 + else answer_text, + metadata={"provider": "tavily", "type": "generated_answer"}, + ) + ) + logger.info("Added Tavily generated answer as a search result") + + return results_list, search_cost + + +def generate_search_query( + sub_question: str, + model: str = "sambanova/Llama-4-Maverick-17B-128E-Instruct", + system_prompt: Optional[str] = None, + project: str = "deep-research", +) -> Dict[str, Any]: + """Generate an optimized search query for a sub-question. + + Uses litellm for model inference via get_structured_llm_output. + + Args: + sub_question: The sub-question to generate a search query for + model: Model to use (with provider prefix) + system_prompt: System prompt for the LLM, defaults to DEFAULT_SEARCH_QUERY_PROMPT + project: Langfuse project name for LLM tracking + + Returns: + Dictionary with search query and reasoning + """ + if system_prompt is None: + system_prompt = DEFAULT_SEARCH_QUERY_PROMPT + + fallback_response = {"search_query": sub_question, "reasoning": ""} + + return get_structured_llm_output( + prompt=sub_question, + system_prompt=system_prompt, + model=model, + fallback_response=fallback_response, + project=project, + ) + + +def search_and_extract_results( + query: str, + max_results: int = 3, + cap_content_length: int = 20000, + max_retries: int = 2, + provider: Optional[Union[str, SearchProvider]] = None, + search_mode: str = "auto", + include_highlights: bool = False, +) -> tuple[List[SearchResult], float]: + """Perform a search and extract results in one step. + + Args: + query: Search query + max_results: Maximum number of results to return + cap_content_length: Maximum length of content to return + max_retries: Maximum number of retries in case of failure + provider: Search provider to use (tavily, exa, both) + search_mode: Search mode for Exa ("neural", "keyword", "auto") + include_highlights: Include highlights for Exa results + + Returns: + Tuple of (List of SearchResult objects, search cost) + """ + results = [] + total_cost = 0.0 + retry_count = 0 + + # List of alternative query formats to try if the original query fails + # to yield good results with non-None content + alternative_queries = [ + query, # Original query first + f'"{query}"', # Try exact phrase matching + f"about {query}", # Try broader context + f"research on {query}", # Try research-oriented results + query.replace(" OR ", " "), # Try without OR operator + ] + + while retry_count <= max_retries and retry_count < len( + alternative_queries + ): + try: + current_query = alternative_queries[retry_count] + logger.info( + f"Searching with query ({retry_count + 1}/{max_retries + 1}): {current_query}" + ) + + # Determine if we're using Exa to track costs + using_exa = False + if provider: + if isinstance(provider, str): + using_exa = provider.lower() in ["exa", "both"] + else: + using_exa = provider in [ + SearchProvider.EXA, + SearchProvider.BOTH, + ] + else: + config = SearchEngineConfig() + using_exa = config.default_provider.lower() in ["exa", "both"] + + # Perform search based on provider + if using_exa and provider != SearchProvider.BOTH: + # Direct Exa search + search_results = exa_search( + query=current_query, + max_results=max_results, + cap_content_length=cap_content_length, + search_mode=search_mode, + include_highlights=include_highlights, + ) + results, cost = extract_search_results( + search_results, provider="exa" + ) + total_cost += cost + elif provider == SearchProvider.BOTH: + # Search with both providers + tavily_results = tavily_search( + current_query, + max_results=max_results, + cap_content_length=cap_content_length, + ) + exa_results = exa_search( + query=current_query, + max_results=max_results, + cap_content_length=cap_content_length, + search_mode=search_mode, + include_highlights=include_highlights, + ) + + # Extract results from both + tavily_extracted, _ = extract_search_results( + tavily_results, provider="tavily" + ) + exa_extracted, exa_cost = extract_search_results( + exa_results, provider="exa" + ) + total_cost += exa_cost + + # Merge results + results = [] + max_len = max(len(tavily_extracted), len(exa_extracted)) + for i in range(max_len): + if i < len(tavily_extracted): + results.append(tavily_extracted[i]) + if i < len(exa_extracted): + results.append(exa_extracted[i]) + results = results[:max_results] + else: + # Tavily search or unified search + results = unified_search( + query=current_query, + provider=provider, + max_results=max_results, + cap_content_length=cap_content_length, + search_mode=search_mode, + include_highlights=include_highlights, + ) + + # Handle case where unified_search returns a dict + if isinstance(results, dict): + all_results = [] + for provider_results in results.values(): + all_results.extend(provider_results) + results = all_results[:max_results] + + # Check if we got results with actual content + if results: + # Count results with non-empty content + content_results = sum(1 for r in results if r.content.strip()) + + if content_results >= max(1, len(results) // 2): + logger.info( + f"Found {content_results}/{len(results)} results with content" + ) + return results, total_cost + else: + logger.warning( + f"Only found {content_results}/{len(results)} results with content. " + f"Trying alternative query..." + ) + + # If we didn't get good results but haven't hit max retries yet, try again + if retry_count < max_retries: + logger.warning( + f"Inadequate search results. Retrying with alternative query... ({retry_count + 1}/{max_retries})" + ) + retry_count += 1 + else: + # If we're out of retries, return whatever we have + logger.warning( + f"Out of retries. Returning best results found ({len(results)} results)." + ) + return results, total_cost + + except Exception as e: + if retry_count < max_retries: + logger.warning( + f"Search failed with error: {e}. Retrying... ({retry_count + 1}/{max_retries})" + ) + retry_count += 1 + else: + logger.error(f"Search failed after {max_retries} retries: {e}") + return [], 0.0 + + # If we've exhausted all retries, return the best results we have + return results, total_cost diff --git a/deep_research/utils/tracing_metadata_utils.py b/deep_research/utils/tracing_metadata_utils.py new file mode 100644 index 00000000..59c7b37e --- /dev/null +++ b/deep_research/utils/tracing_metadata_utils.py @@ -0,0 +1,745 @@ +"""Utilities for collecting and analyzing tracing metadata from Langfuse.""" + +import time +from datetime import datetime, timedelta, timezone +from functools import wraps +from typing import Any, Dict, List, Optional, Tuple + +from langfuse import Langfuse +from langfuse.api.core import ApiError +from langfuse.client import ObservationsView, TraceWithDetails +from rich import print +from rich.console import Console +from rich.table import Table + +console = Console() + +langfuse = Langfuse() + +# Prompt type identification keywords +PROMPT_IDENTIFIERS = { + "query_decomposition": [ + "MAIN RESEARCH QUERY", + "DIFFERENT DIMENSIONS", + "sub-questions", + ], + "search_query": ["Deep Research assistant", "effective search query"], + "synthesis": [ + "information synthesis", + "comprehensive answer", + "confidence level", + ], + "viewpoint_analysis": [ + "multi-perspective analysis", + "viewpoint categories", + ], + "reflection": ["critique and improve", "information gaps"], + "additional_synthesis": ["enhance the original synthesis"], + "conclusion_generation": [ + "Synthesis and Integration", + "Direct Response to Main Query", + ], + "executive_summary": [ + "executive summaries", + "Key Findings", + "250-400 words", + ], + "introduction": ["engaging introductions", "Context and Relevance"], +} + +# Rate limiting configuration +# Adjust these based on your Langfuse tier: +# - Hobby: 30 req/min for Other APIs -> ~2s between requests +# - Core: 100 req/min -> ~0.6s between requests +# - Pro: 1000 req/min -> ~0.06s between requests +RATE_LIMIT_DELAY = 0.1 # 100ms between requests (safe for most tiers) +MAX_RETRIES = 3 +INITIAL_BACKOFF = 1.0 # Initial backoff in seconds + +# Batch processing configuration +BATCH_DELAY = 0.5 # Additional delay between batches of requests + + +def rate_limited(func): + """Decorator to add rate limiting between API calls.""" + + @wraps(func) + def wrapper(*args, **kwargs): + time.sleep(RATE_LIMIT_DELAY) + return func(*args, **kwargs) + + return wrapper + + +def retry_with_backoff(func): + """Decorator to retry functions with exponential backoff on rate limit errors.""" + + @wraps(func) + def wrapper(*args, **kwargs): + backoff = INITIAL_BACKOFF + last_exception = None + + for attempt in range(MAX_RETRIES): + try: + return func(*args, **kwargs) + except ApiError as e: + if e.status_code == 429: # Rate limit error + last_exception = e + if attempt < MAX_RETRIES - 1: + wait_time = backoff * (2**attempt) + console.print( + f"[yellow]Rate limit hit. Retrying in {wait_time:.1f}s...[/yellow]" + ) + time.sleep(wait_time) + continue + raise + except Exception: + # For non-rate limit errors, raise immediately + raise + + # If we've exhausted all retries + if last_exception: + raise last_exception + + return wrapper + + +@rate_limited +@retry_with_backoff +def fetch_traces_safe(limit: Optional[int] = None) -> List[TraceWithDetails]: + """Safely fetch traces with rate limiting and retry logic.""" + return langfuse.fetch_traces(limit=limit).data + + +@rate_limited +@retry_with_backoff +def fetch_observations_safe(trace_id: str) -> List[ObservationsView]: + """Safely fetch observations with rate limiting and retry logic.""" + return langfuse.fetch_observations(trace_id=trace_id).data + + +def get_total_trace_cost(trace_id: str) -> float: + """Calculate the total cost for a single trace by summing all observation costs. + + Args: + trace_id: The ID of the trace to calculate cost for + + Returns: + Total cost across all observations in the trace + """ + try: + observations = fetch_observations_safe(trace_id=trace_id) + total_cost = 0.0 + + for obs in observations: + # Check multiple possible cost fields + if ( + hasattr(obs, "calculated_total_cost") + and obs.calculated_total_cost + ): + total_cost += obs.calculated_total_cost + elif hasattr(obs, "total_price") and obs.total_price: + total_cost += obs.total_price + elif hasattr(obs, "total_cost") and obs.total_cost: + total_cost += obs.total_cost + # If cost details are available, calculate from input/output costs + elif hasattr(obs, "calculated_input_cost") and hasattr( + obs, "calculated_output_cost" + ): + if obs.calculated_input_cost and obs.calculated_output_cost: + total_cost += ( + obs.calculated_input_cost + obs.calculated_output_cost + ) + + return total_cost + except Exception as e: + print(f"[red]Error calculating trace cost: {e}[/red]") + return 0.0 + + +def get_total_tokens_used(trace_id: str) -> Tuple[int, int]: + """Calculate total input and output tokens used for a trace. + + Args: + trace_id: The ID of the trace to calculate tokens for + + Returns: + Tuple of (input_tokens, output_tokens) + """ + try: + observations = fetch_observations_safe(trace_id=trace_id) + total_input_tokens = 0 + total_output_tokens = 0 + + for obs in observations: + # Check for token fields in different possible locations + if hasattr(obs, "usage") and obs.usage: + if hasattr(obs.usage, "input") and obs.usage.input: + total_input_tokens += obs.usage.input + if hasattr(obs.usage, "output") and obs.usage.output: + total_output_tokens += obs.usage.output + # Also check for direct token fields + elif hasattr(obs, "promptTokens") and hasattr( + obs, "completionTokens" + ): + if obs.promptTokens: + total_input_tokens += obs.promptTokens + if obs.completionTokens: + total_output_tokens += obs.completionTokens + + return total_input_tokens, total_output_tokens + except Exception as e: + print(f"[red]Error calculating tokens: {e}[/red]") + return 0, 0 + + +def get_trace_stats(trace: TraceWithDetails) -> Dict[str, Any]: + """Get comprehensive statistics for a trace. + + Args: + trace: The trace object to analyze + + Returns: + Dictionary containing trace statistics including cost, latency, tokens, and metadata + """ + try: + # Get cost and token data + total_cost = get_total_trace_cost(trace.id) + input_tokens, output_tokens = get_total_tokens_used(trace.id) + + # Get observation count + observations = fetch_observations_safe(trace_id=trace.id) + observation_count = len(observations) + + # Extract model information from observations + models_used = set() + for obs in observations: + if hasattr(obs, "model") and obs.model: + models_used.add(obs.model) + + stats = { + "trace_id": trace.id, + "timestamp": trace.timestamp, + "total_cost": total_cost, + "latency_seconds": trace.latency + if hasattr(trace, "latency") + else 0, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, + "observation_count": observation_count, + "models_used": list(models_used), + "metadata": trace.metadata if hasattr(trace, "metadata") else {}, + "tags": trace.tags if hasattr(trace, "tags") else [], + "user_id": trace.user_id if hasattr(trace, "user_id") else None, + "session_id": trace.session_id + if hasattr(trace, "session_id") + else None, + } + + # Add formatted latency + if stats["latency_seconds"]: + minutes = int(stats["latency_seconds"] // 60) + seconds = stats["latency_seconds"] % 60 + stats["latency_formatted"] = f"{minutes}m {seconds:.1f}s" + else: + stats["latency_formatted"] = "0m 0.0s" + + return stats + except Exception as e: + print(f"[red]Error getting trace stats: {e}[/red]") + return {} + + +def get_traces_by_name(name: str, limit: int = 1) -> List[TraceWithDetails]: + """Get traces by name using Langfuse API. + + Args: + name: The name of the trace to search for + limit: Maximum number of traces to return (default: 1) + + Returns: + List of traces matching the name + """ + try: + # Use the Langfuse API to get traces by name + traces_response = langfuse.get_traces(name=name, limit=limit) + return traces_response.data + except Exception as e: + print(f"[red]Error fetching traces by name: {e}[/red]") + return [] + + +def get_observations_for_trace(trace_id: str) -> List[ObservationsView]: + """Get all observations for a specific trace. + + Args: + trace_id: The ID of the trace + + Returns: + List of observations for the trace + """ + try: + observations_response = langfuse.get_observations(trace_id=trace_id) + return observations_response.data + except Exception as e: + print(f"[red]Error fetching observations: {e}[/red]") + return [] + + +def filter_traces_by_date_range( + start_date: datetime, end_date: datetime, limit: Optional[int] = None +) -> List[TraceWithDetails]: + """Filter traces within a specific date range. + + Args: + start_date: Start of the date range (inclusive) + end_date: End of the date range (inclusive) + limit: Maximum number of traces to return + + Returns: + List of traces within the date range + """ + try: + # Ensure dates are timezone-aware + if start_date.tzinfo is None: + start_date = start_date.replace(tzinfo=timezone.utc) + if end_date.tzinfo is None: + end_date = end_date.replace(tzinfo=timezone.utc) + + # Fetch all traces (or up to API maximum limit of 100) + all_traces = fetch_traces_safe(limit=limit or 100) + + # Filter by date range + filtered_traces = [ + trace + for trace in all_traces + if start_date <= trace.timestamp <= end_date + ] + + # Sort by timestamp (most recent first) + filtered_traces.sort(key=lambda x: x.timestamp, reverse=True) + + # Apply limit if specified + if limit: + filtered_traces = filtered_traces[:limit] + + return filtered_traces + except Exception as e: + print(f"[red]Error filtering traces by date range: {e}[/red]") + return [] + + +def get_traces_last_n_days( + days: int, limit: Optional[int] = None +) -> List[TraceWithDetails]: + """Get traces from the last N days. + + Args: + days: Number of days to look back + limit: Maximum number of traces to return + + Returns: + List of traces from the last N days + """ + end_date = datetime.now(timezone.utc) + start_date = end_date - timedelta(days=days) + + return filter_traces_by_date_range(start_date, end_date, limit) + + +def get_trace_stats_batch( + traces: List[TraceWithDetails], show_progress: bool = True +) -> List[Dict[str, Any]]: + """Get statistics for multiple traces efficiently with progress tracking. + + Args: + traces: List of traces to analyze + show_progress: Whether to show progress bar + + Returns: + List of dictionaries containing trace statistics + """ + stats_list = [] + + for i, trace in enumerate(traces): + if show_progress and i % 5 == 0: + console.print( + f"[dim]Processing trace {i + 1}/{len(traces)}...[/dim]" + ) + + stats = get_trace_stats(trace) + stats_list.append(stats) + + return stats_list + + +def get_aggregate_stats_for_traces( + traces: List[TraceWithDetails], +) -> Dict[str, Any]: + """Calculate aggregate statistics for a list of traces. + + Args: + traces: List of traces to analyze + + Returns: + Dictionary containing aggregate statistics + """ + if not traces: + return { + "trace_count": 0, + "total_cost": 0.0, + "total_input_tokens": 0, + "total_output_tokens": 0, + "total_tokens": 0, + "average_cost_per_trace": 0.0, + "average_latency_seconds": 0.0, + "total_observations": 0, + } + + total_cost = 0.0 + total_input_tokens = 0 + total_output_tokens = 0 + total_latency = 0.0 + total_observations = 0 + all_models = set() + + for trace in traces: + stats = get_trace_stats(trace) + total_cost += stats.get("total_cost", 0) + total_input_tokens += stats.get("input_tokens", 0) + total_output_tokens += stats.get("output_tokens", 0) + total_latency += stats.get("latency_seconds", 0) + total_observations += stats.get("observation_count", 0) + all_models.update(stats.get("models_used", [])) + + return { + "trace_count": len(traces), + "total_cost": total_cost, + "total_input_tokens": total_input_tokens, + "total_output_tokens": total_output_tokens, + "total_tokens": total_input_tokens + total_output_tokens, + "average_cost_per_trace": total_cost / len(traces) if traces else 0, + "average_latency_seconds": total_latency / len(traces) + if traces + else 0, + "total_observations": total_observations, + "models_used": list(all_models), + } + + +def display_trace_stats_table( + traces: List[TraceWithDetails], title: str = "Trace Statistics" +): + """Display trace statistics in a formatted table. + + Args: + traces: List of traces to display + title: Title for the table + """ + table = Table(title=title, show_header=True, header_style="bold magenta") + table.add_column("Trace ID", style="cyan", no_wrap=True) + table.add_column("Timestamp", style="yellow") + table.add_column("Cost ($)", justify="right", style="green") + table.add_column("Tokens (In/Out)", justify="right") + table.add_column("Latency", justify="right") + table.add_column("Observations", justify="right") + + for trace in traces[:10]: # Limit to 10 for display + stats = get_trace_stats(trace) + table.add_row( + stats["trace_id"][:12] + "...", + stats["timestamp"].strftime("%Y-%m-%d %H:%M"), + f"${stats['total_cost']:.4f}", + f"{stats['input_tokens']:,}/{stats['output_tokens']:,}", + stats["latency_formatted"], + str(stats["observation_count"]), + ) + + console.print(table) + + +def identify_prompt_type(observation: ObservationsView) -> str: + """Identify the prompt type based on keywords in the observation's input. + + Examines the system prompt in observation.input['messages'][0]['content'] + for unique keywords that identify each prompt type. + + Args: + observation: The observation to analyze + + Returns: + str: The prompt type name, or "unknown" if not identified + """ + try: + # Access the system prompt from the messages + if hasattr(observation, "input") and observation.input: + messages = observation.input.get("messages", []) + if messages and len(messages) > 0: + system_content = messages[0].get("content", "") + + # Check each prompt type's keywords + for prompt_type, keywords in PROMPT_IDENTIFIERS.items(): + # Check if any keyword is in the system prompt + for keyword in keywords: + if keyword in system_content: + return prompt_type + + return "unknown" + except Exception as e: + console.print( + f"[yellow]Warning: Could not identify prompt type: {e}[/yellow]" + ) + return "unknown" + + +def get_costs_by_prompt_type(trace_id: str) -> Dict[str, Dict[str, float]]: + """Get cost breakdown by prompt type for a given trace. + + Uses observation.usage.input/output for token counts and + observation.calculated_total_cost for costs. + + Args: + trace_id: The ID of the trace to analyze + + Returns: + Dict mapping prompt_type to { + 'cost': float, + 'input_tokens': int, + 'output_tokens': int, + 'count': int # number of calls + } + """ + try: + observations = fetch_observations_safe(trace_id=trace_id) + prompt_metrics = {} + + for obs in observations: + # Identify prompt type + prompt_type = identify_prompt_type(obs) + + # Initialize metrics for this prompt type if needed + if prompt_type not in prompt_metrics: + prompt_metrics[prompt_type] = { + "cost": 0.0, + "input_tokens": 0, + "output_tokens": 0, + "count": 0, + } + + # Add cost + cost = 0.0 + if ( + hasattr(obs, "calculated_total_cost") + and obs.calculated_total_cost + ): + cost = obs.calculated_total_cost + prompt_metrics[prompt_type]["cost"] += cost + + # Add tokens + if hasattr(obs, "usage") and obs.usage: + if hasattr(obs.usage, "input") and obs.usage.input: + prompt_metrics[prompt_type]["input_tokens"] += ( + obs.usage.input + ) + if hasattr(obs.usage, "output") and obs.usage.output: + prompt_metrics[prompt_type]["output_tokens"] += ( + obs.usage.output + ) + + # Increment count + prompt_metrics[prompt_type]["count"] += 1 + + return prompt_metrics + except Exception as e: + print(f"[red]Error getting costs by prompt type: {e}[/red]") + return {} + + +def get_prompt_type_statistics(trace_id: str) -> Dict[str, Dict[str, Any]]: + """Get detailed statistics for each prompt type. + + Args: + trace_id: The ID of the trace to analyze + + Returns: + Dict mapping prompt_type to { + 'cost': float, + 'input_tokens': int, + 'output_tokens': int, + 'count': int, + 'avg_cost_per_call': float, + 'avg_input_tokens': float, + 'avg_output_tokens': float, + 'percentage_of_total_cost': float + } + """ + try: + # Get basic metrics + prompt_metrics = get_costs_by_prompt_type(trace_id) + + # Calculate total cost for percentage calculation + total_cost = sum( + metrics["cost"] for metrics in prompt_metrics.values() + ) + + # Enhance with statistics + enhanced_metrics = {} + for prompt_type, metrics in prompt_metrics.items(): + count = metrics["count"] + enhanced_metrics[prompt_type] = { + "cost": metrics["cost"], + "input_tokens": metrics["input_tokens"], + "output_tokens": metrics["output_tokens"], + "count": count, + "avg_cost_per_call": metrics["cost"] / count + if count > 0 + else 0, + "avg_input_tokens": metrics["input_tokens"] / count + if count > 0 + else 0, + "avg_output_tokens": metrics["output_tokens"] / count + if count > 0 + else 0, + "percentage_of_total_cost": ( + metrics["cost"] / total_cost * 100 + ) + if total_cost > 0 + else 0, + } + + return enhanced_metrics + except Exception as e: + print(f"[red]Error getting prompt type statistics: {e}[/red]") + return {} + + +if __name__ == "__main__": + print( + "[bold cyan]ZenML Deep Research - Tracing Metadata Utilities Demo[/bold cyan]\n" + ) + + try: + # Fetch recent traces + print("[yellow]Fetching recent traces...[/yellow]") + traces = fetch_traces_safe(limit=5) + + if not traces: + print("[red]No traces found![/red]") + exit(1) + except ApiError as e: + if e.status_code == 429: + print("[red]Rate limit exceeded. Please try again later.[/red]") + print( + "[yellow]Tip: Consider upgrading your Langfuse tier for higher rate limits.[/yellow]" + ) + else: + print(f"[red]API Error: {e}[/red]") + exit(1) + except Exception as e: + print(f"[red]Error fetching traces: {e}[/red]") + exit(1) + + # Demo 1: Get stats for a single trace + print("\n[bold]1. Single Trace Statistics:[/bold]") + first_trace = traces[0] + stats = get_trace_stats(first_trace) + + console.print(f"Trace ID: [cyan]{stats['trace_id']}[/cyan]") + console.print(f"Timestamp: [yellow]{stats['timestamp']}[/yellow]") + console.print(f"Total Cost: [green]${stats['total_cost']:.4f}[/green]") + console.print( + f"Tokens - Input: [blue]{stats['input_tokens']:,}[/blue], Output: [blue]{stats['output_tokens']:,}[/blue]" + ) + console.print(f"Latency: [magenta]{stats['latency_formatted']}[/magenta]") + console.print(f"Observations: [white]{stats['observation_count']}[/white]") + console.print( + f"Models Used: [cyan]{', '.join(stats['models_used'])}[/cyan]" + ) + + # Demo 2: Get traces from last 7 days + print("\n[bold]2. Traces from Last 7 Days:[/bold]") + recent_traces = get_traces_last_n_days(7, limit=10) + print( + f"Found [green]{len(recent_traces)}[/green] traces in the last 7 days" + ) + + if recent_traces: + display_trace_stats_table(recent_traces, "Last 7 Days Traces") + + # Demo 3: Filter traces by date range + print("\n[bold]3. Filter Traces by Date Range:[/bold]") + end_date = datetime.now(timezone.utc) + start_date = end_date - timedelta(days=3) + + filtered_traces = filter_traces_by_date_range(start_date, end_date) + print( + f"Found [green]{len(filtered_traces)}[/green] traces between {start_date.strftime('%Y-%m-%d')} and {end_date.strftime('%Y-%m-%d')}" + ) + + # Demo 4: Aggregate statistics + print("\n[bold]4. Aggregate Statistics for All Recent Traces:[/bold]") + agg_stats = get_aggregate_stats_for_traces(traces) + + table = Table( + title="Aggregate Statistics", + show_header=True, + header_style="bold magenta", + ) + table.add_column("Metric", style="cyan") + table.add_column("Value", justify="right", style="yellow") + + table.add_row("Total Traces", str(agg_stats["trace_count"])) + table.add_row("Total Cost", f"${agg_stats['total_cost']:.4f}") + table.add_row( + "Average Cost per Trace", f"${agg_stats['average_cost_per_trace']:.4f}" + ) + table.add_row("Total Input Tokens", f"{agg_stats['total_input_tokens']:,}") + table.add_row( + "Total Output Tokens", f"{agg_stats['total_output_tokens']:,}" + ) + table.add_row("Total Tokens", f"{agg_stats['total_tokens']:,}") + table.add_row( + "Average Latency", f"{agg_stats['average_latency_seconds']:.1f}s" + ) + table.add_row("Total Observations", str(agg_stats["total_observations"])) + + console.print(table) + + # Demo 5: Cost breakdown by observation + print("\n[bold]5. Cost Breakdown for First Trace:[/bold]") + observations = fetch_observations_safe(trace_id=first_trace.id) + + if observations: + table = Table( + title="Observation Cost Breakdown", + show_header=True, + header_style="bold magenta", + ) + table.add_column("Observation", style="cyan", no_wrap=True) + table.add_column("Model", style="yellow") + table.add_column("Tokens (In/Out)", justify="right") + table.add_column("Cost", justify="right", style="green") + + for i, obs in enumerate(observations[:5]): # Show first 5 + cost = 0.0 + if hasattr(obs, "calculated_total_cost"): + cost = obs.calculated_total_cost or 0.0 + + in_tokens = 0 + out_tokens = 0 + if hasattr(obs, "usage") and obs.usage: + in_tokens = obs.usage.input or 0 + out_tokens = obs.usage.output or 0 + elif hasattr(obs, "promptTokens"): + in_tokens = obs.promptTokens or 0 + out_tokens = obs.completionTokens or 0 + + table.add_row( + f"Obs {i + 1}", + obs.model if hasattr(obs, "model") else "Unknown", + f"{in_tokens:,}/{out_tokens:,}", + f"${cost:.4f}", + ) + + console.print(table)