From 758a71270ef4878152d6b93becaf9460b7ad1112 Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Tue, 27 May 2025 19:57:12 +0200 Subject: [PATCH 01/11] commented files --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index 851a4f94..f27e166a 100644 --- a/.gitignore +++ b/.gitignore @@ -170,3 +170,7 @@ nohup.out # Claude .claude/ + +# design folder +design/ +deep_research/design From 09e27e29a36e64cf371c4e800bc9262685eaefd4 Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Tue, 27 May 2025 19:57:37 +0200 Subject: [PATCH 02/11] initial commit --- deep_research/README.md | 604 +++++++ deep_research/__init__.py | 0 deep_research/configs/balanced_research.yaml | 79 + deep_research/configs/compare_viewpoints.yaml | 43 + deep_research/configs/daily_trends.yaml | 45 + deep_research/configs/deep_research.yaml | 81 + deep_research/configs/enhanced_research.yaml | 71 + .../enhanced_research_with_approval.yaml | 77 + deep_research/configs/parallel_research.yaml | 93 + deep_research/configs/pipeline_config.yaml | 58 + deep_research/configs/quick_research.yaml | 59 + deep_research/configs/rapid_research.yaml | 59 + deep_research/configs/thorough_research.yaml | 63 + deep_research/logging_config.py | 42 + deep_research/materializers/__init__.py | 21 + .../approval_decision_materializer.py | 281 +++ .../materializers/prompts_materializer.py | 509 ++++++ .../materializers/pydantic_materializer.py | 764 ++++++++ .../reflection_output_materializer.py | 279 +++ .../tracing_metadata_materializer.py | 603 +++++++ deep_research/pipelines/__init__.py | 11 + .../pipelines/parallel_research_pipeline.py | 133 ++ deep_research/requirements.txt | 10 + deep_research/run.py | 330 ++++ deep_research/steps/__init__.py | 7 + deep_research/steps/approval_step.py | 308 ++++ .../steps/collect_tracing_metadata_step.py | 234 +++ deep_research/steps/cross_viewpoint_step.py | 228 +++ .../steps/execute_approved_searches_step.py | 423 +++++ .../steps/generate_reflection_step.py | 167 ++ .../steps/initialize_prompts_step.py | 45 + .../steps/iterative_reflection_step.py | 385 ++++ deep_research/steps/merge_results_step.py | 265 +++ .../steps/process_sub_question_step.py | 289 +++ .../steps/pydantic_final_report_step.py | 1250 +++++++++++++ .../steps/query_decomposition_step.py | 174 ++ deep_research/tests/__init__.py | 1 + deep_research/tests/conftest.py | 11 + deep_research/tests/test_approval_utils.py | 150 ++ deep_research/tests/test_prompt_loader.py | 108 ++ deep_research/tests/test_prompt_models.py | 183 ++ .../tests/test_pydantic_final_report_step.py | 167 ++ .../tests/test_pydantic_materializer.py | 161 ++ deep_research/tests/test_pydantic_models.py | 303 ++++ deep_research/utils/__init__.py | 7 + deep_research/utils/approval_utils.py | 159 ++ deep_research/utils/helper_functions.py | 192 ++ deep_research/utils/llm_utils.py | 387 ++++ deep_research/utils/prompt_loader.py | 136 ++ deep_research/utils/prompt_models.py | 123 ++ deep_research/utils/prompts.py | 1605 +++++++++++++++++ deep_research/utils/pydantic_models.py | 300 +++ deep_research/utils/search_utils.py | 721 ++++++++ deep_research/utils/tracing_metadata_utils.py | 745 ++++++++ 54 files changed, 13549 insertions(+) create mode 100644 deep_research/README.md create mode 100644 deep_research/__init__.py create mode 100644 deep_research/configs/balanced_research.yaml create mode 100644 deep_research/configs/compare_viewpoints.yaml create mode 100644 deep_research/configs/daily_trends.yaml create mode 100644 deep_research/configs/deep_research.yaml create mode 100644 deep_research/configs/enhanced_research.yaml create mode 100644 deep_research/configs/enhanced_research_with_approval.yaml create mode 100644 deep_research/configs/parallel_research.yaml create mode 100644 deep_research/configs/pipeline_config.yaml create mode 100644 deep_research/configs/quick_research.yaml create mode 100644 deep_research/configs/rapid_research.yaml create mode 100644 deep_research/configs/thorough_research.yaml create mode 100644 deep_research/logging_config.py create mode 100644 deep_research/materializers/__init__.py create mode 100644 deep_research/materializers/approval_decision_materializer.py create mode 100644 deep_research/materializers/prompts_materializer.py create mode 100644 deep_research/materializers/pydantic_materializer.py create mode 100644 deep_research/materializers/reflection_output_materializer.py create mode 100644 deep_research/materializers/tracing_metadata_materializer.py create mode 100644 deep_research/pipelines/__init__.py create mode 100644 deep_research/pipelines/parallel_research_pipeline.py create mode 100644 deep_research/requirements.txt create mode 100644 deep_research/run.py create mode 100644 deep_research/steps/__init__.py create mode 100644 deep_research/steps/approval_step.py create mode 100644 deep_research/steps/collect_tracing_metadata_step.py create mode 100644 deep_research/steps/cross_viewpoint_step.py create mode 100644 deep_research/steps/execute_approved_searches_step.py create mode 100644 deep_research/steps/generate_reflection_step.py create mode 100644 deep_research/steps/initialize_prompts_step.py create mode 100644 deep_research/steps/iterative_reflection_step.py create mode 100644 deep_research/steps/merge_results_step.py create mode 100644 deep_research/steps/process_sub_question_step.py create mode 100644 deep_research/steps/pydantic_final_report_step.py create mode 100644 deep_research/steps/query_decomposition_step.py create mode 100644 deep_research/tests/__init__.py create mode 100644 deep_research/tests/conftest.py create mode 100644 deep_research/tests/test_approval_utils.py create mode 100644 deep_research/tests/test_prompt_loader.py create mode 100644 deep_research/tests/test_prompt_models.py create mode 100644 deep_research/tests/test_pydantic_final_report_step.py create mode 100644 deep_research/tests/test_pydantic_materializer.py create mode 100644 deep_research/tests/test_pydantic_models.py create mode 100644 deep_research/utils/__init__.py create mode 100644 deep_research/utils/approval_utils.py create mode 100644 deep_research/utils/helper_functions.py create mode 100644 deep_research/utils/llm_utils.py create mode 100644 deep_research/utils/prompt_loader.py create mode 100644 deep_research/utils/prompt_models.py create mode 100644 deep_research/utils/prompts.py create mode 100644 deep_research/utils/pydantic_models.py create mode 100644 deep_research/utils/search_utils.py create mode 100644 deep_research/utils/tracing_metadata_utils.py diff --git a/deep_research/README.md b/deep_research/README.md new file mode 100644 index 00000000..c057d85b --- /dev/null +++ b/deep_research/README.md @@ -0,0 +1,604 @@ +# 🔍 ZenML Deep Research Agent + +A production-ready MLOps pipeline for conducting deep, comprehensive research on any topic using LLMs and web search capabilities. + +
+ Research Pipeline Visualization +

ZenML Deep Research pipeline flow

+
+ +## 🎯 Overview + +The ZenML Deep Research Agent is a scalable, modular pipeline that automates in-depth research on any topic. It: + +- Creates a structured outline based on your research query +- Researches each section through targeted web searches and LLM analysis +- Iteratively refines content through reflection cycles +- Produces a comprehensive, well-formatted research report +- Visualizes the research process and report structure in the ZenML dashboard + +This project transforms exploratory notebook-based research into a production-grade, reproducible, and transparent process using the ZenML MLOps framework. + +## 📝 Example Research Results + +The Deep Research Agent produces comprehensive, well-structured reports on any topic. Here's an example of research conducted on quantum computing: + +
+ Sample Research Report +

Sample report generated by the Deep Research Agent

+
+ +## 🚀 Pipeline Architecture + +The pipeline uses a parallel processing architecture for efficiency and breaks down the research process into granular steps for maximum modularity and control: + +1. **Initialize Prompts**: Load and track all prompts as versioned artifacts +2. **Query Decomposition**: Break down the main query into specific sub-questions +3. **Parallel Information Gathering**: Process multiple sub-questions concurrently for faster results +4. **Merge Results**: Combine results from parallel processing into a unified state +5. **Cross-Viewpoint Analysis**: Analyze discrepancies and agreements between different perspectives +6. **Reflection Generation**: Generate recommendations for improving research quality +7. **Human Approval** (optional): Get human approval for additional searches +8. **Execute Approved Searches**: Perform approved additional searches to fill gaps +9. **Final Report Generation**: Compile all synthesized information into a coherent HTML report +10. **Collect Tracing Metadata**: Gather comprehensive metrics about token usage, costs, and performance + +This architecture enables: +- Better reproducibility and caching of intermediate results +- Parallel processing for faster research completion +- Easier debugging and monitoring of specific research stages +- More flexible reconfiguration of individual components +- Enhanced transparency into how the research is conducted +- Human oversight and control over iterative research expansions + +## 💡 Under the Hood + +- **LLM Integration**: Uses litellm for flexible access to various LLM providers +- **Web Research**: Utilizes Tavily API for targeted internet searches +- **ZenML Orchestration**: Manages pipeline flow, artifacts, and caching +- **Reproducibility**: Track every step, parameter, and output via ZenML +- **Visualizations**: Interactive visualizations of the research structure and progress +- **Report Generation**: Uses static HTML templates for consistent, high-quality reports +- **Human-in-the-Loop**: Optional approval mechanism via ZenML alerters (Discord, Slack, etc.) +- **LLM Observability**: Integrated Langfuse tracking for monitoring LLM usage, costs, and performance + +## 🛠️ Getting Started + +### Prerequisites + +- Python 3.9+ +- ZenML installed and configured +- API key for your preferred LLM provider (configured with litellm) +- Tavily API key +- Langfuse account for LLM tracking (optional but recommended) + +### Installation + +```bash +# Clone the repository +git clone +cd zenml_deep_research + +# Install dependencies +pip install -r requirements.txt + +# Set up API keys +export OPENAI_API_KEY=your_openai_key # Or another LLM provider key +export TAVILY_API_KEY=your_tavily_key # For Tavily search (default) +export EXA_API_KEY=your_exa_key # For Exa search (optional) + +# Set up Langfuse for LLM tracking (optional) +export LANGFUSE_PUBLIC_KEY=your_public_key +export LANGFUSE_SECRET_KEY=your_secret_key +export LANGFUSE_HOST=https://cloud.langfuse.com # Or your self-hosted URL + +# Initialize ZenML (if needed) +zenml init +``` + +### Setting up Langfuse for LLM Tracking + +The pipeline integrates with [Langfuse](https://langfuse.com) for comprehensive LLM observability and tracking. This allows you to monitor LLM usage, costs, and performance across all pipeline runs. + +#### 1. Create a Langfuse Account + +1. Sign up at [cloud.langfuse.com](https://cloud.langfuse.com) or set up a self-hosted instance +2. Create a new project in your Langfuse dashboard (e.g., "deep-research") +3. Navigate to Settings → API Keys to get your credentials + +#### 2. Configure Environment Variables + +Set the following environment variables with your Langfuse credentials: + +```bash +export LANGFUSE_PUBLIC_KEY=pk-lf-... # Your public key +export LANGFUSE_SECRET_KEY=sk-lf-... # Your secret key +export LANGFUSE_HOST=https://cloud.langfuse.com # Or your self-hosted URL +``` + +#### 3. Configure Project Name + +The Langfuse project name can be configured in any of the pipeline configuration files: + +```yaml +# configs/enhanced_research.yaml +langfuse_project_name: "deep-research" # Change to match your Langfuse project +``` + +**Note**: The project must already exist in your Langfuse dashboard before running the pipeline. + +#### What Gets Tracked + +When Langfuse is configured, the pipeline automatically tracks: + +- **All LLM calls** with their prompts, responses, and token usage +- **Pipeline trace information** including: + - `trace_name`: The ZenML pipeline run name for easy identification + - `trace_id`: The unique ZenML pipeline run ID for correlation +- **Tagged operations** such as: + - `structured_llm_output`: JSON generation calls + - `information_synthesis`: Research synthesis operations + - `find_most_relevant_string`: Relevance matching operations +- **Performance metrics**: Latency, token counts, and costs +- **Project organization**: All traces are organized under your configured project + +This integration provides full observability into your research pipeline's LLM usage, making it easy to optimize performance, track costs, and debug issues. + +### Running the Pipeline + +#### Basic Usage + +```bash +# Run with default configuration +python run.py +``` + +The default configuration and research query are defined in `configs/enhanced_research.yaml`. + +#### Using Research Mode Presets + +The pipeline includes three pre-configured research modes for different use cases: + +```bash +# Rapid mode - Quick overview with minimal depth +python run.py --mode rapid + +# Balanced mode - Standard research depth (default) +python run.py --mode balanced + +# Deep mode - Comprehensive analysis with maximum depth +python run.py --mode deep +``` + +**Mode Comparison:** + +| Mode | Sub-Questions | Search Results* | Additional Searches | Best For | +|------|---------------|----------------|-------------------|----------| +| **Rapid** | 5 | 2 per search | 0 | Quick overviews, time-sensitive research | +| **Balanced** | 10 | 3 per search | 2 | Most research tasks, good depth/speed ratio | +| **Deep** | 15 | 5 per search | 4 | Comprehensive analysis, academic research | + +*Can be overridden with `--num-results` + +#### Using Different Configurations + +```bash +# Run with a custom configuration file +python run.py --config configs/custom_enhanced_config.yaml + +# Override the research query from command line +python run.py --query "My research topic" + +# Specify maximum number of sub-questions to process in parallel +python run.py --max-sub-questions 15 + +# Combine mode with other options +python run.py --mode deep --query "Complex topic" --require-approval + +# Combine multiple options +python run.py --config configs/custom_enhanced_config.yaml --query "My research topic" --max-sub-questions 12 +``` + +### Advanced Options + +```bash +# Enable debug logging +python run.py --debug + +# Disable caching for a fresh run +python run.py --no-cache + +# Specify a log file +python run.py --log-file research.log + +# Enable human-in-the-loop approval for additional research +python run.py --require-approval + +# Set approval timeout (in seconds) +python run.py --require-approval --approval-timeout 7200 + +# Use a different search provider (default: tavily) +python run.py --search-provider exa # Use Exa search +python run.py --search-provider both # Use both providers +python run.py --search-provider exa --search-mode neural # Exa with neural search + +# Control the number of search results per query +python run.py --num-results 5 # Get 5 results per search +python run.py --num-results 10 --search-provider exa # 10 results with Exa +``` + +### Search Providers + +The pipeline supports multiple search providers for flexibility and comparison: + +#### Available Providers + +1. **Tavily** (Default) + - Traditional keyword-based search + - Good for factual information and current events + - Requires `TAVILY_API_KEY` environment variable + +2. **Exa** + - Neural search engine with semantic understanding + - Better for conceptual and research-oriented queries + - Supports three search modes: + - `auto` (default): Automatically chooses between neural and keyword + - `neural`: Semantic search for conceptual understanding + - `keyword`: Traditional keyword matching + - Requires `EXA_API_KEY` environment variable + +3. **Both** + - Runs searches on both providers + - Useful for comprehensive research or comparing results + - Requires both API keys + +#### Usage Examples + +```bash +# Use Exa with neural search +python run.py --search-provider exa --search-mode neural + +# Compare results from both providers +python run.py --search-provider both + +# Use Exa with keyword search for exact matches +python run.py --search-provider exa --search-mode keyword + +# Combine with other options +python run.py --mode deep --search-provider exa --require-approval +``` + +### Human-in-the-Loop Approval + +The pipeline supports human approval for additional research queries identified during the reflection phase: + +```bash +# Enable approval with default 1-hour timeout +python run.py --require-approval + +# Custom timeout (2 hours) +python run.py --require-approval --approval-timeout 7200 + +# Approval works with any configuration +python run.py --config configs/thorough_research.yaml --require-approval +``` + +When enabled, the pipeline will: +1. Pause after the initial research phase +2. Send an approval request via your configured ZenML alerter (Discord, Slack, etc.) +3. Present research progress, identified gaps, and proposed additional queries +4. Wait for your approval before conducting additional searches +5. Continue with approved queries or finalize the report based on your decision + +**Note**: You need a ZenML stack with an alerter configured (e.g., Discord or Slack) for approval functionality to work. + +**Tip**: When using `--mode deep`, the pipeline will suggest enabling `--require-approval` for better control over the comprehensive research process. + +## 📊 Visualizing Research Process + +The pipeline includes built-in visualizations to help you understand and monitor the research process: + +### Viewing Visualizations + +After running the pipeline, you can view the visualizations in the ZenML dashboard: + +1. Start the ZenML dashboard: + ```bash + zenml up + ``` + +2. Navigate to the "Runs" tab in the dashboard +3. Select your pipeline run +4. Explore visualizations for each step: + - **initialize_prompts_step**: View all prompts used in the pipeline + - **initial_query_decomposition_step**: See how the query was broken down + - **process_sub_question_step**: Track progress for each sub-question + - **cross_viewpoint_analysis_step**: View viewpoint analysis results + - **generate_reflection_step**: See reflection and recommendations + - **get_research_approval_step**: View approval decisions + - **pydantic_final_report_step**: Access the final research state + - **collect_tracing_metadata_step**: View comprehensive cost and performance metrics + +### Visualization Features + +The visualizations provide: +- An overview of the report structure +- Details of each paragraph's research status +- Search history and source information +- Progress through reflection iterations +- Professionally formatted HTML reports with static templates + +### Sample Visualization + +Here's what the report structure visualization looks like: + +``` +Report Structure: +├── Introduction +│ └── Initial understanding of the topic +├── Historical Background +│ └── Evolution and key developments +├── Current State +│ └── Latest advancements and implementations +└── Conclusion + └── Summary and future implications +``` + +## 📁 Project Structure + +``` +zenml_deep_research/ +├── configs/ # Configuration files +│ ├── __init__.py +│ └── enhanced_research.yaml # Main configuration file +├── materializers/ # Custom materializers for artifact storage +│ ├── __init__.py +│ └── pydantic_materializer.py +├── pipelines/ # ZenML pipeline definitions +│ ├── __init__.py +│ └── parallel_research_pipeline.py +├── steps/ # ZenML pipeline steps +│ ├── __init__.py +│ ├── approval_step.py # Human approval step for additional research +│ ├── cross_viewpoint_step.py +│ ├── execute_approved_searches_step.py # Execute approved searches +│ ├── generate_reflection_step.py # Generate reflection without execution +│ ├── iterative_reflection_step.py # Legacy combined reflection step +│ ├── merge_results_step.py +│ ├── process_sub_question_step.py +│ ├── pydantic_final_report_step.py +│ └── query_decomposition_step.py +├── utils/ # Utility functions and helpers +│ ├── __init__.py +│ ├── approval_utils.py # Human approval utilities +│ ├── helper_functions.py +│ ├── llm_utils.py # LLM integration utilities +│ ├── prompts.py # Contains prompt templates and HTML templates +│ ├── pydantic_models.py # Data models using Pydantic +│ └── search_utils.py # Web search functionality +├── __init__.py +├── requirements.txt # Project dependencies +├── logging_config.py # Logging configuration +├── README.md # Project documentation +└── run.py # Main script to run the pipeline +``` + +## 🔧 Customization + +The project supports two levels of customization: + +### 1. Command-Line Parameters + +You can customize the research behavior directly through command-line parameters: + +```bash +# Specify your research query +python run.py --query "Your research topic" + +# Control parallelism with max-sub-questions +python run.py --max-sub-questions 15 + +# Combine multiple options +python run.py --query "Your research topic" --max-sub-questions 12 --no-cache +``` + +These settings control how the parallel pipeline processes your research query. + +### 2. Pipeline Configuration + +For more detailed settings, modify the configuration file: + +```yaml +# configs/enhanced_research.yaml + +# Enhanced Deep Research Pipeline Configuration +enable_cache: true + +# Research query parameters +query: "Climate change policy debates" + +# Step configurations +steps: + initial_query_decomposition_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + cross_viewpoint_analysis_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + viewpoint_categories: ["scientific", "political", "economic", "social", "ethical", "historical"] + + iterative_reflection_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + max_additional_searches: 2 + num_results_per_search: 3 + + # Human approval configuration (when using --require-approval) + get_research_approval_step: + parameters: + timeout: 3600 # 1 hour timeout for approval + max_queries: 2 # Maximum queries to present for approval + + pydantic_final_report_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 +``` + +To use a custom configuration file: + +```bash +python run.py --config configs/custom_research.yaml +``` + +### Available Configurations + +**Mode-Based Configurations** (automatically selected when using `--mode`): + +| Config File | Mode | Description | +|-------------|------|-------------| +| `rapid_research.yaml` | `--mode rapid` | Quick overview with minimal depth | +| `balanced_research.yaml` | `--mode balanced` | Standard research with moderate depth | +| `deep_research.yaml` | `--mode deep` | Comprehensive analysis with maximum depth | + +**Specialized Configurations:** + +| Config File | Description | Key Parameters | +|-------------|-------------|----------------| +| `enhanced_research.yaml` | Default research configuration | Standard settings, 2 additional searches | +| `thorough_research.yaml` | In-depth analysis | 12 sub-questions, 5 results per search | +| `quick_research.yaml` | Faster results | 5 sub-questions, 2 results per search | +| `daily_trends.yaml` | Research on recent topics | 24-hour search recency, disable cache | +| `compare_viewpoints.yaml` | Focus on comparing perspectives | Extended viewpoint categories | +| `parallel_research.yaml` | Optimized for parallel execution | Configured for distributed orchestrators | + +You can create additional configuration files by copying and modifying the base configuration files above. + +## 🎯 Prompts Tracking and Management + +The pipeline includes a sophisticated prompts tracking system that allows you to track all prompts as versioned artifacts in ZenML. This provides better observability, version control, and visualization of the prompts used in your research pipeline. + +### Overview + +The prompts tracking system enables: +- **Artifact Tracking**: All prompts are tracked as versioned artifacts in ZenML +- **Beautiful Visualizations**: HTML interface in the dashboard with search, copy, and expand features +- **Version Control**: Prompts are versioned alongside your code +- **Pipeline Integration**: Prompts are passed through the pipeline as artifacts, not hardcoded imports + +### Components + +1. **PromptsBundle Model** (`utils/prompt_models.py`) + - Pydantic model containing all prompts used in the pipeline + - Each prompt includes metadata: name, content, description, version, and tags + +2. **PromptsBundleMaterializer** (`materializers/prompts_materializer.py`) + - Custom materializer creating HTML visualizations in the ZenML dashboard + - Features: search, copy-to-clipboard, expandable content, tag categorization + +3. **Prompt Loader** (`utils/prompt_loader.py`) + - Utility to load prompts from `prompts.py` into a PromptsBundle + +### Integration Guide + +To integrate prompts tracking into a pipeline: + +1. **Initialize prompts as the first step:** + ```python + from steps.initialize_prompts_step import initialize_prompts_step + + @pipeline + def my_pipeline(): + prompts_bundle = initialize_prompts_step(pipeline_version="1.0.0") + ``` + +2. **Update steps to receive prompts_bundle:** + ```python + @step + def my_step(state: ResearchState, prompts_bundle: PromptsBundle): + prompt = prompts_bundle.get_prompt_content("synthesis_prompt") + # Use prompt in your step logic + ``` + +3. **Pass prompts_bundle through the pipeline:** + ```python + state = synthesis_step(state=state, prompts_bundle=prompts_bundle) + ``` + +### Benefits + +- **Full Tracking**: Every pipeline run tracks which exact prompts were used +- **Version History**: See how prompts evolved across different runs +- **Debugging**: Easily identify which prompts produced specific outputs +- **A/B Testing**: Compare results using different prompt versions + +### Visualization Features + +The HTML visualization in the ZenML dashboard includes: +- Pipeline version and creation timestamp +- Statistics (total prompts, tagged prompts, custom prompts) +- Search functionality across all prompt content +- Expandable/collapsible prompt content +- One-click copy to clipboard +- Tag-based categorization with visual indicators + +## 📊 Cost and Performance Tracking + +The pipeline includes comprehensive tracking of costs and performance metrics through the `collect_tracing_metadata_step`, which runs at the end of each pipeline execution. + +### Tracked Metrics + +- **LLM Costs**: Detailed breakdown by model and prompt type +- **Search Costs**: Tracking for both Tavily and Exa search providers +- **Token Usage**: Input/output tokens per model and step +- **Performance**: Latency and execution time metrics +- **Cost Attribution**: See which steps and prompts consume the most resources + +### Viewing Metrics + +After pipeline execution, the tracing metadata is available in the ZenML dashboard: + +1. Navigate to your pipeline run +2. Find the `collect_tracing_metadata_step` +3. View the comprehensive cost visualization including: + - Total pipeline cost (LLM + Search) + - Cost breakdown by model + - Token usage distribution + - Performance metrics + +This helps you: +- Optimize pipeline costs by identifying expensive operations +- Monitor token usage to stay within limits +- Track performance over time +- Make informed decisions about model selection + +## 📈 Example Use Cases + +- **Academic Research**: Rapidly generate preliminary research on academic topics +- **Business Intelligence**: Stay informed on industry trends and competitive landscape +- **Content Creation**: Develop well-researched content for articles, blogs, or reports +- **Decision Support**: Gather comprehensive information for informed decision-making + +## 🔄 Integration Possibilities + +This pipeline can integrate with: + +- **Document Storage**: Save reports to database or document management systems +- **Web Applications**: Power research functionality in web interfaces +- **Alerting Systems**: Schedule research on key topics and receive regular reports +- **Other ZenML Pipelines**: Chain with downstream analysis or processing + +## 📄 License + +This project is licensed under the Apache License 2.0. diff --git a/deep_research/__init__.py b/deep_research/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/deep_research/configs/balanced_research.yaml b/deep_research/configs/balanced_research.yaml new file mode 100644 index 00000000..4f8bfa23 --- /dev/null +++ b/deep_research/configs/balanced_research.yaml @@ -0,0 +1,79 @@ +# Deep Research Pipeline Configuration - Balanced Mode +enable_cache: true + +# ZenML MCP +model: + name: "deep_research" + description: "Parallelized ZenML pipelines for deep research on a given query." + tags: + [ + "research", + "exa", + "tavily", + "openrouter", + "sambanova", + "langfuse", + "balanced", + ] + use_cases: "Research on a given query." + +# Langfuse project name for LLM tracking +langfuse_project_name: "deep-research" + +# Research parameters for balanced research +parameters: + query: "Default research query" + +steps: + initial_query_decomposition_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + max_sub_questions: 10 # Balanced number of sub-questions + + process_sub_question_step: + parameters: + llm_model_search: "sambanova/Meta-Llama-3.3-70B-Instruct" + llm_model_synthesis: "sambanova/DeepSeek-R1-Distill-Llama-70B" + cap_search_length: 20000 # Standard cap for search length + + cross_viewpoint_analysis_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + viewpoint_categories: + [ + "scientific", + "political", + "economic", + "social", + "ethical", + "historical", + ] # Standard viewpoints + + generate_reflection_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + get_research_approval_step: + parameters: + timeout: 3600 # 1 hour timeout + max_queries: 2 # Moderate additional queries + + execute_approved_searches_step: + parameters: + llm_model: "sambanova/Meta-Llama-3.3-70B-Instruct" + cap_search_length: 20000 + + pydantic_final_report_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 \ No newline at end of file diff --git a/deep_research/configs/compare_viewpoints.yaml b/deep_research/configs/compare_viewpoints.yaml new file mode 100644 index 00000000..26a59dd4 --- /dev/null +++ b/deep_research/configs/compare_viewpoints.yaml @@ -0,0 +1,43 @@ +# Deep Research Pipeline Configuration - Compare Viewpoints +enable_cache: true + +# ZenML MCP +model: + name: "deep_research" + description: "Parallelized ZenML pipelines for deep research on a given query." + tags: + [ + "research", + "exa", + "tavily", + "openrouter", + "sambanova", + "langfuse", + "viewpoints", + ] + use_cases: "Research on a given query." + +# Langfuse project name for LLM tracking +langfuse_project_name: "deep-research" + +# Research parameters for comparing different viewpoints +steps: + initial_query_decomposition_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + cross_viewpoint_analysis_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + viewpoint_categories: ["scientific", "political", "economic", "social", "ethical", "historical"] + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 diff --git a/deep_research/configs/daily_trends.yaml b/deep_research/configs/daily_trends.yaml new file mode 100644 index 00000000..0f2b6587 --- /dev/null +++ b/deep_research/configs/daily_trends.yaml @@ -0,0 +1,45 @@ +# Deep Research Pipeline Configuration - Daily Trends Research +enable_cache: false # Disable cache to always get fresh results for daily trends + +# ZenML MCP +model: + name: "deep_research" + description: "Parallelized ZenML pipelines for deep research on a given query." + tags: + [ + "research", + "exa", + "tavily", + "openrouter", + "sambanova", + "langfuse", + "daily_trends", + ] + use_cases: "Research on a given query." + +# Langfuse project name for LLM tracking +langfuse_project_name: "deep-research" + +# Research parameters for daily trending topics +parameters: + query: "Latest developments in artificial intelligence" + +steps: + initial_query_decomposition_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + process_sub_question_step: + parameters: + llm_model_search: "sambanova/Meta-Llama-3.3-70B-Instruct" + llm_model_synthesis: "sambanova/DeepSeek-R1-Distill-Llama-70B" +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 \ No newline at end of file diff --git a/deep_research/configs/deep_research.yaml b/deep_research/configs/deep_research.yaml new file mode 100644 index 00000000..61cc4c2b --- /dev/null +++ b/deep_research/configs/deep_research.yaml @@ -0,0 +1,81 @@ +# Deep Research Pipeline Configuration - Deep Comprehensive Mode +enable_cache: false # Disable cache for fresh comprehensive analysis + +# ZenML MCP +model: + name: "deep_research" + description: "Parallelized ZenML pipelines for deep research on a given query." + tags: + [ + "research", + "exa", + "tavily", + "openrouter", + "sambanova", + "langfuse", + "deep", + ] + use_cases: "Research on a given query." + +# Langfuse project name for LLM tracking +langfuse_project_name: "deep-research" + +# Research parameters for deep comprehensive research +parameters: + query: "Default research query" + +steps: + initial_query_decomposition_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + max_sub_questions: 15 # Maximum sub-questions for comprehensive analysis + + process_sub_question_step: + parameters: + llm_model_search: "sambanova/Meta-Llama-3.3-70B-Instruct" + llm_model_synthesis: "sambanova/DeepSeek-R1-Distill-Llama-70B" + cap_search_length: 30000 # Higher cap for more comprehensive data + + cross_viewpoint_analysis_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + viewpoint_categories: + [ + "scientific", + "political", + "economic", + "social", + "ethical", + "historical", + "technological", + "philosophical", + ] # Extended viewpoints for comprehensive analysis + + generate_reflection_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + get_research_approval_step: + parameters: + timeout: 7200 # 2 hour timeout for deep research + max_queries: 4 # Maximum additional queries for deep mode + + execute_approved_searches_step: + parameters: + llm_model: "sambanova/Meta-Llama-3.3-70B-Instruct" + cap_search_length: 30000 + + pydantic_final_report_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 \ No newline at end of file diff --git a/deep_research/configs/enhanced_research.yaml b/deep_research/configs/enhanced_research.yaml new file mode 100644 index 00000000..1933efa9 --- /dev/null +++ b/deep_research/configs/enhanced_research.yaml @@ -0,0 +1,71 @@ +# Enhanced Deep Research Pipeline Configuration +enable_cache: false + +# ZenML MCP +model: + name: "deep_research" + description: "Parallelized ZenML pipelines for deep research on a given query." + tags: + [ + "research", + "exa", + "tavily", + "openrouter", + "sambanova", + "langfuse", + "enhanced", + ] + use_cases: "Research on a given query." + +# Research query parameters +query: "Climate change policy debates" + +# Langfuse project name for LLM tracking +langfuse_project_name: "deep-research" + +# Step configurations +steps: + initial_query_decomposition_step: + parameters: + llm_model: "openrouter/google/gemini-2.5-flash-preview-05-20" + + cross_viewpoint_analysis_step: + parameters: + llm_model: "openrouter/google/gemini-2.5-flash-preview-05-20" + viewpoint_categories: + [ + "scientific", + "political", + "economic", + "social", + "ethical", + "historical", + ] + + generate_reflection_step: + parameters: + llm_model: "openrouter/google/gemini-2.5-flash-preview-05-20" + + get_research_approval_step: + parameters: + timeout: 3600 + max_queries: 2 + + execute_approved_searches_step: + parameters: + llm_model: "openrouter/google/gemini-2.5-flash-preview-05-20" + + pydantic_final_report_step: + parameters: + llm_model: "openrouter/google/gemini-2.5-flash-preview-05-20" + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 diff --git a/deep_research/configs/enhanced_research_with_approval.yaml b/deep_research/configs/enhanced_research_with_approval.yaml new file mode 100644 index 00000000..73d6fe42 --- /dev/null +++ b/deep_research/configs/enhanced_research_with_approval.yaml @@ -0,0 +1,77 @@ +# Enhanced Deep Research Pipeline Configuration with Human Approval +enable_cache: false + +# ZenML MCP +model: + name: "deep_research" + description: "Parallelized ZenML pipelines for deep research on a given query." + tags: + [ + "research", + "exa", + "tavily", + "openrouter", + "sambanova", + "langfuse", + "enhanced_approval", + ] + use_cases: "Research on a given query." + +# Langfuse project name for LLM tracking +langfuse_project_name: "deep-research" + +# Research query parameters +query: "Climate change policy debates" + +# Pipeline parameters +parameters: + require_approval: true # Enable human-in-the-loop approval + approval_timeout: 1800 # 30 minutes timeout for approval + max_additional_searches: 3 # Allow up to 3 additional searches + +# Step configurations +steps: + initial_query_decomposition_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + cross_viewpoint_analysis_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + viewpoint_categories: + [ + "scientific", + "political", + "economic", + "social", + "ethical", + "historical", + ] + + # New reflection steps (replacing iterative_reflection_step) + generate_reflection_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + get_research_approval_step: + parameters: + alerter_type: "slack" # or "email" if configured + + execute_approved_searches_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + pydantic_final_report_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 \ No newline at end of file diff --git a/deep_research/configs/parallel_research.yaml b/deep_research/configs/parallel_research.yaml new file mode 100644 index 00000000..25ea6df0 --- /dev/null +++ b/deep_research/configs/parallel_research.yaml @@ -0,0 +1,93 @@ +# Deep Research Pipeline Configuration - Parallelized Version +enable_cache: true + +# ZenML MCP +model: + name: "deep_research" + description: "Parallelized ZenML pipelines for deep research on a given query." + tags: + [ + "research", + "exa", + "tavily", + "openrouter", + "sambanova", + "langfuse", + "parallel", + ] + use_cases: "Research on a given query." + +# Langfuse project name for LLM tracking +langfuse_project_name: "deep-research" + +# Pipeline parameters +parameters: + query: "How are people balancing MLOps and Agents/LLMOps?" + max_sub_questions: 10 + +# Step parameters +steps: + initial_query_decomposition_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + process_sub_question_step: + parameters: + llm_model_search: "sambanova/Meta-Llama-3.3-70B-Instruct" + llm_model_synthesis: "sambanova/DeepSeek-R1-Distill-Llama-70B" + cap_search_length: 20000 + + merge_sub_question_results_step: + parameters: + step_prefix: "process_question_" + output_name: "output" + + cross_viewpoint_analysis_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + viewpoint_categories: ["scientific", "political", "economic", "social", "ethical", "historical"] + + generate_reflection_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + get_research_approval_step: + parameters: + timeout: 3600 + max_queries: 2 + + execute_approved_searches_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + cap_search_length: 20000 + + pydantic_final_report_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 + + # Uncomment and customize these settings when running with orchestrators that support parallelization + # orchestrator.kubeflow: + # synchronous: false + # resources: + # cpu_request: "1" + # memory_request: "2Gi" + # cpu_limit: "2" + # memory_limit: "4Gi" + + # orchestrator.kubernetes: + # synchronous: false + # resources: + # process_sub_question_step: + # cpu_request: "1" + # memory_request: "2Gi" \ No newline at end of file diff --git a/deep_research/configs/pipeline_config.yaml b/deep_research/configs/pipeline_config.yaml new file mode 100644 index 00000000..84e2520b --- /dev/null +++ b/deep_research/configs/pipeline_config.yaml @@ -0,0 +1,58 @@ +# Deep Research Pipeline Configuration +enable_cache: true + +# ZenML MCP +model: + name: "deep_research" + description: "Parallelized ZenML pipelines for deep research on a given query." + tags: + [ + "research", + "exa", + "tavily", + "openrouter", + "sambanova", + "langfuse", + "pipeline", + ] + use_cases: "Research on a given query." + +# Langfuse project name for LLM tracking +langfuse_project_name: "deep-research" + +# Research parameters +parameters: + query: "Default research query" # The research query/topic to investigate + +steps: + initial_query_decomposition_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + process_sub_question_step: + parameters: + llm_model_search: "sambanova/Meta-Llama-3.3-70B-Instruct" + llm_model_synthesis: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + generate_reflection_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + get_research_approval_step: + parameters: + auto_approve: false + + execute_approved_searches_step: + parameters: + llm_model: "sambanova/Meta-Llama-3.3-70B-Instruct" + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 \ No newline at end of file diff --git a/deep_research/configs/quick_research.yaml b/deep_research/configs/quick_research.yaml new file mode 100644 index 00000000..b210f18f --- /dev/null +++ b/deep_research/configs/quick_research.yaml @@ -0,0 +1,59 @@ +# Deep Research Pipeline Configuration - Quick Research +enable_cache: true + +# ZenML MCP +model: + name: "deep_research" + description: "Parallelized ZenML pipelines for deep research on a given query." + tags: + [ + "research", + "exa", + "tavily", + "openrouter", + "sambanova", + "langfuse", + "quick", + ] + use_cases: "Research on a given query." + +# Langfuse project name for LLM tracking +langfuse_project_name: "deep-research" + +# Research parameters for quick research +parameters: + query: "Default research query" + +steps: + initial_query_decomposition_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + max_sub_questions: 5 # Limit to fewer sub-questions for quick research + + process_sub_question_step: + parameters: + llm_model_search: "sambanova/Meta-Llama-3.3-70B-Instruct" + llm_model_synthesis: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + generate_reflection_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + get_research_approval_step: + parameters: + auto_approve: true # Auto-approve for quick research + + execute_approved_searches_step: + parameters: + llm_model: "sambanova/Meta-Llama-3.3-70B-Instruct" + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 diff --git a/deep_research/configs/rapid_research.yaml b/deep_research/configs/rapid_research.yaml new file mode 100644 index 00000000..e69982bf --- /dev/null +++ b/deep_research/configs/rapid_research.yaml @@ -0,0 +1,59 @@ +# Deep Research Pipeline Configuration - Quick Research +enable_cache: true + +# ZenML MCP +model: + name: "deep_research" + description: "Parallelized ZenML pipelines for deep research on a given query." + tags: + [ + "research", + "exa", + "tavily", + "openrouter", + "sambanova", + "langfuse", + "rapid", + ] + use_cases: "Research on a given query." + +# Langfuse project name for LLM tracking +langfuse_project_name: "deep-research" + +# Research parameters for quick research +parameters: + query: "Default research query" + +steps: + initial_query_decomposition_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + max_sub_questions: 5 # Limit to fewer sub-questions for quick research + + process_sub_question_step: + parameters: + llm_model_search: "sambanova/Meta-Llama-3.3-70B-Instruct" + llm_model_synthesis: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + generate_reflection_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + get_research_approval_step: + parameters: + auto_approve: true # Auto-approve for quick research + + execute_approved_searches_step: + parameters: + llm_model: "sambanova/Meta-Llama-3.3-70B-Instruct" + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 diff --git a/deep_research/configs/thorough_research.yaml b/deep_research/configs/thorough_research.yaml new file mode 100644 index 00000000..a798577e --- /dev/null +++ b/deep_research/configs/thorough_research.yaml @@ -0,0 +1,63 @@ +# Deep Research Pipeline Configuration - Thorough Research +enable_cache: true + +# ZenML MCP +model: + name: "deep_research" + description: "Parallelized ZenML pipelines for deep research on a given query." + tags: + [ + "research", + "exa", + "tavily", + "openrouter", + "sambanova", + "langfuse", + "thorough", + ] + use_cases: "Research on a given query." + +# Langfuse project name for LLM tracking +langfuse_project_name: "deep-research" + +# Research parameters for more thorough research +parameters: + query: "Quantum computing applications in cryptography" + +steps: + initial_query_decomposition_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + max_sub_questions: 12 # More sub-questions for thorough analysis + + process_sub_question_step: + parameters: + llm_model_search: "sambanova/Meta-Llama-3.3-70B-Instruct" + llm_model_synthesis: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + generate_reflection_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + + get_research_approval_step: + parameters: + auto_approve: false + + execute_approved_searches_step: + parameters: + llm_model: "sambanova/Meta-Llama-3.3-70B-Instruct" + + pydantic_final_report_step: + parameters: + llm_model: "sambanova/DeepSeek-R1-Distill-Llama-70B" + +# Environment settings +settings: + docker: + requirements: + - openai>=1.0.0 + - tavily-python>=0.2.8 + - PyYAML>=6.0 + - click>=8.0.0 + - pydantic>=2.0.0 + - typing_extensions>=4.0.0 diff --git a/deep_research/logging_config.py b/deep_research/logging_config.py new file mode 100644 index 00000000..2b93c3e0 --- /dev/null +++ b/deep_research/logging_config.py @@ -0,0 +1,42 @@ +import logging +import sys +from typing import Optional + + +def configure_logging( + level: int = logging.INFO, log_file: Optional[str] = None +): + """Configure logging for the application. + + Args: + level: The log level (default: INFO) + log_file: Optional path to a log file + """ + # Create formatter + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + + # Configure root logger + root_logger = logging.getLogger() + root_logger.setLevel(level) + + # Remove existing handlers to avoid duplicate logs + for handler in root_logger.handlers[:]: + root_logger.removeHandler(handler) + + # Console handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(formatter) + root_logger.addHandler(console_handler) + + # File handler if log_file is provided + if log_file: + file_handler = logging.FileHandler(log_file) + file_handler.setFormatter(formatter) + root_logger.addHandler(file_handler) + + # Reduce verbosity for noisy third-party libraries + logging.getLogger("LiteLLM").setLevel(logging.WARNING) + logging.getLogger("httpx").setLevel(logging.WARNING) + logging.getLogger("urllib3").setLevel(logging.WARNING) diff --git a/deep_research/materializers/__init__.py b/deep_research/materializers/__init__.py new file mode 100644 index 00000000..1479f72b --- /dev/null +++ b/deep_research/materializers/__init__.py @@ -0,0 +1,21 @@ +""" +Materializers package for the ZenML Deep Research project. + +This package contains custom ZenML materializers that handle serialization and +deserialization of complex data types used in the research pipeline, particularly +the ResearchState object that tracks the state of the research process. +""" + +from .approval_decision_materializer import ApprovalDecisionMaterializer +from .prompts_materializer import PromptsBundleMaterializer +from .pydantic_materializer import ResearchStateMaterializer +from .reflection_output_materializer import ReflectionOutputMaterializer +from .tracing_metadata_materializer import TracingMetadataMaterializer + +__all__ = [ + "ApprovalDecisionMaterializer", + "PromptsBundleMaterializer", + "ReflectionOutputMaterializer", + "ResearchStateMaterializer", + "TracingMetadataMaterializer", +] diff --git a/deep_research/materializers/approval_decision_materializer.py b/deep_research/materializers/approval_decision_materializer.py new file mode 100644 index 00000000..0a4ed2c3 --- /dev/null +++ b/deep_research/materializers/approval_decision_materializer.py @@ -0,0 +1,281 @@ +"""Materializer for ApprovalDecision with custom visualization.""" + +import os +from datetime import datetime +from typing import Dict + +from utils.pydantic_models import ApprovalDecision +from zenml.enums import ArtifactType, VisualizationType +from zenml.io import fileio +from zenml.materializers import PydanticMaterializer + + +class ApprovalDecisionMaterializer(PydanticMaterializer): + """Materializer for the ApprovalDecision class with visualizations.""" + + ASSOCIATED_TYPES = (ApprovalDecision,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + + def save_visualizations( + self, data: ApprovalDecision + ) -> Dict[str, VisualizationType]: + """Create and save visualizations for the ApprovalDecision. + + Args: + data: The ApprovalDecision to visualize + + Returns: + Dictionary mapping file paths to visualization types + """ + # Generate an HTML visualization + visualization_path = os.path.join(self.uri, "approval_decision.html") + + # Create HTML content + html_content = self._generate_visualization_html(data) + + # Write the HTML content to a file + with fileio.open(visualization_path, "w") as f: + f.write(html_content) + + # Return the visualization path and type + return {visualization_path: VisualizationType.HTML} + + def _generate_visualization_html(self, decision: ApprovalDecision) -> str: + """Generate HTML visualization for the approval decision. + + Args: + decision: The ApprovalDecision to visualize + + Returns: + HTML string + """ + # Format timestamp + decision_time = datetime.fromtimestamp(decision.timestamp).strftime( + "%Y-%m-%d %H:%M:%S" + ) + + # Determine status color and icon + if decision.approved: + status_color = "#27ae60" + status_icon = "✅" + status_text = "APPROVED" + else: + status_color = "#e74c3c" + status_icon = "❌" + status_text = "NOT APPROVED" + + # Format approval method + method_display = { + "APPROVE_ALL": "Approve All Queries", + "SKIP": "Skip Additional Research", + "SELECT_SPECIFIC": "Select Specific Queries", + }.get(decision.approval_method, decision.approval_method or "Unknown") + + html = f""" + + + + Approval Decision + + + +
+

+ 🔒 Approval Decision +
+ {status_icon} + {status_text} +
+

+ +
+
+
Approval Method
+
{method_display}
+
+
+
Decision Time
+
{decision_time}
+
+
+
Queries Selected
+
{len(decision.selected_queries)}
+
+
+ """ + + # Add selected queries section if any + if decision.selected_queries: + html += """ +
+

📋Selected Queries

+
+ """ + + for i, query in enumerate(decision.selected_queries, 1): + html += f""" +
+
{i}
+
{query}
+
+ """ + + html += """ +
+
+ """ + else: + html += """ +
+

📋Selected Queries

+
+ No queries were selected for additional research +
+
+ """ + + # Add reviewer notes if any + if decision.reviewer_notes: + html += f""" +
+

📝Reviewer Notes

+
+ {decision.reviewer_notes} +
+
+ """ + + # Add timestamp footer + html += f""" +
+ Decision recorded at: {decision_time} +
+
+ + + """ + + return html diff --git a/deep_research/materializers/prompts_materializer.py b/deep_research/materializers/prompts_materializer.py new file mode 100644 index 00000000..96e35dde --- /dev/null +++ b/deep_research/materializers/prompts_materializer.py @@ -0,0 +1,509 @@ +"""Materializer for PromptsBundle with custom HTML visualization. + +This module provides a materializer that creates beautiful HTML visualizations +for prompt bundles in the ZenML dashboard. +""" + +import os +from typing import Dict + +from utils.prompt_models import PromptsBundle +from zenml.enums import ArtifactType, VisualizationType +from zenml.io import fileio +from zenml.materializers import PydanticMaterializer + + +class PromptsBundleMaterializer(PydanticMaterializer): + """Materializer for PromptsBundle with custom visualization.""" + + ASSOCIATED_TYPES = (PromptsBundle,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + + def save_visualizations( + self, data: PromptsBundle + ) -> Dict[str, VisualizationType]: + """Create and save visualizations for the PromptsBundle. + + Args: + data: The PromptsBundle to visualize + + Returns: + Dictionary mapping file paths to visualization types + """ + # Generate an HTML visualization + visualization_path = os.path.join(self.uri, "prompts_bundle.html") + + # Create HTML content + html_content = self._generate_visualization_html(data) + + # Write the HTML content to a file + with fileio.open(visualization_path, "w") as f: + f.write(html_content) + + # Return the visualization path and type + return {visualization_path: VisualizationType.HTML} + + def _generate_visualization_html(self, bundle: PromptsBundle) -> str: + """Generate HTML visualization for the prompts bundle. + + Args: + bundle: The PromptsBundle to visualize + + Returns: + HTML string + """ + # Create HTML content + html = f""" + + + + Prompts Bundle - {bundle.created_at} + + + +
+
+

🎯 Prompts Bundle

+ +
+
+ {len(bundle.list_all_prompts())} + Total Prompts +
+
+ {len([p for p in bundle.list_all_prompts().values() if p.tags])} + Tagged Prompts +
+
+ {len(bundle.custom_prompts)} + Custom Prompts +
+
+
+ + + +
+ """ + + # Add each prompt + prompts = [ + ( + "search_query_prompt", + bundle.search_query_prompt, + ["search", "query"], + ), + ( + "query_decomposition_prompt", + bundle.query_decomposition_prompt, + ["analysis", "decomposition"], + ), + ( + "synthesis_prompt", + bundle.synthesis_prompt, + ["synthesis", "integration"], + ), + ( + "viewpoint_analysis_prompt", + bundle.viewpoint_analysis_prompt, + ["analysis", "viewpoint"], + ), + ( + "reflection_prompt", + bundle.reflection_prompt, + ["reflection", "critique"], + ), + ( + "additional_synthesis_prompt", + bundle.additional_synthesis_prompt, + ["synthesis", "enhancement"], + ), + ( + "conclusion_generation_prompt", + bundle.conclusion_generation_prompt, + ["report", "conclusion"], + ), + ] + + for prompt_type, prompt, default_tags in prompts: + # Use provided tags or default tags + tags = prompt.tags if prompt.tags else default_tags + + html += f""" +
+
+

{prompt.name}

+ v{prompt.version} +
+ {f'

{prompt.description}

' if prompt.description else ""} +
+ {"".join([f'{tag}' for tag in tags])} +
+
+ + {self._escape_html(prompt.content)} +
+ +
+ """ + + # Add custom prompts if any + for name, prompt in bundle.custom_prompts.items(): + tags = prompt.tags if prompt.tags else ["custom"] + html += f""" +
+
+

{prompt.name}

+ v{prompt.version} +
+ {f'

{prompt.description}

' if prompt.description else ""} +
+ {"".join([f'{tag}' for tag in tags])} +
+
+ + {self._escape_html(prompt.content)} +
+ +
+ """ + + html += """ +
+ +
+ + + + + """ + + return html + + def _escape_html(self, text: str) -> str: + """Escape HTML special characters. + + Args: + text: Text to escape + + Returns: + Escaped text + """ + return ( + text.replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace('"', """) + .replace("'", "'") + ) diff --git a/deep_research/materializers/pydantic_materializer.py b/deep_research/materializers/pydantic_materializer.py new file mode 100644 index 00000000..ee01281b --- /dev/null +++ b/deep_research/materializers/pydantic_materializer.py @@ -0,0 +1,764 @@ +"""Pydantic materializer for research state objects. + +This module contains an extended version of ZenML's PydanticMaterializer +that adds visualization capabilities for the ResearchState model. +""" + +import os +from typing import Dict + +from utils.pydantic_models import ResearchState +from zenml.enums import ArtifactType, VisualizationType +from zenml.io import fileio +from zenml.materializers import PydanticMaterializer + + +class ResearchStateMaterializer(PydanticMaterializer): + """Materializer for the ResearchState class with visualizations.""" + + ASSOCIATED_TYPES = (ResearchState,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + + def save_visualizations( + self, data: ResearchState + ) -> Dict[str, VisualizationType]: + """Create and save visualizations for the ResearchState. + + Args: + data: The ResearchState to visualize + + Returns: + Dictionary mapping file paths to visualization types + """ + # Generate an HTML visualization + visualization_path = os.path.join(self.uri, "research_state.html") + + # Create HTML content based on current stage + html_content = self._generate_visualization_html(data) + + # Write the HTML content to a file + with fileio.open(visualization_path, "w") as f: + f.write(html_content) + + # Return the visualization path and type + return {visualization_path: VisualizationType.HTML} + + def _generate_visualization_html(self, state: ResearchState) -> str: + """Generate HTML visualization for the research state. + + Args: + state: The ResearchState to visualize + + Returns: + HTML string + """ + # Base structure for the HTML + html = f""" + + + + Research State: {state.main_query} + + + + +
+

Research State

+ + +
+
Initial Query
+
Query Decomposition
+
Information Gathering
+
Information Synthesis
+
Viewpoint Analysis
+
Reflection & Enhancement
+
Final Report
+
+ + +
+
+
+ + +
    + """ + + # Determine which tab should be active based on current stage + current_stage = state.get_current_stage() + + # Map stages to tabs + stage_to_tab = { + "empty": "overview", + "initial": "overview", + "after_query_decomposition": "sub-questions", + "after_search": "search-results", + "after_synthesis": "synthesis", + "after_viewpoint_analysis": "viewpoints", + "after_reflection": "reflection", + "final_report": "final-report", + } + + # Get the default active tab based on stage + default_active_tab = stage_to_tab.get(current_stage, "overview") + + # Create tab headers dynamically based on available data + tabs_created = [] + + # Overview tab is always shown + is_active = default_active_tab == "overview" + html += f'
  • Overview
  • ' + tabs_created.append("overview") + + if state.sub_questions: + is_active = default_active_tab == "sub-questions" + html += f'
  • Sub-Questions
  • ' + tabs_created.append("sub-questions") + + if state.search_results: + is_active = default_active_tab == "search-results" + html += f'
  • Search Results
  • ' + tabs_created.append("search-results") + + if state.synthesized_info: + is_active = default_active_tab == "synthesis" + html += f'
  • Synthesis
  • ' + tabs_created.append("synthesis") + + if state.viewpoint_analysis: + is_active = default_active_tab == "viewpoints" + html += f'
  • Viewpoints
  • ' + tabs_created.append("viewpoints") + + if state.enhanced_info or state.reflection_metadata: + is_active = default_active_tab == "reflection" + html += f'
  • Reflection
  • ' + tabs_created.append("reflection") + + if state.final_report_html: + is_active = default_active_tab == "final-report" + html += f'
  • Final Report
  • ' + tabs_created.append("final-report") + + # Ensure the active tab actually exists in the created tabs + # If not, fallback to the first available tab + if default_active_tab not in tabs_created and tabs_created: + default_active_tab = tabs_created[0] + + html += """ +
+ + + """ + + # Overview tab content (always shown) + is_active = default_active_tab == "overview" + html += f""" +
+
+

Main Query

+
+ """ + + if state.main_query: + html += f"

{state.main_query}

" + else: + html += "

No main query specified

" + + html += """ +
+
+
+ """ + + # Sub-questions tab content + if state.sub_questions: + is_active = default_active_tab == "sub-questions" + html += f""" +
+
+

Sub-Questions ({len(state.sub_questions)})

+
+ """ + + for i, question in enumerate(state.sub_questions): + html += f""" +
+ {i + 1}. {question} +
+ """ + + html += """ +
+
+
+ """ + + # Search results tab content + if state.search_results: + is_active = default_active_tab == "search-results" + html += f""" +
+
+

Search Results

+ """ + + for question, results in state.search_results.items(): + html += f""" +

{question}

+

Found {len(results)} results

+
    + """ + + for result in results: + # Extract domain from URL or use special handling for generated content + if result.url == "tavily-generated-answer": + domain = "Tavily" + else: + domain = "" + try: + from urllib.parse import urlparse + + parsed_url = urlparse(result.url) + domain = parsed_url.netloc + # Strip www. prefix to save space + if domain.startswith("www."): + domain = domain[4:] + except: + domain = ( + result.url.split("/")[2] + if len(result.url.split("/")) > 2 + else "" + ) + # Strip www. prefix to save space + if domain.startswith("www."): + domain = domain[4:] + + html += f""" +
  • + {result.title} ({domain}) +
  • + """ + + html += """ +
+ """ + + html += """ +
+
+ """ + + # Synthesized information tab content + if state.synthesized_info: + is_active = default_active_tab == "synthesis" + html += f""" +
+
+

Synthesized Information

+ """ + + for question, info in state.synthesized_info.items(): + html += f""" +

{question} {info.confidence_level}

+
+

{info.synthesized_answer}

+ """ + + if info.key_sources: + html += """ +
+

Key Sources:

+
    + """ + + for source in info.key_sources[:3]: + html += f""" +
  • {source[:50]}...
  • + """ + + if len(info.key_sources) > 3: + html += f"
  • ...and {len(info.key_sources) - 3} more sources
  • " + + html += """ +
+
+ """ + + if info.information_gaps: + html += f""" + + """ + + html += """ +
+ """ + + html += """ +
+
+ """ + + # Viewpoint analysis tab content + if state.viewpoint_analysis: + is_active = default_active_tab == "viewpoints" + html += f""" +
+
+

Viewpoint Analysis

+
+ """ + + # Points of agreement + if state.viewpoint_analysis.main_points_of_agreement: + html += """ +

Points of Agreement

+
    + """ + + for point in state.viewpoint_analysis.main_points_of_agreement: + html += f""" +
  • {point}
  • + """ + + html += """ +
+ """ + + # Areas of tension + if state.viewpoint_analysis.areas_of_tension: + html += """ +

Areas of Tension

+ """ + + for tension in state.viewpoint_analysis.areas_of_tension: + html += f""" +
+

{tension.topic}

+
    + """ + + for viewpoint, description in tension.viewpoints.items(): + html += f""" +
  • {viewpoint}: {description}
  • + """ + + html += """ +
+
+ """ + + # Perspective gaps and integrative insights + if state.viewpoint_analysis.perspective_gaps: + html += f""" +

Perspective Gaps

+

{state.viewpoint_analysis.perspective_gaps}

+ """ + + if state.viewpoint_analysis.integrative_insights: + html += f""" +

Integrative Insights

+

{state.viewpoint_analysis.integrative_insights}

+ """ + + html += """ +
+
+
+ """ + + # Reflection & Enhancement tab content + if state.enhanced_info or state.reflection_metadata: + is_active = default_active_tab == "reflection" + html += f""" +
+
+

Reflection & Enhancement

+ """ + + # Reflection metadata + if state.reflection_metadata: + html += """ +
+ """ + + if state.reflection_metadata.critique_summary: + html += """ +

Critique Summary

+
    + """ + + for critique in state.reflection_metadata.critique_summary: + html += f""" +
  • {critique}
  • + """ + + html += """ +
+ """ + + if state.reflection_metadata.additional_questions_identified: + html += """ +

Additional Questions Identified

+
    + """ + + for question in state.reflection_metadata.additional_questions_identified: + html += f""" +
  • {question}
  • + """ + + html += """ +
+ """ + + html += f""" + +
+ """ + + # Enhanced information + if state.enhanced_info: + html += """ +

Enhanced Information

+ """ + + for question, info in state.enhanced_info.items(): + # Show only for questions with improvements + if info.improvements: + html += f""" +
+

{question} {info.confidence_level}

+ +
+

Improvements Made:

+
    + """ + + for improvement in info.improvements: + html += f""" +
  • {improvement}
  • + """ + + html += """ +
+
+
+ """ + + html += """ +
+
+ """ + + # Final report tab + if state.final_report_html: + is_active = default_active_tab == "final-report" + html += f""" +
+
+

Final Report

+

Final HTML report is available but not displayed here. View the HTML artifact to see the complete report.

+
+
+ """ + + # Close HTML tags + html += """ +
+ + + """ + + return html + + def _get_stage_class(self, state: ResearchState, stage: str) -> str: + """Get CSS class for a stage based on current progress. + + Args: + state: ResearchState object + stage: Stage name + + Returns: + CSS class string + """ + current_stage = state.get_current_stage() + + # These are the stages in order + stages = [ + "empty", + "initial", + "after_query_decomposition", + "after_search", + "after_synthesis", + "after_viewpoint_analysis", + "after_reflection", + "final_report", + ] + + current_index = ( + stages.index(current_stage) if current_stage in stages else 0 + ) + stage_index = stages.index(stage) if stage in stages else 0 + + if stage_index == current_index: + return "active" + elif stage_index < current_index: + return "completed" + else: + return "" + + def _calculate_progress(self, state: ResearchState) -> int: + """Calculate overall progress percentage. + + Args: + state: ResearchState object + + Returns: + Progress percentage (0-100) + """ + # Map stages to progress percentages + stage_percentages = { + "empty": 0, + "initial": 5, + "after_query_decomposition": 20, + "after_search": 40, + "after_synthesis": 60, + "after_viewpoint_analysis": 75, + "after_reflection": 90, + "final_report": 100, + } + + current_stage = state.get_current_stage() + return stage_percentages.get(current_stage, 0) diff --git a/deep_research/materializers/reflection_output_materializer.py b/deep_research/materializers/reflection_output_materializer.py new file mode 100644 index 00000000..1e8b37ae --- /dev/null +++ b/deep_research/materializers/reflection_output_materializer.py @@ -0,0 +1,279 @@ +"""Materializer for ReflectionOutput with custom visualization.""" + +import os +from typing import Dict + +from utils.pydantic_models import ReflectionOutput +from zenml.enums import ArtifactType, VisualizationType +from zenml.io import fileio +from zenml.materializers import PydanticMaterializer + + +class ReflectionOutputMaterializer(PydanticMaterializer): + """Materializer for the ReflectionOutput class with visualizations.""" + + ASSOCIATED_TYPES = (ReflectionOutput,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + + def save_visualizations( + self, data: ReflectionOutput + ) -> Dict[str, VisualizationType]: + """Create and save visualizations for the ReflectionOutput. + + Args: + data: The ReflectionOutput to visualize + + Returns: + Dictionary mapping file paths to visualization types + """ + # Generate an HTML visualization + visualization_path = os.path.join(self.uri, "reflection_output.html") + + # Create HTML content + html_content = self._generate_visualization_html(data) + + # Write the HTML content to a file + with fileio.open(visualization_path, "w") as f: + f.write(html_content) + + # Return the visualization path and type + return {visualization_path: VisualizationType.HTML} + + def _generate_visualization_html(self, output: ReflectionOutput) -> str: + """Generate HTML visualization for the reflection output. + + Args: + output: The ReflectionOutput to visualize + + Returns: + HTML string + """ + html = f""" + + + + Reflection Output + + + +
+

🔍 Reflection & Analysis Output

+ + + +
+

+ 📝Critique Summary + {} +

+ """.format(len(output.critique_summary)) + + if output.critique_summary: + for critique in output.critique_summary: + html += """ +
+ """ + + # Handle different critique formats + if isinstance(critique, dict): + for key, value in critique.items(): + html += f""" +
{key}:
+
{value}
+ """ + else: + html += f""" +
{critique}
+ """ + + html += """ +
+ """ + else: + html += """ +

No critique summary available

+ """ + + html += """ +
+ +
+

+ Additional Questions Identified + {} +

+ """.format(len(output.additional_questions)) + + if output.additional_questions: + for question in output.additional_questions: + html += f""" +
+ {question} +
+ """ + else: + html += """ +

No additional questions identified

+ """ + + html += """ +
+ +
+

📊Research State Summary

+
+

Main Query: {}

+

Current Stage: {}

+

Sub-questions: {}

+

Search Results: {} queries with results

+

Synthesized Info: {} topics synthesized

+
+
+ """.format( + output.state.main_query, + output.state.get_current_stage().replace("_", " ").title(), + len(output.state.sub_questions), + len(output.state.search_results), + len(output.state.synthesized_info), + ) + + # Add metadata section + html += """ + +
+ + + """ + + return html diff --git a/deep_research/materializers/tracing_metadata_materializer.py b/deep_research/materializers/tracing_metadata_materializer.py new file mode 100644 index 00000000..7cf7b51b --- /dev/null +++ b/deep_research/materializers/tracing_metadata_materializer.py @@ -0,0 +1,603 @@ +"""Materializer for TracingMetadata with custom visualization.""" + +import os +from typing import Dict + +from utils.pydantic_models import TracingMetadata +from zenml.enums import ArtifactType, VisualizationType +from zenml.io import fileio +from zenml.materializers import PydanticMaterializer + + +class TracingMetadataMaterializer(PydanticMaterializer): + """Materializer for the TracingMetadata class with visualizations.""" + + ASSOCIATED_TYPES = (TracingMetadata,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + + def save_visualizations( + self, data: TracingMetadata + ) -> Dict[str, VisualizationType]: + """Create and save visualizations for the TracingMetadata. + + Args: + data: The TracingMetadata to visualize + + Returns: + Dictionary mapping file paths to visualization types + """ + # Generate an HTML visualization + visualization_path = os.path.join(self.uri, "tracing_metadata.html") + + # Create HTML content + html_content = self._generate_visualization_html(data) + + # Write the HTML content to a file + with fileio.open(visualization_path, "w") as f: + f.write(html_content) + + # Return the visualization path and type + return {visualization_path: VisualizationType.HTML} + + def _generate_visualization_html(self, metadata: TracingMetadata) -> str: + """Generate HTML visualization for the tracing metadata. + + Args: + metadata: The TracingMetadata to visualize + + Returns: + HTML string + """ + # Calculate some derived values + avg_cost_per_token = metadata.total_cost / max( + metadata.total_tokens, 1 + ) + + # Base structure for the HTML + html = f""" + + + + Pipeline Tracing Metadata + + + +
+

Pipeline Tracing Metadata

+ +
+
+
Pipeline Run
+
{metadata.pipeline_run_name}
+
+ +
+
LLM Cost
+
${metadata.total_cost:.4f}
+
+ +
+
Total Tokens
+
{metadata.total_tokens:,}
+
+ +
+
Duration
+
{metadata.formatted_latency}
+
+
+ +

Token Usage

+
+
+
Input Tokens
+
{metadata.total_input_tokens:,}
+
+ +
+
Output Tokens
+
{metadata.total_output_tokens:,}
+
+ +
+
Observations
+
{metadata.observation_count}
+
+ +
+
Avg Cost per Token
+
${avg_cost_per_token:.6f}
+
+
+ +

Model Usage Breakdown

+ + + + + + + + + + + + """ + + # Add model breakdown + for model in metadata.models_used: + tokens = metadata.model_token_breakdown.get(model, {}) + cost = metadata.cost_breakdown_by_model.get(model, 0.0) + html += f""" + + + + + + + + """ + + html += """ + +
ModelInput TokensOutput TokensTotal TokensCost
{model}{tokens.get("input_tokens", 0):,}{tokens.get("output_tokens", 0):,}{tokens.get("total_tokens", 0):,}${cost:.4f}
+ """ + + # Add prompt-level metrics visualization + if metadata.prompt_metrics: + html += """ +

Cost Analysis by Prompt Type

+ + + + + +
+ +
+ + +
+ +
+ + +

Prompt Type Efficiency

+ + + + + + + + + + + + + + """ + + # Add prompt metrics rows + for metric in metadata.prompt_metrics: + # Format prompt type name nicely + prompt_type_display = metric.prompt_type.replace( + "_", " " + ).title() + html += f""" + + + + + + + + + + """ + + html += """ + +
Prompt TypeTotal CostCallsAvg $/Call% of TotalInput TokensOutput Tokens
{prompt_type_display}${metric.total_cost:.4f}{metric.call_count}${metric.avg_cost_per_call:.4f}{metric.percentage_of_total_cost:.1f}%{metric.input_tokens:,}{metric.output_tokens:,}
+ + + """ + + # Add search cost visualization if available + if metadata.search_costs and any(metadata.search_costs.values()): + total_search_cost = sum(metadata.search_costs.values()) + total_combined_cost = metadata.total_cost + total_search_cost + + html += """ +

Search Provider Costs

+
+ """ + + for provider, cost in metadata.search_costs.items(): + if cost > 0: + query_count = metadata.search_queries_count.get( + provider, 0 + ) + avg_cost_per_query = ( + cost / query_count if query_count > 0 else 0 + ) + html += f""" +
+
{provider.upper()} Search
+
${cost:.4f}
+
+ {query_count} queries • ${avg_cost_per_query:.4f}/query +
+
+ """ + + html += ( + f""" +
+
Total Search Cost
+
${total_search_cost:.4f}
+
+ {sum(metadata.search_queries_count.values())} total queries +
+
+
+ +

Combined Cost Summary

+
+
+
LLM Cost
+
${metadata.total_cost:.4f}
+
+ {(metadata.total_cost / total_combined_cost * 100):.1f}% of total +
+
+
+
Search Cost
+
${total_search_cost:.4f}
+
+ {(total_search_cost / total_combined_cost * 100):.1f}% of total +
+
+
+
Total Pipeline Cost
+
${total_combined_cost:.4f}
+
+
+ +

Cost Breakdown Chart

+ + + """ + ) + + # Add trace metadata + if metadata.trace_tags or metadata.trace_metadata: + html += """ +

Trace Information

+
+ """ + + if metadata.trace_tags: + html += """ +

Tags

+
+ """ + for tag in metadata.trace_tags: + html += f'{tag}' + html += """ +
+ """ + + if metadata.trace_metadata: + html += """ +

Metadata

+
+                """
+                import json
+
+                html += json.dumps(metadata.trace_metadata, indent=2)
+                html += """
+                    
+ """ + + html += """ +
+ """ + + # Add footer with collection info + from datetime import datetime + + collection_time = datetime.fromtimestamp( + metadata.collected_at + ).strftime("%Y-%m-%d %H:%M:%S") + + html += f""" + +
+ + + """ + + return html diff --git a/deep_research/pipelines/__init__.py b/deep_research/pipelines/__init__.py new file mode 100644 index 00000000..7f4ea5eb --- /dev/null +++ b/deep_research/pipelines/__init__.py @@ -0,0 +1,11 @@ +""" +Pipelines package for the ZenML Deep Research project. + +This package contains the ZenML pipeline definitions for running deep research +workflows. Each pipeline orchestrates a sequence of steps for comprehensive +research on a given query topic. +""" + +from .parallel_research_pipeline import parallelized_deep_research_pipeline + +__all__ = ["parallelized_deep_research_pipeline"] diff --git a/deep_research/pipelines/parallel_research_pipeline.py b/deep_research/pipelines/parallel_research_pipeline.py new file mode 100644 index 00000000..bd7afe14 --- /dev/null +++ b/deep_research/pipelines/parallel_research_pipeline.py @@ -0,0 +1,133 @@ +from steps.approval_step import get_research_approval_step +from steps.collect_tracing_metadata_step import collect_tracing_metadata_step +from steps.cross_viewpoint_step import cross_viewpoint_analysis_step +from steps.execute_approved_searches_step import execute_approved_searches_step +from steps.generate_reflection_step import generate_reflection_step +from steps.initialize_prompts_step import initialize_prompts_step +from steps.merge_results_step import merge_sub_question_results_step +from steps.process_sub_question_step import process_sub_question_step +from steps.pydantic_final_report_step import pydantic_final_report_step +from steps.query_decomposition_step import initial_query_decomposition_step +from utils.pydantic_models import ResearchState +from zenml import pipeline + + +@pipeline(enable_cache=False) +def parallelized_deep_research_pipeline( + query: str = "What is ZenML?", + max_sub_questions: int = 10, + require_approval: bool = False, + approval_timeout: int = 3600, + max_additional_searches: int = 2, + search_provider: str = "tavily", + search_mode: str = "auto", + num_results_per_search: int = 3, + langfuse_project_name: str = "deep-research", +) -> None: + """Parallelized ZenML pipeline for deep research on a given query. + + This pipeline uses the fan-out/fan-in pattern for parallel processing of sub-questions, + potentially improving execution time when using distributed orchestrators. + + Args: + query: The research query/topic + max_sub_questions: Maximum number of sub-questions to process in parallel + require_approval: Whether to require human approval for additional searches + approval_timeout: Timeout in seconds for human approval + max_additional_searches: Maximum number of additional searches to perform + search_provider: Search provider to use (tavily, exa, or both) + search_mode: Search mode for Exa provider (neural, keyword, or auto) + num_results_per_search: Number of search results to return per query + langfuse_project_name: Langfuse project name for LLM tracking + + Returns: + Formatted research report as HTML + """ + # Initialize prompts bundle for tracking + prompts_bundle = initialize_prompts_step(pipeline_version="1.0.0") + + # Initialize the research state with the main query + state = ResearchState(main_query=query) + + # Step 1: Decompose the query into sub-questions, limiting to max_sub_questions + decomposed_state = initial_query_decomposition_step( + state=state, + prompts_bundle=prompts_bundle, + max_sub_questions=max_sub_questions, + langfuse_project_name=langfuse_project_name, + ) + + # Fan out: Process each sub-question in parallel + # Collect artifacts to establish dependencies for the merge step + after = [] + for i in range(max_sub_questions): + # Process the i-th sub-question (if it exists) + sub_state = process_sub_question_step( + state=decomposed_state, + prompts_bundle=prompts_bundle, + question_index=i, + search_provider=search_provider, + search_mode=search_mode, + num_results_per_search=num_results_per_search, + langfuse_project_name=langfuse_project_name, + id=f"process_question_{i + 1}", + ) + after.append(sub_state) + + # Fan in: Merge results from all parallel processing + # The 'after' parameter ensures this step runs after all processing steps + # It doesn't directly use the processed_states input + merged_state = merge_sub_question_results_step( + original_state=decomposed_state, + step_prefix="process_question_", + output_name="output", + after=after, # This creates the dependency + ) + + # Continue with subsequent steps + analyzed_state = cross_viewpoint_analysis_step( + state=merged_state, + prompts_bundle=prompts_bundle, + langfuse_project_name=langfuse_project_name, + ) + + # New 3-step reflection flow with optional human approval + # Step 1: Generate reflection and recommendations (no searches yet) + reflection_output = generate_reflection_step( + state=analyzed_state, + prompts_bundle=prompts_bundle, + langfuse_project_name=langfuse_project_name, + ) + + # Step 2: Get approval for recommended searches + approval_decision = get_research_approval_step( + reflection_output=reflection_output, + require_approval=require_approval, + timeout=approval_timeout, + max_queries=max_additional_searches, + ) + + # Step 3: Execute approved searches (if any) + reflected_state = execute_approved_searches_step( + reflection_output=reflection_output, + approval_decision=approval_decision, + prompts_bundle=prompts_bundle, + search_provider=search_provider, + search_mode=search_mode, + num_results_per_search=num_results_per_search, + langfuse_project_name=langfuse_project_name, + ) + + # Use our new Pydantic-based final report step + # This returns a tuple (state, html_report) + final_state, final_report = pydantic_final_report_step( + state=reflected_state, + prompts_bundle=prompts_bundle, + langfuse_project_name=langfuse_project_name, + ) + + # Collect tracing metadata for the entire pipeline run + _, tracing_metadata = collect_tracing_metadata_step( + state=final_state, + langfuse_project_name=langfuse_project_name, + ) diff --git a/deep_research/requirements.txt b/deep_research/requirements.txt new file mode 100644 index 00000000..6d2f9166 --- /dev/null +++ b/deep_research/requirements.txt @@ -0,0 +1,10 @@ +zenml>=0.82.0 +litellm>=1.70.0,<2.0.0 +tavily-python>=0.2.8 +exa-py>=1.0.0 +PyYAML>=6.0 +click>=8.0.0 +pydantic>=2.0.0 +typing_extensions>=4.0.0 +requests +langfuse>=2.0.0 diff --git a/deep_research/run.py b/deep_research/run.py new file mode 100644 index 00000000..aa1745c3 --- /dev/null +++ b/deep_research/run.py @@ -0,0 +1,330 @@ +import logging +import os + +import click +import yaml +from logging_config import configure_logging +from pipelines.parallel_research_pipeline import ( + parallelized_deep_research_pipeline, +) +from utils.helper_functions import check_required_env_vars + +logger = logging.getLogger(__name__) + + +# Research mode presets for easy configuration +RESEARCH_MODES = { + "rapid": { + "max_sub_questions": 5, + "num_results_per_search": 2, + "max_additional_searches": 0, + "description": "Quick research with minimal depth - great for getting a fast overview", + }, + "balanced": { + "max_sub_questions": 10, + "num_results_per_search": 3, + "max_additional_searches": 2, + "description": "Balanced research with moderate depth - ideal for most use cases", + }, + "deep": { + "max_sub_questions": 15, + "num_results_per_search": 5, + "max_additional_searches": 4, + "description": "Comprehensive research with maximum depth - for thorough analysis", + "suggest_approval": True, # Suggest using approval for deep mode + }, +} + + +@click.command( + help=""" +Deep Research Agent - ZenML Pipeline for Comprehensive Research + +Run a deep research pipeline that: +1. Generates a structured report outline +2. Researches each topic with web searches and LLM analysis +3. Refines content through multiple reflection cycles +4. Produces a formatted, comprehensive research report + +Examples: + + \b + # Run with default configuration + python run.py + + \b + # Use a research mode preset for easy configuration + python run.py --mode rapid # Quick overview + python run.py --mode balanced # Standard research (default) + python run.py --mode deep # Comprehensive analysis + + \b + # Run with a custom pipeline configuration file + python run.py --config configs/custom_pipeline.yaml + + \b + # Override the research query + python run.py --query "My research topic" + + \b + # Combine mode with other options + python run.py --mode deep --query "Complex topic" --require-approval + + \b + # Run with a custom number of sub-questions + python run.py --max-sub-questions 15 +""" +) +@click.option( + "--mode", + type=click.Choice(["rapid", "balanced", "deep"], case_sensitive=False), + default=None, + help="Research mode preset: rapid (fast overview), balanced (standard), or deep (comprehensive)", +) +@click.option( + "--config", + type=str, + default="configs/enhanced_research.yaml", + help="Path to the pipeline configuration YAML file", +) +@click.option( + "--no-cache", + is_flag=True, + default=False, + help="Disable caching for the pipeline run", +) +@click.option( + "--log-file", + type=str, + default=None, + help="Path to log file (if not provided, logs only go to console)", +) +@click.option( + "--debug", + is_flag=True, + default=False, + help="Enable debug logging", +) +@click.option( + "--query", + type=str, + default=None, + help="Research query (overrides the query in the config file)", +) +@click.option( + "--max-sub-questions", + type=int, + default=10, + help="Maximum number of sub-questions to process in parallel", +) +@click.option( + "--require-approval", + is_flag=True, + default=False, + help="Enable human-in-the-loop approval for additional searches", +) +@click.option( + "--approval-timeout", + type=int, + default=3600, + help="Timeout in seconds for human approval (default: 3600)", +) +@click.option( + "--search-provider", + type=click.Choice(["tavily", "exa", "both"], case_sensitive=False), + default=None, + help="Search provider to use: tavily (default), exa, or both", +) +@click.option( + "--search-mode", + type=click.Choice(["neural", "keyword", "auto"], case_sensitive=False), + default="auto", + help="Search mode for Exa provider: neural, keyword, or auto (default: auto)", +) +@click.option( + "--num-results", + type=int, + default=3, + help="Number of search results to return per query (default: 3)", +) +def main( + mode: str = None, + config: str = "configs/enhanced_research.yaml", + no_cache: bool = False, + log_file: str = None, + debug: bool = False, + query: str = None, + max_sub_questions: int = 10, + require_approval: bool = False, + approval_timeout: int = 3600, + search_provider: str = None, + search_mode: str = "auto", + num_results: int = 3, +): + """Run the deep research pipeline. + + Args: + mode: Research mode preset (rapid, balanced, or deep) + config: Path to the pipeline configuration YAML file + no_cache: Disable caching for the pipeline run + log_file: Path to log file + debug: Enable debug logging + query: Research query (overrides the query in the config file) + max_sub_questions: Maximum number of sub-questions to process in parallel + require_approval: Enable human-in-the-loop approval for additional searches + approval_timeout: Timeout in seconds for human approval + search_provider: Search provider to use (tavily, exa, or both) + search_mode: Search mode for Exa provider (neural, keyword, or auto) + num_results: Number of search results to return per query + """ + # Configure logging + log_level = logging.DEBUG if debug else logging.INFO + configure_logging(level=log_level, log_file=log_file) + + # Apply mode presets if specified + if mode: + mode_config = RESEARCH_MODES[mode.lower()] + logger.info(f"\n{'=' * 80}") + logger.info(f"Using research mode: {mode.upper()}") + logger.info(f"Description: {mode_config['description']}") + + # Apply mode parameters (can be overridden by explicit arguments) + if max_sub_questions == 10: # Default value - apply mode preset + max_sub_questions = mode_config["max_sub_questions"] + logger.info(f" - Max sub-questions: {max_sub_questions}") + + # Store mode config for later use + mode_max_additional_searches = mode_config["max_additional_searches"] + + # Use mode's num_results_per_search only if user didn't override with --num-results + if num_results == 3: # Default value - apply mode preset + num_results = mode_config["num_results_per_search"] + + logger.info( + f" - Max additional searches: {mode_max_additional_searches}" + ) + logger.info(f" - Results per search: {num_results}") + + # Check if a mode-specific config exists and user didn't override config + if config == "configs/enhanced_research.yaml": # Default config + mode_specific_config = f"configs/{mode.lower()}_research.yaml" + if os.path.exists(mode_specific_config): + config = mode_specific_config + logger.info(f" - Using mode-specific config: {config}") + + # Suggest approval for deep mode if not already enabled + if mode_config.get("suggest_approval") and not require_approval: + logger.info(f"\n{'!' * 60}") + logger.info( + f"! TIP: Consider using --require-approval with {mode} mode" + ) + logger.info(f"! for better control over comprehensive research") + logger.info(f"{'!' * 60}") + + logger.info(f"{'=' * 80}\n") + else: + # Default values if no mode specified + mode_max_additional_searches = 2 + + # Check that required environment variables are present using the helper function + required_vars = ["SAMBANOVA_API_KEY"] + + # Add provider-specific API key requirements + if search_provider in {"exa", "both"}: + required_vars.append("EXA_API_KEY") + if search_provider in {"tavily", "both", None}: # Default is tavily + required_vars.append("TAVILY_API_KEY") + + if missing_vars := check_required_env_vars(required_vars): + logger.error( + f"The following required environment variables are not set: {', '.join(missing_vars)}" + ) + logger.info("Please set them with:") + for var in missing_vars: + logger.info(f" export {var}=your_{var.lower()}_here") + return + + # Set pipeline options + pipeline_options = {"config_path": config} + + if no_cache: + pipeline_options["enable_cache"] = False + + logger.info("\n" + "=" * 80) + logger.info("Starting Deep Research") + logger.info("Using parallel pipeline for efficient execution") + + # Log search provider settings + if search_provider: + logger.info(f"Search provider: {search_provider.upper()}") + if search_provider == "exa": + logger.info(f" - Search mode: {search_mode}") + elif search_provider == "both": + logger.info(f" - Running both Tavily and Exa searches") + logger.info(f" - Exa search mode: {search_mode}") + else: + logger.info("Search provider: TAVILY (default)") + + # Log num_results if custom value or no mode preset + if num_results != 3 or not mode: + logger.info(f"Results per search: {num_results}") + + langfuse_project_name = "deep-research" # default + try: + with open(config, "r") as f: + config_data = yaml.safe_load(f) + langfuse_project_name = config_data.get( + "langfuse_project_name", "deep-research" + ) + except Exception as e: + logger.warning( + f"Could not load langfuse_project_name from config: {e}" + ) + + # Set up the pipeline with the parallelized version as default + pipeline = parallelized_deep_research_pipeline.with_options( + **pipeline_options + ) + + # Execute the pipeline + if query: + logger.info( + f"Using query: {query} with max {max_sub_questions} parallel sub-questions" + ) + if require_approval: + logger.info( + f"Human approval enabled with {approval_timeout}s timeout" + ) + pipeline( + query=query, + max_sub_questions=max_sub_questions, + require_approval=require_approval, + approval_timeout=approval_timeout, + max_additional_searches=mode_max_additional_searches, + search_provider=search_provider or "tavily", + search_mode=search_mode, + num_results_per_search=num_results, + langfuse_project_name=langfuse_project_name, + ) + else: + logger.info( + f"Using query from config file with max {max_sub_questions} parallel sub-questions" + ) + if require_approval: + logger.info( + f"Human approval enabled with {approval_timeout}s timeout" + ) + pipeline( + max_sub_questions=max_sub_questions, + require_approval=require_approval, + approval_timeout=approval_timeout, + max_additional_searches=mode_max_additional_searches, + search_provider=search_provider or "tavily", + search_mode=search_mode, + num_results_per_search=num_results, + langfuse_project_name=langfuse_project_name, + ) + + +if __name__ == "__main__": + main() diff --git a/deep_research/steps/__init__.py b/deep_research/steps/__init__.py new file mode 100644 index 00000000..1d454e49 --- /dev/null +++ b/deep_research/steps/__init__.py @@ -0,0 +1,7 @@ +""" +Steps package for the ZenML Deep Research project. + +This package contains individual ZenML steps used in the research pipelines. +Each step is responsible for a specific part of the research process, such as +query decomposition, searching, synthesis, and report generation. +""" diff --git a/deep_research/steps/approval_step.py b/deep_research/steps/approval_step.py new file mode 100644 index 00000000..97a283a1 --- /dev/null +++ b/deep_research/steps/approval_step.py @@ -0,0 +1,308 @@ +import logging +import time +from typing import Annotated + +from materializers.approval_decision_materializer import ( + ApprovalDecisionMaterializer, +) +from utils.approval_utils import ( + format_approval_request, + summarize_research_progress, +) +from utils.pydantic_models import ApprovalDecision, ReflectionOutput +from zenml import log_metadata, step +from zenml.client import Client + +logger = logging.getLogger(__name__) + + +@step( + enable_cache=False, output_materializers=ApprovalDecisionMaterializer +) # Never cache approval decisions +def get_research_approval_step( + reflection_output: ReflectionOutput, + require_approval: bool = True, + alerter_type: str = "slack", + timeout: int = 3600, + max_queries: int = 2, +) -> Annotated[ApprovalDecision, "approval_decision"]: + """ + Get human approval for additional research queries. + + Always returns an ApprovalDecision object. If require_approval is False, + automatically approves all queries. + + Args: + reflection_output: Output from the reflection generation step + require_approval: Whether to require human approval + alerter_type: Type of alerter to use (slack, email, etc.) + timeout: Timeout in seconds for approval response + max_queries: Maximum number of queries to approve + + Returns: + ApprovalDecision object with approval status and selected queries + """ + start_time = time.time() + + # Limit queries to max_queries + limited_queries = reflection_output.recommended_queries[:max_queries] + + # If approval not required, auto-approve all + if not require_approval: + logger.info( + f"Auto-approving {len(limited_queries)} recommended queries (approval not required)" + ) + + # Log metadata for auto-approval + execution_time = time.time() - start_time + log_metadata( + metadata={ + "approval_decision": { + "execution_time_seconds": execution_time, + "approval_required": False, + "approval_method": "AUTO_APPROVED", + "num_queries_recommended": len( + reflection_output.recommended_queries + ), + "num_queries_approved": len(limited_queries), + "max_queries_allowed": max_queries, + "approval_status": "approved", + "wait_time_seconds": 0, + } + } + ) + + return ApprovalDecision( + approved=True, + selected_queries=limited_queries, + approval_method="AUTO_APPROVED", + reviewer_notes="Approval not required by configuration", + ) + + # If no queries to approve, skip + if not limited_queries: + logger.info("No additional queries recommended") + + # Log metadata for no queries + execution_time = time.time() - start_time + log_metadata( + metadata={ + "approval_decision": { + "execution_time_seconds": execution_time, + "approval_required": require_approval, + "approval_method": "NO_QUERIES", + "num_queries_recommended": 0, + "num_queries_approved": 0, + "max_queries_allowed": max_queries, + "approval_status": "skipped", + "wait_time_seconds": 0, + } + } + ) + + return ApprovalDecision( + approved=False, + selected_queries=[], + approval_method="NO_QUERIES", + reviewer_notes="No additional queries recommended", + ) + + # Prepare approval request + progress_summary = summarize_research_progress(reflection_output.state) + message = format_approval_request( + main_query=reflection_output.state.main_query, + progress_summary=progress_summary, + critique_points=reflection_output.critique_summary, + proposed_queries=limited_queries, + timeout=timeout, + ) + + # Log the approval request for visibility + logger.info("=" * 80) + logger.info("APPROVAL REQUEST:") + logger.info(message) + logger.info("=" * 80) + + try: + # Get the alerter from the active stack + client = Client() + alerter = client.active_stack.alerter + + if not alerter: + logger.warning("No alerter configured in stack, auto-approving") + + # Log metadata for no alerter scenario + execution_time = time.time() - start_time + log_metadata( + metadata={ + "approval_decision": { + "execution_time_seconds": execution_time, + "approval_required": require_approval, + "approval_method": "NO_ALERTER_AUTO_APPROVED", + "alerter_type": "none", + "num_queries_recommended": len( + reflection_output.recommended_queries + ), + "num_queries_approved": len(limited_queries), + "max_queries_allowed": max_queries, + "approval_status": "auto_approved", + "wait_time_seconds": 0, + } + } + ) + + return ApprovalDecision( + approved=True, + selected_queries=limited_queries, + approval_method="NO_ALERTER_AUTO_APPROVED", + reviewer_notes="No alerter configured - auto-approved", + ) + + # Use the alerter's ask method for interactive approval + try: + # Send the message to Discord and wait for response + logger.info( + f"Sending approval request to {alerter.flavor} alerter" + ) + + # Format message for Discord (Discord has message length limits) + discord_message = ( + f"**Research Approval Request**\n\n{message[:1900]}" + ) + if len(message) > 1900: + discord_message += ( + "\n\n*(Message truncated due to Discord limits)*" + ) + + # Add instructions for Discord responses + discord_message += "\n\n**How to respond:**\n" + discord_message += "✅ Type `yes`, `approve`, `ok`, or `LGTM` to approve ALL queries\n" + discord_message += "❌ Type `no`, `skip`, `reject`, or `decline` to skip additional research\n" + discord_message += f"⏱️ Response timeout: {timeout} seconds" + + # Use the ask method to get user response + logger.info("Waiting for approval response from Discord...") + wait_start_time = time.time() + approved = alerter.ask(discord_message) + wait_end_time = time.time() + wait_time = wait_end_time - wait_start_time + + logger.info( + f"Received Discord response: {'approved' if approved else 'rejected'}" + ) + + if approved: + # Log metadata for approved decision + execution_time = time.time() - start_time + log_metadata( + metadata={ + "approval_decision": { + "execution_time_seconds": execution_time, + "approval_required": require_approval, + "approval_method": "DISCORD_APPROVED", + "alerter_type": alerter_type, + "num_queries_recommended": len( + reflection_output.recommended_queries + ), + "num_queries_approved": len(limited_queries), + "max_queries_allowed": max_queries, + "approval_status": "approved", + "wait_time_seconds": wait_time, + "timeout_configured": timeout, + } + } + ) + + return ApprovalDecision( + approved=True, + selected_queries=limited_queries, + approval_method="DISCORD_APPROVED", + reviewer_notes="Approved via Discord", + ) + else: + # Log metadata for rejected decision + execution_time = time.time() - start_time + log_metadata( + metadata={ + "approval_decision": { + "execution_time_seconds": execution_time, + "approval_required": require_approval, + "approval_method": "DISCORD_REJECTED", + "alerter_type": alerter_type, + "num_queries_recommended": len( + reflection_output.recommended_queries + ), + "num_queries_approved": 0, + "max_queries_allowed": max_queries, + "approval_status": "rejected", + "wait_time_seconds": wait_time, + "timeout_configured": timeout, + } + } + ) + + return ApprovalDecision( + approved=False, + selected_queries=[], + approval_method="DISCORD_REJECTED", + reviewer_notes="Rejected via Discord", + ) + + except Exception as e: + logger.error(f"Failed to get approval from alerter: {e}") + + # Log metadata for alerter error + execution_time = time.time() - start_time + log_metadata( + metadata={ + "approval_decision": { + "execution_time_seconds": execution_time, + "approval_required": require_approval, + "approval_method": "ALERTER_ERROR", + "alerter_type": alerter_type, + "num_queries_recommended": len( + reflection_output.recommended_queries + ), + "num_queries_approved": 0, + "max_queries_allowed": max_queries, + "approval_status": "error", + "error_message": str(e), + } + } + ) + + return ApprovalDecision( + approved=False, + selected_queries=[], + approval_method="ALERTER_ERROR", + reviewer_notes=f"Failed to get approval: {str(e)}", + ) + + except Exception as e: + logger.error(f"Approval step failed: {e}") + + # Log metadata for general error + execution_time = time.time() - start_time + log_metadata( + metadata={ + "approval_decision": { + "execution_time_seconds": execution_time, + "approval_required": require_approval, + "approval_method": "ERROR", + "num_queries_recommended": len( + reflection_output.recommended_queries + ), + "num_queries_approved": 0, + "max_queries_allowed": max_queries, + "approval_status": "error", + "error_message": str(e), + } + } + ) + + return ApprovalDecision( + approved=False, + selected_queries=[], + approval_method="ERROR", + reviewer_notes=f"Approval failed: {str(e)}", + ) diff --git a/deep_research/steps/collect_tracing_metadata_step.py b/deep_research/steps/collect_tracing_metadata_step.py new file mode 100644 index 00000000..4f01505e --- /dev/null +++ b/deep_research/steps/collect_tracing_metadata_step.py @@ -0,0 +1,234 @@ +"""Step to collect tracing metadata from Langfuse for the pipeline run.""" + +import logging +from typing import Annotated, Tuple + +from materializers.pydantic_materializer import ResearchStateMaterializer +from materializers.tracing_metadata_materializer import ( + TracingMetadataMaterializer, +) +from utils.pydantic_models import ( + PromptTypeMetrics, + ResearchState, + TracingMetadata, +) +from utils.tracing_metadata_utils import ( + get_observations_for_trace, + get_prompt_type_statistics, + get_trace_stats, + get_traces_by_name, +) +from zenml import get_step_context, step + +logger = logging.getLogger(__name__) + + +@step( + enable_cache=False, + output_materializers={ + "state": ResearchStateMaterializer, + "tracing_metadata": TracingMetadataMaterializer, + }, +) +def collect_tracing_metadata_step( + state: ResearchState, + langfuse_project_name: str, +) -> Tuple[ + Annotated[ResearchState, "state"], + Annotated[TracingMetadata, "tracing_metadata"], +]: + """Collect tracing metadata from Langfuse for the current pipeline run. + + This step gathers comprehensive metrics about token usage, costs, and performance + for the entire pipeline run, providing insights into resource consumption. + + Args: + state: The final research state + langfuse_project_name: Langfuse project name for accessing traces + + Returns: + Tuple of (ResearchState, TracingMetadata) - the state is passed through unchanged + """ + ctx = get_step_context() + pipeline_run_name = ctx.pipeline_run.name + pipeline_run_id = str(ctx.pipeline_run.id) + + logger.info( + f"Collecting tracing metadata for pipeline run: {pipeline_run_name} (ID: {pipeline_run_id})" + ) + + # Initialize the metadata object + metadata = TracingMetadata( + pipeline_run_name=pipeline_run_name, + pipeline_run_id=pipeline_run_id, + trace_name=pipeline_run_name, + trace_id=pipeline_run_id, + ) + + try: + # Fetch the trace for this pipeline run + # The trace_name is the pipeline run name + traces = get_traces_by_name(name=pipeline_run_name, limit=1) + + if not traces: + logger.warning( + f"No trace found for pipeline run: {pipeline_run_name}" + ) + return state, metadata + + trace = traces[0] + + # Get comprehensive trace stats + trace_stats = get_trace_stats(trace) + + # Update metadata with trace stats + metadata.trace_id = trace.id + metadata.total_cost = trace_stats["total_cost"] + metadata.total_input_tokens = trace_stats["input_tokens"] + metadata.total_output_tokens = trace_stats["output_tokens"] + metadata.total_tokens = ( + trace_stats["input_tokens"] + trace_stats["output_tokens"] + ) + metadata.total_latency_seconds = trace_stats["latency_seconds"] + metadata.formatted_latency = trace_stats["latency_formatted"] + metadata.observation_count = trace_stats["observation_count"] + metadata.models_used = trace_stats["models_used"] + metadata.trace_tags = trace_stats.get("tags", []) + metadata.trace_metadata = trace_stats.get("metadata", {}) + + # Get model-specific breakdown + observations = get_observations_for_trace(trace_id=trace.id) + model_costs = {} + model_tokens = {} + step_costs = {} + step_tokens = {} + + for obs in observations: + if obs.model: + # Track by model + if obs.model not in model_costs: + model_costs[obs.model] = 0.0 + model_tokens[obs.model] = { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + } + + if obs.calculated_total_cost: + model_costs[obs.model] += obs.calculated_total_cost + + if obs.usage: + input_tokens = obs.usage.input or 0 + output_tokens = obs.usage.output or 0 + model_tokens[obs.model]["input_tokens"] += input_tokens + model_tokens[obs.model]["output_tokens"] += output_tokens + model_tokens[obs.model]["total_tokens"] += ( + input_tokens + output_tokens + ) + + # Track by step (using observation name as step indicator) + if obs.name: + step_name = obs.name + + if step_name not in step_costs: + step_costs[step_name] = 0.0 + step_tokens[step_name] = { + "input_tokens": 0, + "output_tokens": 0, + } + + if obs.calculated_total_cost: + step_costs[step_name] += obs.calculated_total_cost + + if obs.usage: + input_tokens = obs.usage.input or 0 + output_tokens = obs.usage.output or 0 + step_tokens[step_name]["input_tokens"] += input_tokens + step_tokens[step_name]["output_tokens"] += output_tokens + + metadata.cost_breakdown_by_model = model_costs + metadata.model_token_breakdown = model_tokens + metadata.step_costs = step_costs + metadata.step_tokens = step_tokens + + # Collect prompt-level metrics + try: + prompt_stats = get_prompt_type_statistics(trace_id=trace.id) + + # Convert to PromptTypeMetrics objects + prompt_metrics_list = [] + for prompt_type, stats in prompt_stats.items(): + prompt_metrics = PromptTypeMetrics( + prompt_type=prompt_type, + total_cost=stats["cost"], + input_tokens=stats["input_tokens"], + output_tokens=stats["output_tokens"], + call_count=stats["count"], + avg_cost_per_call=stats["avg_cost_per_call"], + percentage_of_total_cost=stats["percentage_of_total_cost"], + ) + prompt_metrics_list.append(prompt_metrics) + + # Sort by total cost descending + prompt_metrics_list.sort(key=lambda x: x.total_cost, reverse=True) + metadata.prompt_metrics = prompt_metrics_list + + logger.info( + f"Collected prompt-level metrics for {len(prompt_metrics_list)} prompt types" + ) + except Exception as e: + logger.warning(f"Failed to collect prompt-level metrics: {str(e)}") + + # Add search costs from the state + if hasattr(state, "search_costs") and state.search_costs: + metadata.search_costs = state.search_costs.copy() + logger.info(f"Added search costs: {metadata.search_costs}") + + if hasattr(state, "search_cost_details") and state.search_cost_details: + metadata.search_cost_details = state.search_cost_details.copy() + + # Count queries by provider + search_queries_count = {} + for detail in state.search_cost_details: + provider = detail.get("provider", "unknown") + search_queries_count[provider] = ( + search_queries_count.get(provider, 0) + 1 + ) + metadata.search_queries_count = search_queries_count + + logger.info( + f"Added {len(metadata.search_cost_details)} search cost detail entries" + ) + + total_search_cost = sum(metadata.search_costs.values()) + logger.info( + f"Successfully collected tracing metadata - " + f"LLM Cost: ${metadata.total_cost:.4f}, " + f"Search Cost: ${total_search_cost:.4f}, " + f"Total Cost: ${metadata.total_cost + total_search_cost:.4f}, " + f"Tokens: {metadata.total_tokens:,}, " + f"Models: {metadata.models_used}, " + f"Duration: {metadata.formatted_latency}" + ) + + except Exception as e: + logger.error( + f"Failed to collect tracing metadata for pipeline run {pipeline_run_name}: {str(e)}" + ) + # Return metadata with whatever we could collect + + # Still try to get search costs even if Langfuse failed + if hasattr(state, "search_costs") and state.search_costs: + metadata.search_costs = state.search_costs.copy() + if hasattr(state, "search_cost_details") and state.search_cost_details: + metadata.search_cost_details = state.search_cost_details.copy() + # Count queries by provider + search_queries_count = {} + for detail in state.search_cost_details: + provider = detail.get("provider", "unknown") + search_queries_count[provider] = ( + search_queries_count.get(provider, 0) + 1 + ) + metadata.search_queries_count = search_queries_count + + return state, metadata diff --git a/deep_research/steps/cross_viewpoint_step.py b/deep_research/steps/cross_viewpoint_step.py new file mode 100644 index 00000000..ad9c5ddb --- /dev/null +++ b/deep_research/steps/cross_viewpoint_step.py @@ -0,0 +1,228 @@ +import json +import logging +import time +from typing import Annotated, List + +from materializers.pydantic_materializer import ResearchStateMaterializer +from utils.helper_functions import ( + safe_json_loads, +) +from utils.llm_utils import run_llm_completion +from utils.prompt_models import PromptsBundle +from utils.pydantic_models import ( + ResearchState, + ViewpointAnalysis, + ViewpointTension, +) +from zenml import log_metadata, step + +logger = logging.getLogger(__name__) + + +@step(output_materializers=ResearchStateMaterializer) +def cross_viewpoint_analysis_step( + state: ResearchState, + prompts_bundle: PromptsBundle, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + viewpoint_categories: List[str] = [ + "scientific", + "political", + "economic", + "social", + "ethical", + "historical", + ], + langfuse_project_name: str = "deep-research", +) -> Annotated[ResearchState, "analyzed_state"]: + """Analyze synthesized information across different viewpoints. + + Args: + state: The current research state + prompts_bundle: Bundle containing all prompts for the pipeline + llm_model: The model to use for viewpoint analysis + viewpoint_categories: Categories of viewpoints to analyze + + Returns: + Updated research state with viewpoint analysis + """ + start_time = time.time() + logger.info( + f"Performing cross-viewpoint analysis on {len(state.synthesized_info)} sub-questions" + ) + + # Prepare input for viewpoint analysis + analysis_input = { + "main_query": state.main_query, + "sub_questions": state.sub_questions, + "synthesized_information": { + question: { + "synthesized_answer": info.synthesized_answer, + "key_sources": info.key_sources, + "confidence_level": info.confidence_level, + "information_gaps": info.information_gaps, + } + for question, info in state.synthesized_info.items() + }, + "viewpoint_categories": viewpoint_categories, + } + + # Perform viewpoint analysis + try: + logger.info(f"Calling {llm_model} for viewpoint analysis") + # Get the prompt from the bundle + system_prompt = prompts_bundle.get_prompt_content( + "viewpoint_analysis_prompt" + ) + + # Use the run_llm_completion function from llm_utils + content = run_llm_completion( + prompt=json.dumps(analysis_input), + system_prompt=system_prompt, + model=llm_model, # Model name will be prefixed in the function + max_tokens=3000, # Further increased for more comprehensive viewpoint analysis + project=langfuse_project_name, + ) + + result = safe_json_loads(content) + + if not result: + logger.warning("Failed to parse viewpoint analysis result") + # Create a default viewpoint analysis + viewpoint_analysis = ViewpointAnalysis( + main_points_of_agreement=[ + "Analysis failed to identify points of agreement." + ], + perspective_gaps="Analysis failed to identify perspective gaps.", + integrative_insights="Analysis failed to provide integrative insights.", + ) + else: + # Create tension objects + tensions = [] + for tension_data in result.get("areas_of_tension", []): + tensions.append( + ViewpointTension( + topic=tension_data.get("topic", ""), + viewpoints=tension_data.get("viewpoints", {}), + ) + ) + + # Create the viewpoint analysis object + viewpoint_analysis = ViewpointAnalysis( + main_points_of_agreement=result.get( + "main_points_of_agreement", [] + ), + areas_of_tension=tensions, + perspective_gaps=result.get("perspective_gaps", ""), + integrative_insights=result.get("integrative_insights", ""), + ) + + logger.info("Completed viewpoint analysis") + + # Update the state with the viewpoint analysis + state.update_viewpoint_analysis(viewpoint_analysis) + + # Calculate execution time + execution_time = time.time() - start_time + + # Count viewpoint tensions by category + tension_categories = {} + for tension in viewpoint_analysis.areas_of_tension: + for category in tension.viewpoints.keys(): + tension_categories[category] = ( + tension_categories.get(category, 0) + 1 + ) + + # Log metadata + log_metadata( + metadata={ + "viewpoint_analysis": { + "execution_time_seconds": execution_time, + "llm_model": llm_model, + "num_sub_questions_analyzed": len(state.synthesized_info), + "viewpoint_categories_requested": viewpoint_categories, + "num_agreement_points": len( + viewpoint_analysis.main_points_of_agreement + ), + "num_tension_areas": len( + viewpoint_analysis.areas_of_tension + ), + "tension_categories_distribution": tension_categories, + "has_perspective_gaps": bool( + viewpoint_analysis.perspective_gaps + and viewpoint_analysis.perspective_gaps != "" + ), + "has_integrative_insights": bool( + viewpoint_analysis.integrative_insights + and viewpoint_analysis.integrative_insights != "" + ), + "analysis_success": not viewpoint_analysis.main_points_of_agreement[ + 0 + ].startswith("Analysis failed"), + } + } + ) + + # Log model metadata for cross-pipeline tracking + log_metadata( + metadata={ + "research_scope": { + "num_tension_areas": len( + viewpoint_analysis.areas_of_tension + ), + } + }, + infer_model=True, + ) + + # Log artifact metadata + log_metadata( + metadata={ + "state_with_viewpoint_analysis": { + "has_viewpoint_analysis": True, + "total_viewpoints_analyzed": sum( + tension_categories.values() + ), + "most_common_tension_category": max( + tension_categories, key=tension_categories.get + ) + if tension_categories + else None, + } + }, + infer_artifact=True, + ) + + return state + + except Exception as e: + logger.error(f"Error performing viewpoint analysis: {e}") + + # Create a fallback viewpoint analysis + fallback_analysis = ViewpointAnalysis( + main_points_of_agreement=[ + "Analysis failed due to technical error." + ], + perspective_gaps=f"Analysis failed: {str(e)}", + integrative_insights="No insights available due to analysis failure.", + ) + + # Update the state with the fallback analysis + state.update_viewpoint_analysis(fallback_analysis) + + # Log error metadata + execution_time = time.time() - start_time + log_metadata( + metadata={ + "viewpoint_analysis": { + "execution_time_seconds": execution_time, + "llm_model": llm_model, + "num_sub_questions_analyzed": len(state.synthesized_info), + "viewpoint_categories_requested": viewpoint_categories, + "analysis_success": False, + "error_message": str(e), + "fallback_used": True, + } + } + ) + + return state diff --git a/deep_research/steps/execute_approved_searches_step.py b/deep_research/steps/execute_approved_searches_step.py new file mode 100644 index 00000000..90a15718 --- /dev/null +++ b/deep_research/steps/execute_approved_searches_step.py @@ -0,0 +1,423 @@ +import json +import logging +import time +from typing import Annotated + +from materializers.pydantic_materializer import ResearchStateMaterializer +from utils.llm_utils import ( + find_most_relevant_string, + get_structured_llm_output, + is_text_relevant, +) +from utils.prompt_models import PromptsBundle +from utils.pydantic_models import ( + ApprovalDecision, + ReflectionMetadata, + ReflectionOutput, + ResearchState, + SynthesizedInfo, +) +from utils.search_utils import search_and_extract_results +from zenml import log_metadata, step + +logger = logging.getLogger(__name__) + + +def create_enhanced_info_copy(synthesized_info): + """Create a deep copy of synthesized info for enhancement.""" + return { + k: SynthesizedInfo( + synthesized_answer=v.synthesized_answer, + key_sources=v.key_sources.copy(), + confidence_level=v.confidence_level, + information_gaps=v.information_gaps, + improvements=v.improvements.copy() + if hasattr(v, "improvements") + else [], + ) + for k, v in synthesized_info.items() + } + + +@step(output_materializers=ResearchStateMaterializer) +def execute_approved_searches_step( + reflection_output: ReflectionOutput, + approval_decision: ApprovalDecision, + prompts_bundle: PromptsBundle, + num_results_per_search: int = 3, + cap_search_length: int = 20000, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + search_provider: str = "tavily", + search_mode: str = "auto", + langfuse_project_name: str = "deep-research", +) -> Annotated[ResearchState, "updated_state"]: + """Execute approved searches and enhance the research state. + + This step receives the approval decision and only executes + searches that were approved by the human reviewer (or auto-approved). + + Args: + reflection_output: Output from the reflection generation step + approval_decision: Human approval decision + num_results_per_search: Number of results to fetch per search + cap_search_length: Maximum length of content to process from search results + llm_model: The model to use for synthesis enhancement + prompts_bundle: Bundle containing all prompts for the pipeline + search_provider: Search provider to use + search_mode: Search mode for the provider + + Returns: + Updated research state with enhanced information and reflection metadata + """ + start_time = time.time() + logger.info( + f"Processing approval decision: {approval_decision.approval_method}" + ) + + state = reflection_output.state + enhanced_info = create_enhanced_info_copy(state.synthesized_info) + + # Track improvements count + improvements_count = 0 + + # Check if we should execute searches + if ( + not approval_decision.approved + or not approval_decision.selected_queries + ): + logger.info("No additional searches approved") + + # Add any additional questions as new synthesized entries (from reflection) + for new_question in reflection_output.additional_questions: + if ( + new_question not in state.sub_questions + and new_question not in enhanced_info + ): + enhanced_info[new_question] = SynthesizedInfo( + synthesized_answer=f"This question was identified during reflection but has not yet been researched: {new_question}", + key_sources=[], + confidence_level="low", + information_gaps="This question requires additional research.", + ) + + # Create metadata indicating no additional research + reflection_metadata = ReflectionMetadata( + critique_summary=[ + c.get("issue", "") for c in reflection_output.critique_summary + ], + additional_questions_identified=reflection_output.additional_questions, + searches_performed=[], + improvements_made=improvements_count, + ) + + # Add approval decision info to metadata + if hasattr(reflection_metadata, "__dict__"): + reflection_metadata.__dict__["user_decision"] = ( + approval_decision.approval_method + ) + reflection_metadata.__dict__["reviewer_notes"] = ( + approval_decision.reviewer_notes + ) + + state.update_after_reflection(enhanced_info, reflection_metadata) + + # Log metadata for no approved searches + execution_time = time.time() - start_time + log_metadata( + metadata={ + "execute_approved_searches": { + "execution_time_seconds": execution_time, + "approval_method": approval_decision.approval_method, + "approval_status": "not_approved" + if not approval_decision.approved + else "no_queries", + "num_queries_approved": 0, + "num_searches_executed": 0, + "num_additional_questions": len( + reflection_output.additional_questions + ), + "improvements_made": improvements_count, + "search_provider": search_provider, + "llm_model": llm_model, + } + } + ) + + return state + + # Execute approved searches + logger.info( + f"Executing {len(approval_decision.selected_queries)} approved searches" + ) + + try: + search_enhancements = [] # Track search results for metadata + + for query in approval_decision.selected_queries: + logger.info(f"Performing approved search: {query}") + + # Execute search using the utility function + search_results, search_cost = search_and_extract_results( + query=query, + max_results=num_results_per_search, + cap_content_length=cap_search_length, + provider=search_provider, + search_mode=search_mode, + ) + + # Track search costs if using Exa + if ( + search_provider + and search_provider.lower() in ["exa", "both"] + and search_cost > 0 + ): + # Update total costs + state.search_costs["exa"] = ( + state.search_costs.get("exa", 0.0) + search_cost + ) + + # Add detailed cost entry + state.search_cost_details.append( + { + "provider": "exa", + "query": query, + "cost": search_cost, + "timestamp": time.time(), + "step": "execute_approved_searches", + "purpose": "reflection_enhancement", + } + ) + logger.info( + f"Exa search cost for approved query: ${search_cost:.4f}" + ) + + # Extract raw contents + raw_contents = [result.content for result in search_results] + + # Find the most relevant sub-question for this query + most_relevant_question = find_most_relevant_string( + query, + state.sub_questions, + llm_model, + project=langfuse_project_name, + ) + + if ( + most_relevant_question + and most_relevant_question in enhanced_info + ): + # Enhance the synthesis with new information + enhancement_input = { + "original_synthesis": enhanced_info[ + most_relevant_question + ].synthesized_answer, + "new_information": raw_contents, + "critique": [ + item + for item in reflection_output.critique_summary + if is_text_relevant( + item.get("issue", ""), most_relevant_question + ) + ], + } + + # Get the prompt from the bundle + additional_synthesis_prompt = ( + prompts_bundle.get_prompt_content( + "additional_synthesis_prompt" + ) + ) + + # Use the utility function for enhancement + enhanced_synthesis = get_structured_llm_output( + prompt=json.dumps(enhancement_input), + system_prompt=additional_synthesis_prompt, + model=llm_model, + fallback_response={ + "enhanced_synthesis": enhanced_info[ + most_relevant_question + ].synthesized_answer, + "improvements_made": ["Failed to enhance synthesis"], + "remaining_limitations": "Enhancement process failed.", + }, + project=langfuse_project_name, + ) + + if ( + enhanced_synthesis + and "enhanced_synthesis" in enhanced_synthesis + ): + # Update the synthesized answer + enhanced_info[ + most_relevant_question + ].synthesized_answer = enhanced_synthesis[ + "enhanced_synthesis" + ] + + # Add improvements + improvements = enhanced_synthesis.get( + "improvements_made", [] + ) + enhanced_info[most_relevant_question].improvements.extend( + improvements + ) + improvements_count += len(improvements) + + # Track enhancement for metadata + search_enhancements.append( + { + "query": query, + "relevant_question": most_relevant_question, + "num_results": len(search_results), + "improvements": len(improvements), + "enhanced": True, + "search_cost": search_cost + if search_provider + and search_provider.lower() in ["exa", "both"] + else 0.0, + } + ) + + # Add any additional questions as new synthesized entries + for new_question in reflection_output.additional_questions: + if ( + new_question not in state.sub_questions + and new_question not in enhanced_info + ): + enhanced_info[new_question] = SynthesizedInfo( + synthesized_answer=f"This question was identified during reflection but has not yet been researched: {new_question}", + key_sources=[], + confidence_level="low", + information_gaps="This question requires additional research.", + ) + + # Create final metadata with approval info + reflection_metadata = ReflectionMetadata( + critique_summary=[ + c.get("issue", "") for c in reflection_output.critique_summary + ], + additional_questions_identified=reflection_output.additional_questions, + searches_performed=approval_decision.selected_queries, + improvements_made=improvements_count, + ) + + # Add approval decision info to metadata + if hasattr(reflection_metadata, "__dict__"): + reflection_metadata.__dict__["user_decision"] = ( + approval_decision.approval_method + ) + reflection_metadata.__dict__["reviewer_notes"] = ( + approval_decision.reviewer_notes + ) + + logger.info( + f"Completed approved searches with {improvements_count} improvements" + ) + + state.update_after_reflection(enhanced_info, reflection_metadata) + + # Calculate metrics for metadata + execution_time = time.time() - start_time + total_results = sum( + e.get("num_results", 0) for e in search_enhancements + ) + questions_enhanced = len( + set( + e.get("relevant_question") + for e in search_enhancements + if e.get("enhanced") + ) + ) + + # Log successful execution metadata + log_metadata( + metadata={ + "execute_approved_searches": { + "execution_time_seconds": execution_time, + "approval_method": approval_decision.approval_method, + "approval_status": "approved", + "num_queries_recommended": len( + reflection_output.recommended_queries + ), + "num_queries_approved": len( + approval_decision.selected_queries + ), + "num_searches_executed": len( + approval_decision.selected_queries + ), + "total_search_results": total_results, + "questions_enhanced": questions_enhanced, + "improvements_made": improvements_count, + "num_additional_questions": len( + reflection_output.additional_questions + ), + "search_provider": search_provider, + "search_mode": search_mode, + "llm_model": llm_model, + "success": True, + "total_search_cost": state.search_costs.get("exa", 0.0), + } + } + ) + + # Log artifact metadata + log_metadata( + metadata={ + "enhanced_state_after_approval": { + "total_questions": len(enhanced_info), + "questions_with_improvements": sum( + 1 + for info in enhanced_info.values() + if info.improvements + ), + "total_improvements": sum( + len(info.improvements) + for info in enhanced_info.values() + ), + "approval_method": approval_decision.approval_method, + } + }, + infer_artifact=True, + ) + + return state + + except Exception as e: + logger.error(f"Error during approved search execution: {e}") + + # Create error metadata + error_metadata = ReflectionMetadata( + error=f"Approved search execution failed: {str(e)}", + critique_summary=[ + c.get("issue", "") for c in reflection_output.critique_summary + ], + additional_questions_identified=reflection_output.additional_questions, + searches_performed=[], + improvements_made=0, + ) + + # Update the state with the original synthesized info as enhanced info + state.update_after_reflection(state.synthesized_info, error_metadata) + + # Log error metadata + execution_time = time.time() - start_time + log_metadata( + metadata={ + "execute_approved_searches": { + "execution_time_seconds": execution_time, + "approval_method": approval_decision.approval_method, + "approval_status": "approved", + "num_queries_approved": len( + approval_decision.selected_queries + ), + "num_searches_executed": 0, + "improvements_made": 0, + "search_provider": search_provider, + "llm_model": llm_model, + "success": False, + "error_message": str(e), + } + } + ) + + return state diff --git a/deep_research/steps/generate_reflection_step.py b/deep_research/steps/generate_reflection_step.py new file mode 100644 index 00000000..e8547af6 --- /dev/null +++ b/deep_research/steps/generate_reflection_step.py @@ -0,0 +1,167 @@ +import json +import logging +import time +from typing import Annotated + +from materializers.reflection_output_materializer import ( + ReflectionOutputMaterializer, +) +from utils.llm_utils import get_structured_llm_output +from utils.prompt_models import PromptsBundle +from utils.pydantic_models import ReflectionOutput, ResearchState +from zenml import log_metadata, step + +logger = logging.getLogger(__name__) + + +@step(output_materializers=ReflectionOutputMaterializer) +def generate_reflection_step( + state: ResearchState, + prompts_bundle: PromptsBundle, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + langfuse_project_name: str = "deep-research", +) -> Annotated[ReflectionOutput, "reflection_output"]: + """ + Generate reflection and recommendations WITHOUT executing searches. + + This step only analyzes the current state and produces recommendations + for additional research that could improve the quality of the results. + + Args: + state: The current research state + prompts_bundle: Bundle containing all prompts for the pipeline + llm_model: The model to use for reflection + + Returns: + ReflectionOutput containing the state, recommendations, and critique + """ + start_time = time.time() + logger.info("Generating reflection on research") + + # Prepare input for reflection + synthesized_info_dict = { + question: { + "synthesized_answer": info.synthesized_answer, + "key_sources": info.key_sources, + "confidence_level": info.confidence_level, + "information_gaps": info.information_gaps, + } + for question, info in state.synthesized_info.items() + } + + viewpoint_analysis_dict = None + if state.viewpoint_analysis: + # Convert the viewpoint analysis to a dict for the LLM + tension_list = [] + for tension in state.viewpoint_analysis.areas_of_tension: + tension_list.append( + {"topic": tension.topic, "viewpoints": tension.viewpoints} + ) + + viewpoint_analysis_dict = { + "main_points_of_agreement": state.viewpoint_analysis.main_points_of_agreement, + "areas_of_tension": tension_list, + "perspective_gaps": state.viewpoint_analysis.perspective_gaps, + "integrative_insights": state.viewpoint_analysis.integrative_insights, + } + + reflection_input = { + "main_query": state.main_query, + "sub_questions": state.sub_questions, + "synthesized_information": synthesized_info_dict, + } + + if viewpoint_analysis_dict: + reflection_input["viewpoint_analysis"] = viewpoint_analysis_dict + + # Get reflection critique + logger.info(f"Generating self-critique via {llm_model}") + + # Get the prompt from the bundle + reflection_prompt = prompts_bundle.get_prompt_content("reflection_prompt") + + # Define fallback for reflection result + fallback_reflection = { + "critique": [], + "additional_questions": [], + "recommended_search_queries": [], + } + + # Use utility function to get structured output + reflection_result = get_structured_llm_output( + prompt=json.dumps(reflection_input), + system_prompt=reflection_prompt, + model=llm_model, + fallback_response=fallback_reflection, + project=langfuse_project_name, + ) + + # Prepare return value + reflection_output = ReflectionOutput( + state=state, + recommended_queries=reflection_result.get( + "recommended_search_queries", [] + ), + critique_summary=reflection_result.get("critique", []), + additional_questions=reflection_result.get("additional_questions", []), + ) + + # Calculate execution time + execution_time = time.time() - start_time + + # Count confidence levels in synthesized info + confidence_levels = [ + info.confidence_level for info in state.synthesized_info.values() + ] + confidence_distribution = { + "high": confidence_levels.count("high"), + "medium": confidence_levels.count("medium"), + "low": confidence_levels.count("low"), + } + + # Log step metadata + log_metadata( + metadata={ + "reflection_generation": { + "execution_time_seconds": execution_time, + "llm_model": llm_model, + "num_sub_questions_analyzed": len(state.sub_questions), + "num_synthesized_answers": len(state.synthesized_info), + "viewpoint_analysis_included": bool(viewpoint_analysis_dict), + "num_critique_points": len(reflection_output.critique_summary), + "num_additional_questions": len( + reflection_output.additional_questions + ), + "num_recommended_queries": len( + reflection_output.recommended_queries + ), + "confidence_distribution": confidence_distribution, + "has_information_gaps": any( + info.information_gaps + for info in state.synthesized_info.values() + ), + } + } + ) + + # Log artifact metadata + log_metadata( + metadata={ + "reflection_output_characteristics": { + "has_recommendations": bool( + reflection_output.recommended_queries + ), + "has_critique": bool(reflection_output.critique_summary), + "has_additional_questions": bool( + reflection_output.additional_questions + ), + "total_recommendations": len( + reflection_output.recommended_queries + ) + + len(reflection_output.additional_questions), + } + }, + infer_artifact=True, + ) + + return reflection_output diff --git a/deep_research/steps/initialize_prompts_step.py b/deep_research/steps/initialize_prompts_step.py new file mode 100644 index 00000000..377c6df4 --- /dev/null +++ b/deep_research/steps/initialize_prompts_step.py @@ -0,0 +1,45 @@ +"""Step to initialize and track prompts as artifacts. + +This step creates a PromptsBundle artifact at the beginning of the pipeline, +making all prompts trackable and versioned in ZenML. +""" + +import logging +from typing import Annotated + +from materializers.prompts_materializer import PromptsBundleMaterializer +from utils.prompt_loader import load_prompts_bundle +from utils.prompt_models import PromptsBundle +from zenml import step + +logger = logging.getLogger(__name__) + + +@step(output_materializers=PromptsBundleMaterializer) +def initialize_prompts_step( + pipeline_version: str = "1.1.0", +) -> Annotated[PromptsBundle, "prompts_bundle"]: + """Initialize the prompts bundle for the pipeline. + + This step loads all prompts from the prompts.py module and creates + a PromptsBundle artifact that can be tracked and visualized in ZenML. + + Args: + pipeline_version: Version of the pipeline using these prompts + + Returns: + PromptsBundle containing all prompts used in the pipeline + """ + logger.info( + f"Initializing prompts bundle for pipeline version {pipeline_version}" + ) + + # Load all prompts into a bundle + prompts_bundle = load_prompts_bundle(pipeline_version=pipeline_version) + + # Log some statistics + all_prompts = prompts_bundle.list_all_prompts() + logger.info(f"Loaded {len(all_prompts)} prompts into bundle") + logger.info(f"Prompts: {', '.join(all_prompts.keys())}") + + return prompts_bundle diff --git a/deep_research/steps/iterative_reflection_step.py b/deep_research/steps/iterative_reflection_step.py new file mode 100644 index 00000000..0ed38800 --- /dev/null +++ b/deep_research/steps/iterative_reflection_step.py @@ -0,0 +1,385 @@ +import json +import logging +import time +from typing import Annotated + +from materializers.pydantic_materializer import ResearchStateMaterializer +from utils.llm_utils import ( + find_most_relevant_string, + get_structured_llm_output, + is_text_relevant, +) +from utils.prompts import ADDITIONAL_SYNTHESIS_PROMPT, REFLECTION_PROMPT +from utils.pydantic_models import ( + ReflectionMetadata, + ResearchState, + SynthesizedInfo, +) +from utils.search_utils import search_and_extract_results +from zenml import log_metadata, step + +logger = logging.getLogger(__name__) + + +@step(output_materializers=ResearchStateMaterializer) +def iterative_reflection_step( + state: ResearchState, + max_additional_searches: int = 2, + num_results_per_search: int = 3, + cap_search_length: int = 20000, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + reflection_prompt: str = REFLECTION_PROMPT, + additional_synthesis_prompt: str = ADDITIONAL_SYNTHESIS_PROMPT, +) -> Annotated[ResearchState, "reflected_state"]: + """Perform iterative reflection on the research, identifying gaps and improving it. + + Args: + state: The current research state + max_additional_searches: Maximum number of additional searches to perform + num_results_per_search: Number of results to fetch per search + cap_search_length: Maximum length of content to process from search results + llm_model: The model to use for reflection + reflection_prompt: System prompt for the reflection + additional_synthesis_prompt: System prompt for incorporating additional information + + Returns: + Updated research state with enhanced information and reflection metadata + """ + start_time = time.time() + logger.info("Starting iterative reflection on research") + + # Prepare input for reflection + synthesized_info_dict = { + question: { + "synthesized_answer": info.synthesized_answer, + "key_sources": info.key_sources, + "confidence_level": info.confidence_level, + "information_gaps": info.information_gaps, + } + for question, info in state.synthesized_info.items() + } + + viewpoint_analysis_dict = None + if state.viewpoint_analysis: + # Convert the viewpoint analysis to a dict for the LLM + tension_list = [] + for tension in state.viewpoint_analysis.areas_of_tension: + tension_list.append( + {"topic": tension.topic, "viewpoints": tension.viewpoints} + ) + + viewpoint_analysis_dict = { + "main_points_of_agreement": state.viewpoint_analysis.main_points_of_agreement, + "areas_of_tension": tension_list, + "perspective_gaps": state.viewpoint_analysis.perspective_gaps, + "integrative_insights": state.viewpoint_analysis.integrative_insights, + } + + reflection_input = { + "main_query": state.main_query, + "sub_questions": state.sub_questions, + "synthesized_information": synthesized_info_dict, + } + + if viewpoint_analysis_dict: + reflection_input["viewpoint_analysis"] = viewpoint_analysis_dict + + # Get reflection critique + try: + logger.info(f"Generating self-critique via {llm_model}") + + # Define fallback for reflection result + fallback_reflection = { + "critique": [], + "additional_questions": [], + "recommended_search_queries": [], + } + + # Use utility function to get structured output + reflection_result = get_structured_llm_output( + prompt=json.dumps(reflection_input), + system_prompt=reflection_prompt, + model=llm_model, + fallback_response=fallback_reflection, + ) + + # Make a deep copy of the synthesized info to create enhanced_info + enhanced_info = { + k: SynthesizedInfo( + synthesized_answer=v.synthesized_answer, + key_sources=v.key_sources.copy(), + confidence_level=v.confidence_level, + information_gaps=v.information_gaps, + improvements=v.improvements.copy() + if hasattr(v, "improvements") + else [], + ) + for k, v in state.synthesized_info.items() + } + + # Perform additional searches based on recommendations + search_queries = reflection_result.get( + "recommended_search_queries", [] + ) + if max_additional_searches > 0 and search_queries: + # Limit to max_additional_searches + search_queries = search_queries[:max_additional_searches] + + for query in search_queries: + logger.info(f"Performing additional search: {query}") + # Execute the search using the utility function + search_results, search_cost = search_and_extract_results( + query=query, + max_results=num_results_per_search, + cap_content_length=cap_search_length, + ) + + # Extract raw contents + raw_contents = [result.content for result in search_results] + + # Find the most relevant sub-question for this query + most_relevant_question = find_most_relevant_string( + query, state.sub_questions, llm_model + ) + + # Track search costs if using Exa (default provider) + # Note: This step doesn't have a search_provider parameter, so we check the default + from utils.search_utils import SearchEngineConfig + + config = SearchEngineConfig() + if ( + config.default_provider.lower() in ["exa", "both"] + and search_cost > 0 + ): + # Update total costs + state.search_costs["exa"] = ( + state.search_costs.get("exa", 0.0) + search_cost + ) + + # Add detailed cost entry + state.search_cost_details.append( + { + "provider": "exa", + "query": query, + "cost": search_cost, + "timestamp": time.time(), + "step": "iterative_reflection", + "purpose": "gap_filling", + "relevant_question": most_relevant_question, + } + ) + logger.info( + f"Exa search cost for reflection query: ${search_cost:.4f}" + ) + + if ( + most_relevant_question + and most_relevant_question in enhanced_info + ): + # Enhance the synthesis with new information + enhancement_input = { + "original_synthesis": enhanced_info[ + most_relevant_question + ].synthesized_answer, + "new_information": raw_contents, + "critique": [ + item + for item in reflection_result.get("critique", []) + if is_text_relevant( + item.get("issue", ""), most_relevant_question + ) + ], + } + + # Use the utility function for enhancement + enhanced_synthesis = get_structured_llm_output( + prompt=json.dumps(enhancement_input), + system_prompt=additional_synthesis_prompt, + model=llm_model, + fallback_response={ + "enhanced_synthesis": enhanced_info[ + most_relevant_question + ].synthesized_answer, + "improvements_made": [ + "Failed to enhance synthesis" + ], + "remaining_limitations": "Enhancement process failed.", + }, + ) + + if ( + enhanced_synthesis + and "enhanced_synthesis" in enhanced_synthesis + ): + # Update the synthesized answer + enhanced_info[ + most_relevant_question + ].synthesized_answer = enhanced_synthesis[ + "enhanced_synthesis" + ] + + # Add improvements + improvements = enhanced_synthesis.get( + "improvements_made", [] + ) + enhanced_info[ + most_relevant_question + ].improvements.extend(improvements) + + # Add any additional questions as new synthesized entries + for new_question in reflection_result.get("additional_questions", []): + if ( + new_question not in state.sub_questions + and new_question not in enhanced_info + ): + enhanced_info[new_question] = SynthesizedInfo( + synthesized_answer=f"This question was identified during reflection but has not yet been researched: {new_question}", + key_sources=[], + confidence_level="low", + information_gaps="This question requires additional research.", + ) + + # Prepare metadata about the reflection process + reflection_metadata = ReflectionMetadata( + critique_summary=[ + item.get("issue", "") + for item in reflection_result.get("critique", []) + ], + additional_questions_identified=reflection_result.get( + "additional_questions", [] + ), + searches_performed=search_queries, + improvements_made=sum( + [len(info.improvements) for info in enhanced_info.values()] + ), + ) + + logger.info( + f"Completed iterative reflection with {reflection_metadata.improvements_made} improvements" + ) + + # Update the state with enhanced info and metadata + state.update_after_reflection(enhanced_info, reflection_metadata) + + # Calculate execution time + execution_time = time.time() - start_time + + # Count questions that were enhanced + questions_enhanced = 0 + for question, enhanced in enhanced_info.items(): + if question in state.synthesized_info: + original = state.synthesized_info[question] + if enhanced.synthesized_answer != original.synthesized_answer: + questions_enhanced += 1 + + # Calculate confidence level changes + confidence_improvements = {"improved": 0, "unchanged": 0, "new": 0} + for question, enhanced in enhanced_info.items(): + if question in state.synthesized_info: + original = state.synthesized_info[question] + original_level = original.confidence_level.lower() + enhanced_level = enhanced.confidence_level.lower() + + level_map = {"low": 0, "medium": 1, "high": 2} + if enhanced_level in level_map and original_level in level_map: + if level_map[enhanced_level] > level_map[original_level]: + confidence_improvements["improved"] += 1 + else: + confidence_improvements["unchanged"] += 1 + else: + confidence_improvements["new"] += 1 + + # Log metadata + log_metadata( + metadata={ + "iterative_reflection": { + "execution_time_seconds": execution_time, + "llm_model": llm_model, + "max_additional_searches": max_additional_searches, + "searches_performed": len(search_queries), + "num_critique_points": len( + reflection_result.get("critique", []) + ), + "num_additional_questions": len( + reflection_result.get("additional_questions", []) + ), + "questions_enhanced": questions_enhanced, + "total_improvements": reflection_metadata.improvements_made, + "confidence_improvements": confidence_improvements, + "has_viewpoint_analysis": bool(viewpoint_analysis_dict), + "total_search_cost": state.search_costs.get("exa", 0.0), + } + } + ) + + # Log model metadata for cross-pipeline tracking + log_metadata( + metadata={ + "improvement_metrics": { + "confidence_improvements": confidence_improvements, + "total_improvements": reflection_metadata.improvements_made, + } + }, + infer_model=True, + ) + + # Log artifact metadata + log_metadata( + metadata={ + "enhanced_state_characteristics": { + "total_questions": len(enhanced_info), + "questions_with_improvements": sum( + 1 + for info in enhanced_info.values() + if info.improvements + ), + "high_confidence_count": sum( + 1 + for info in enhanced_info.values() + if info.confidence_level.lower() == "high" + ), + "medium_confidence_count": sum( + 1 + for info in enhanced_info.values() + if info.confidence_level.lower() == "medium" + ), + "low_confidence_count": sum( + 1 + for info in enhanced_info.values() + if info.confidence_level.lower() == "low" + ), + } + }, + infer_artifact=True, + ) + + return state + + except Exception as e: + logger.error(f"Error during iterative reflection: {e}") + + # Create error metadata + error_metadata = ReflectionMetadata( + error=f"Reflection failed: {str(e)}" + ) + + # Update the state with the original synthesized info as enhanced info + # and the error metadata + state.update_after_reflection(state.synthesized_info, error_metadata) + + # Log error metadata + execution_time = time.time() - start_time + log_metadata( + metadata={ + "iterative_reflection": { + "execution_time_seconds": execution_time, + "llm_model": llm_model, + "max_additional_searches": max_additional_searches, + "searches_performed": 0, + "status": "failed", + "error_message": str(e), + } + } + ) + + return state diff --git a/deep_research/steps/merge_results_step.py b/deep_research/steps/merge_results_step.py new file mode 100644 index 00000000..a8c8cd53 --- /dev/null +++ b/deep_research/steps/merge_results_step.py @@ -0,0 +1,265 @@ +import copy +import logging +import time +from typing import Annotated + +from materializers.pydantic_materializer import ResearchStateMaterializer +from utils.pydantic_models import ResearchState +from zenml import get_step_context, log_metadata, step +from zenml.client import Client + +logger = logging.getLogger(__name__) + + +@step(output_materializers=ResearchStateMaterializer) +def merge_sub_question_results_step( + original_state: ResearchState, + step_prefix: str = "process_question_", + output_name: str = "output", +) -> Annotated[ResearchState, "merged_state"]: + """Merge results from individual sub-question processing steps. + + This step collects the results from the parallel sub-question processing steps + and combines them into a single, comprehensive state object. + + Args: + original_state: The original research state with all sub-questions + step_prefix: The prefix used in step IDs for the parallel processing steps + output_name: The name of the output artifact from the processing steps + + Returns: + Annotated[ResearchState, "merged_state"]: A merged ResearchState with combined + results from all sub-questions + + Note: + This step is typically configured with the 'after' parameter in the pipeline + definition to ensure it runs after all parallel sub-question processing steps + have completed. + """ + start_time = time.time() + + # Start with the original state that has all sub-questions + merged_state = copy.deepcopy(original_state) + + # Initialize empty dictionaries for the results + merged_state.search_results = {} + merged_state.synthesized_info = {} + + # Initialize search cost tracking + merged_state.search_costs = {} + merged_state.search_cost_details = [] + + # Get pipeline run information to access outputs + try: + ctx = get_step_context() + if not ctx or not ctx.pipeline_run: + logger.error("Could not get pipeline run context") + return merged_state + + run_name = ctx.pipeline_run.name + client = Client() + run = client.get_pipeline_run(run_name) + + logger.info( + f"Merging results from parallel sub-question processing steps in run: {run_name}" + ) + + # Track which sub-questions were successfully processed + processed_questions = set() + parallel_steps_processed = 0 + + # Process each step in the run + for step_name, step_info in run.steps.items(): + # Only process steps with the specified prefix + if step_name.startswith(step_prefix): + try: + # Extract the sub-question index from the step name + if "_" in step_name: + index = int(step_name.split("_")[-1]) + logger.info( + f"Processing results from step: {step_name} (index: {index})" + ) + + # Get the output artifact + if output_name in step_info.outputs: + output_artifacts = step_info.outputs[output_name] + if output_artifacts: + output_artifact = output_artifacts[0] + sub_state = output_artifact.load() + + # Check if the sub-state has valid data + if ( + hasattr(sub_state, "sub_questions") + and sub_state.sub_questions + ): + sub_question = sub_state.sub_questions[0] + logger.info( + f"Found results for sub-question: {sub_question}" + ) + parallel_steps_processed += 1 + processed_questions.add(sub_question) + + # Merge search results + if ( + hasattr(sub_state, "search_results") + and sub_question + in sub_state.search_results + ): + merged_state.search_results[ + sub_question + ] = sub_state.search_results[ + sub_question + ] + logger.info( + f"Added search results for: {sub_question}" + ) + + # Merge synthesized info + if ( + hasattr(sub_state, "synthesized_info") + and sub_question + in sub_state.synthesized_info + ): + merged_state.synthesized_info[ + sub_question + ] = sub_state.synthesized_info[ + sub_question + ] + logger.info( + f"Added synthesized info for: {sub_question}" + ) + + # Merge search costs + if hasattr(sub_state, "search_costs"): + for ( + provider, + cost, + ) in sub_state.search_costs.items(): + merged_state.search_costs[ + provider + ] = ( + merged_state.search_costs.get( + provider, 0.0 + ) + + cost + ) + + # Merge search cost details + if hasattr( + sub_state, "search_cost_details" + ): + merged_state.search_cost_details.extend( + sub_state.search_cost_details + ) + except (ValueError, IndexError, KeyError, AttributeError) as e: + logger.warning(f"Error processing step {step_name}: {e}") + continue + + # Log summary + logger.info( + f"Merged results from {parallel_steps_processed} parallel steps" + ) + logger.info( + f"Successfully processed {len(processed_questions)} sub-questions" + ) + + # Log search cost summary + if merged_state.search_costs: + total_cost = sum(merged_state.search_costs.values()) + logger.info( + f"Total search costs merged: ${total_cost:.4f} across {len(merged_state.search_cost_details)} queries" + ) + for provider, cost in merged_state.search_costs.items(): + logger.info(f" {provider}: ${cost:.4f}") + + # Check for any missing sub-questions + for sub_q in merged_state.sub_questions: + if sub_q not in processed_questions: + logger.warning(f"Missing results for sub-question: {sub_q}") + + except Exception as e: + logger.error(f"Error during merge step: {e}") + + # Final check for empty results + if not merged_state.search_results or not merged_state.synthesized_info: + logger.warning( + "No results were found or merged from parallel processing steps!" + ) + + # Calculate execution time + execution_time = time.time() - start_time + + # Calculate metrics + missing_questions = [ + q for q in merged_state.sub_questions if q not in processed_questions + ] + + # Count total search results across all questions + total_search_results = sum( + len(results) for results in merged_state.search_results.values() + ) + + # Get confidence distribution for merged results + confidence_distribution = {"high": 0, "medium": 0, "low": 0} + for info in merged_state.synthesized_info.values(): + level = info.confidence_level.lower() + if level in confidence_distribution: + confidence_distribution[level] += 1 + + # Calculate completeness ratio + completeness_ratio = ( + len(processed_questions) / len(merged_state.sub_questions) + if merged_state.sub_questions + else 0 + ) + + # Log metadata + log_metadata( + metadata={ + "merge_results": { + "execution_time_seconds": execution_time, + "total_sub_questions": len(merged_state.sub_questions), + "parallel_steps_processed": parallel_steps_processed, + "questions_successfully_merged": len(processed_questions), + "missing_questions_count": len(missing_questions), + "missing_questions": missing_questions[:5] + if missing_questions + else [], # Limit to 5 for metadata + "total_search_results": total_search_results, + "confidence_distribution": confidence_distribution, + "merge_success": bool( + merged_state.search_results + and merged_state.synthesized_info + ), + "total_search_costs": merged_state.search_costs, + "total_search_queries": len(merged_state.search_cost_details), + "total_exa_cost": merged_state.search_costs.get("exa", 0.0), + } + } + ) + + # Log model metadata for cross-pipeline tracking + log_metadata( + metadata={ + "research_quality": { + "completeness_ratio": completeness_ratio, + } + }, + infer_model=True, + ) + + # Log artifact metadata + log_metadata( + metadata={ + "merged_state_characteristics": { + "has_search_results": bool(merged_state.search_results), + "has_synthesized_info": bool(merged_state.synthesized_info), + "search_results_count": len(merged_state.search_results), + "synthesized_info_count": len(merged_state.synthesized_info), + "completeness_ratio": completeness_ratio, + } + }, + infer_artifact=True, + ) + + return merged_state diff --git a/deep_research/steps/process_sub_question_step.py b/deep_research/steps/process_sub_question_step.py new file mode 100644 index 00000000..f6b077c7 --- /dev/null +++ b/deep_research/steps/process_sub_question_step.py @@ -0,0 +1,289 @@ +import copy +import logging +import time +import warnings +from typing import Annotated + +# Suppress Pydantic serialization warnings from ZenML artifact metadata +# These occur when ZenML stores timestamp metadata as floats but models expect ints +warnings.filterwarnings( + "ignore", message=".*PydanticSerializationUnexpectedValue.*" +) + +from materializers.pydantic_materializer import ResearchStateMaterializer +from utils.llm_utils import synthesize_information +from utils.prompt_models import PromptsBundle +from utils.pydantic_models import ResearchState, SynthesizedInfo +from utils.search_utils import ( + generate_search_query, + search_and_extract_results, +) +from zenml import log_metadata, step + +logger = logging.getLogger(__name__) + + +@step(output_materializers=ResearchStateMaterializer) +def process_sub_question_step( + state: ResearchState, + prompts_bundle: PromptsBundle, + question_index: int, + llm_model_search: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + llm_model_synthesis: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + num_results_per_search: int = 3, + cap_search_length: int = 20000, + search_provider: str = "tavily", + search_mode: str = "auto", + langfuse_project_name: str = "deep-research", +) -> Annotated[ResearchState, "output"]: + """Process a single sub-question if it exists at the given index. + + This step combines the gathering and synthesis steps for a single sub-question. + It's designed to be run in parallel for each sub-question. + + Args: + state: The original research state with all sub-questions + prompts_bundle: Bundle containing all prompts for the pipeline + question_index: The index of the sub-question to process + llm_model_search: Model to use for search query generation + llm_model_synthesis: Model to use for synthesis + num_results_per_search: Number of results to fetch per search + cap_search_length: Maximum length of content to process from search results + search_provider: Search provider to use (tavily, exa, or both) + search_mode: Search mode for Exa provider (neural, keyword, or auto) + + Returns: + A new ResearchState containing only the processed sub-question's results + """ + start_time = time.time() + + # Create a copy of the state to avoid modifying the original + sub_state = copy.deepcopy(state) + + # Clear all existing data except the main query + sub_state.search_results = {} + sub_state.synthesized_info = {} + sub_state.enhanced_info = {} + sub_state.viewpoint_analysis = None + sub_state.reflection_metadata = None + sub_state.final_report_html = "" + + # Check if this index exists in sub-questions + if question_index >= len(state.sub_questions): + logger.info( + f"No sub-question at index {question_index}, skipping processing" + ) + # Log metadata for skipped processing + log_metadata( + metadata={ + "sub_question_processing": { + "question_index": question_index, + "status": "skipped", + "reason": "index_out_of_range", + "total_sub_questions": len(state.sub_questions), + } + } + ) + # Return an empty state since there's no question to process + sub_state.sub_questions = [] + return sub_state + + # Get the target sub-question + sub_question = state.sub_questions[question_index] + logger.info( + f"Processing sub-question {question_index + 1}: {sub_question}" + ) + + # Store only this sub-question in the sub-state + sub_state.sub_questions = [sub_question] + + # === INFORMATION GATHERING === + search_phase_start = time.time() + + # Generate search query with prompt from bundle + search_query_prompt = prompts_bundle.get_prompt_content( + "search_query_prompt" + ) + search_query_data = generate_search_query( + sub_question=sub_question, + model=llm_model_search, + system_prompt=search_query_prompt, + project=langfuse_project_name, + ) + search_query = search_query_data.get( + "search_query", f"research about {sub_question}" + ) + + # Perform search + logger.info(f"Performing search with query: {search_query}") + if search_provider: + logger.info(f"Using search provider: {search_provider}") + results_list, search_cost = search_and_extract_results( + query=search_query, + max_results=num_results_per_search, + cap_content_length=cap_search_length, + provider=search_provider, + search_mode=search_mode, + ) + + # Track search costs if using Exa + if ( + search_provider + and search_provider.lower() in ["exa", "both"] + and search_cost > 0 + ): + # Update total costs + sub_state.search_costs["exa"] = ( + sub_state.search_costs.get("exa", 0.0) + search_cost + ) + + # Add detailed cost entry + sub_state.search_cost_details.append( + { + "provider": "exa", + "query": search_query, + "cost": search_cost, + "timestamp": time.time(), + "step": "process_sub_question", + "sub_question": sub_question, + "question_index": question_index, + } + ) + logger.info( + f"Exa search cost for sub-question {question_index}: ${search_cost:.4f}" + ) + + search_results = {sub_question: results_list} + sub_state.update_search_results(search_results) + + search_phase_time = time.time() - search_phase_start + + # === INFORMATION SYNTHESIS === + synthesis_phase_start = time.time() + + # Extract raw contents and URLs + raw_contents = [] + sources = [] + for result in results_list: + raw_contents.append(result.content) + sources.append(result.url) + + # Prepare input for synthesis + synthesis_input = { + "sub_question": sub_question, + "search_results": raw_contents, + "sources": sources, + } + + # Synthesize information with prompt from bundle + synthesis_prompt = prompts_bundle.get_prompt_content("synthesis_prompt") + synthesis_result = synthesize_information( + synthesis_input=synthesis_input, + model=llm_model_synthesis, + system_prompt=synthesis_prompt, + project=langfuse_project_name, + ) + + # Create SynthesizedInfo object + synthesized_info = { + sub_question: SynthesizedInfo( + synthesized_answer=synthesis_result.get( + "synthesized_answer", f"Synthesis for '{sub_question}' failed." + ), + key_sources=synthesis_result.get("key_sources", sources[:1]), + confidence_level=synthesis_result.get("confidence_level", "low"), + information_gaps=synthesis_result.get( + "information_gaps", + "Synthesis process encountered technical difficulties.", + ), + improvements=synthesis_result.get("improvements", []), + ) + } + + # Update the state with synthesized information + sub_state.update_synthesized_info(synthesized_info) + + synthesis_phase_time = time.time() - synthesis_phase_start + total_execution_time = time.time() - start_time + + # Calculate total content length processed + total_content_length = sum(len(content) for content in raw_contents) + + # Get unique domains from sources + unique_domains = set() + for url in sources: + try: + from urllib.parse import urlparse + + domain = urlparse(url).netloc + unique_domains.add(domain) + except: + pass + + # Log comprehensive metadata + log_metadata( + metadata={ + "sub_question_processing": { + "question_index": question_index, + "status": "completed", + "sub_question": sub_question, + "execution_time_seconds": total_execution_time, + "search_phase_time_seconds": search_phase_time, + "synthesis_phase_time_seconds": synthesis_phase_time, + "search_query": search_query, + "search_provider": search_provider, + "search_mode": search_mode, + "num_results_requested": num_results_per_search, + "num_results_retrieved": len(results_list), + "total_content_length": total_content_length, + "cap_search_length": cap_search_length, + "unique_domains": list(unique_domains), + "llm_model_search": llm_model_search, + "llm_model_synthesis": llm_model_synthesis, + "confidence_level": synthesis_result.get( + "confidence_level", "low" + ), + "information_gaps": synthesis_result.get( + "information_gaps", "" + ), + "key_sources_count": len( + synthesis_result.get("key_sources", []) + ), + "search_cost": search_cost, + "search_cost_provider": "exa" + if search_provider + and search_provider.lower() in ["exa", "both"] + else None, + } + } + ) + + # Log model metadata for cross-pipeline tracking + log_metadata( + metadata={ + "search_metrics": { + "confidence_level": synthesis_result.get( + "confidence_level", "low" + ), + "search_provider": search_provider, + } + }, + infer_model=True, + ) + + # Log artifact metadata for the output state + log_metadata( + metadata={ + "sub_state_characteristics": { + "has_search_results": bool(sub_state.search_results), + "has_synthesized_info": bool(sub_state.synthesized_info), + "sub_question_processed": sub_question, + "confidence_level": synthesis_result.get( + "confidence_level", "low" + ), + } + }, + infer_artifact=True, + ) + + return sub_state diff --git a/deep_research/steps/pydantic_final_report_step.py b/deep_research/steps/pydantic_final_report_step.py new file mode 100644 index 00000000..dc05b1bb --- /dev/null +++ b/deep_research/steps/pydantic_final_report_step.py @@ -0,0 +1,1250 @@ +"""Final report generation step using Pydantic models and materializers. + +This module provides a ZenML pipeline step for generating the final HTML research report +using Pydantic models and improved materializers. +""" + +import html +import json +import logging +import re +import time +from typing import Annotated, Tuple + +from materializers.pydantic_materializer import ResearchStateMaterializer +from utils.helper_functions import ( + extract_html_from_content, + remove_reasoning_from_output, +) +from utils.llm_utils import run_llm_completion +from utils.prompt_models import PromptsBundle +from utils.prompts import ( + STATIC_HTML_TEMPLATE, + SUB_QUESTION_TEMPLATE, + VIEWPOINT_ANALYSIS_TEMPLATE, +) +from utils.pydantic_models import ResearchState +from zenml import log_metadata, step +from zenml.types import HTMLString + +logger = logging.getLogger(__name__) + + +def clean_html_output(html_content: str) -> str: + """Clean HTML output from LLM to ensure proper rendering. + + This function removes markdown code blocks, fixes common issues with LLM HTML output, + and ensures we have proper HTML structure for rendering. + + Args: + html_content: Raw HTML content from LLM + + Returns: + Cleaned HTML content ready for rendering + """ + # Remove markdown code block markers (```html and ```) + html_content = re.sub(r"```html\s*", "", html_content) + html_content = re.sub(r"```\s*$", "", html_content) + html_content = re.sub(r"```", "", html_content) + + # Remove any CSS code block markers + html_content = re.sub(r"```css\s*", "", html_content) + + # Ensure HTML content is properly wrapped in HTML tags if not already + if not html_content.strip().startswith( + "{html_content}' + + html_content = re.sub(r"\[CSS STYLESHEET GOES HERE\]", "", html_content) + html_content = re.sub(r"\[SUB-QUESTIONS LINKS\]", "", html_content) + html_content = re.sub(r"\[ADDITIONAL SECTIONS LINKS\]", "", html_content) + html_content = re.sub(r"\[FOR EACH SUB-QUESTION\]:", "", html_content) + html_content = re.sub(r"\[FOR EACH TENSION\]:", "", html_content) + + # Replace content placeholders with appropriate defaults + html_content = re.sub( + r"\[CONCISE SUMMARY OF KEY FINDINGS\]", + "Summary of findings from the research query.", + html_content, + ) + html_content = re.sub( + r"\[INTRODUCTION TO THE RESEARCH QUERY\]", + "Introduction to the research topic.", + html_content, + ) + html_content = re.sub( + r"\[OVERVIEW OF THE APPROACH AND SUB-QUESTIONS\]", + "Overview of the research approach.", + html_content, + ) + html_content = re.sub( + r"\[CONCLUSION TEXT\]", + "Conclusion of the research findings.", + html_content, + ) + + return html_content + + +def format_text_with_code_blocks(text: str) -> str: + """Format text with proper handling of code blocks and markdown formatting. + + Args: + text: The raw text to format + + Returns: + str: HTML-formatted text + """ + if not text: + return "" + + # First escape HTML + escaped_text = html.escape(text) + + # Handle code blocks (wrap content in ``` or ```) + pattern = r"```(?:\w*\n)?(.*?)```" + + def code_block_replace(match): + code_content = match.group(1) + # Strip extra newlines at beginning and end + code_content = code_content.strip("\n") + return f"
{code_content}
" + + # Replace code blocks + formatted_text = re.sub( + pattern, code_block_replace, escaped_text, flags=re.DOTALL + ) + + # Convert regular newlines to
tags (but not inside
 blocks)
+    parts = []
+    in_pre = False
+    for line in formatted_text.split("\n"):
+        if "
" in line:
+            in_pre = True
+            parts.append(line)
+        elif "
" in line: + in_pre = False + parts.append(line) + elif in_pre: + # Inside a code block, preserve newlines + parts.append(line) + else: + # Outside code blocks, convert newlines to
+ parts.append(line + "
") + + return "".join(parts) + + +def generate_executive_summary( + state: ResearchState, + prompts_bundle: PromptsBundle, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + langfuse_project_name: str = "deep-research", +) -> str: + """Generate an executive summary using LLM based on research findings. + + Args: + state: The current research state + prompts_bundle: Bundle containing all prompts for the pipeline + llm_model: The model to use for generation + langfuse_project_name: Name of the Langfuse project for tracking + + Returns: + HTML formatted executive summary + """ + logger.info("Generating executive summary using LLM") + + # Prepare the context with all research findings + context = f"Main Research Query: {state.main_query}\n\n" + + # Add synthesized findings for each sub-question + for i, sub_question in enumerate(state.sub_questions, 1): + info = state.enhanced_info.get( + sub_question + ) or state.synthesized_info.get(sub_question) + if info: + context += f"Sub-question {i}: {sub_question}\n" + context += f"Answer Summary: {info.synthesized_answer[:500]}...\n" + context += f"Confidence: {info.confidence_level}\n" + context += f"Key Sources: {', '.join(info.key_sources[:3]) if info.key_sources else 'N/A'}\n\n" + + # Add viewpoint analysis insights if available + if state.viewpoint_analysis: + context += "Key Areas of Agreement:\n" + for agreement in state.viewpoint_analysis.main_points_of_agreement[:3]: + context += f"- {agreement}\n" + context += "\nKey Tensions:\n" + for tension in state.viewpoint_analysis.areas_of_tension[:2]: + context += f"- {tension.topic}\n" + + # Get the executive summary prompt + try: + executive_summary_prompt = prompts_bundle.get_prompt_content( + "executive_summary_prompt" + ) + logger.info("Successfully retrieved executive_summary_prompt") + except Exception as e: + logger.error(f"Failed to get executive_summary_prompt: {e}") + logger.info( + f"Available prompts: {list(prompts_bundle.list_all_prompts().keys())}" + ) + return generate_fallback_executive_summary(state) + + try: + # Call LLM to generate executive summary + result = run_llm_completion( + prompt=context, + system_prompt=executive_summary_prompt, + model=llm_model, + temperature=0.7, + max_tokens=800, + project=langfuse_project_name, + tags=["executive_summary_generation"], + ) + + if result: + content = remove_reasoning_from_output(result) + # Clean up the HTML + content = extract_html_from_content(content) + logger.info("Successfully generated LLM-based executive summary") + return content + else: + logger.warning("Failed to generate executive summary via LLM") + return generate_fallback_executive_summary(state) + + except Exception as e: + logger.error(f"Error generating executive summary: {e}") + return generate_fallback_executive_summary(state) + + +def generate_introduction( + state: ResearchState, + prompts_bundle: PromptsBundle, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + langfuse_project_name: str = "deep-research", +) -> str: + """Generate an introduction using LLM based on research query and sub-questions. + + Args: + state: The current research state + prompts_bundle: Bundle containing all prompts for the pipeline + llm_model: The model to use for generation + langfuse_project_name: Name of the Langfuse project for tracking + + Returns: + HTML formatted introduction + """ + logger.info("Generating introduction using LLM") + + # Prepare the context + context = f"Main Research Query: {state.main_query}\n\n" + context += "Sub-questions being explored:\n" + for i, sub_question in enumerate(state.sub_questions, 1): + context += f"{i}. {sub_question}\n" + + # Get the introduction prompt + try: + introduction_prompt = prompts_bundle.get_prompt_content( + "introduction_prompt" + ) + logger.info("Successfully retrieved introduction_prompt") + except Exception as e: + logger.error(f"Failed to get introduction_prompt: {e}") + logger.info( + f"Available prompts: {list(prompts_bundle.list_all_prompts().keys())}" + ) + return generate_fallback_introduction(state) + + try: + # Call LLM to generate introduction + result = run_llm_completion( + prompt=context, + system_prompt=introduction_prompt, + model=llm_model, + temperature=0.7, + max_tokens=600, + project=langfuse_project_name, + tags=["introduction_generation"], + ) + + if result: + content = remove_reasoning_from_output(result) + # Clean up the HTML + content = extract_html_from_content(content) + logger.info("Successfully generated LLM-based introduction") + return content + else: + logger.warning("Failed to generate introduction via LLM") + return generate_fallback_introduction(state) + + except Exception as e: + logger.error(f"Error generating introduction: {e}") + return generate_fallback_introduction(state) + + +def generate_fallback_executive_summary(state: ResearchState) -> str: + """Generate a fallback executive summary when LLM fails.""" + summary = f"

This report examines the question: {html.escape(state.main_query)}

" + summary += f"

The research explored {len(state.sub_questions)} key dimensions of this topic, " + summary += "synthesizing findings from multiple sources to provide a comprehensive analysis.

" + + # Add confidence overview + confidence_counts = {"high": 0, "medium": 0, "low": 0} + for info in state.enhanced_info.values(): + level = info.confidence_level.lower() + if level in confidence_counts: + confidence_counts[level] += 1 + + summary += f"

Overall confidence in findings: {confidence_counts['high']} high, " + summary += f"{confidence_counts['medium']} medium, {confidence_counts['low']} low.

" + + return summary + + +def generate_fallback_introduction(state: ResearchState) -> str: + """Generate a fallback introduction when LLM fails.""" + intro = f"

This report addresses the research query: {html.escape(state.main_query)}

" + intro += f"

The research was conducted by breaking down the main query into {len(state.sub_questions)} " + intro += ( + "sub-questions to explore different aspects of the topic in depth. " + ) + intro += "Each sub-question was researched independently, with findings synthesized from various sources.

" + return intro + + +def generate_conclusion( + state: ResearchState, + prompts_bundle: PromptsBundle, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + langfuse_project_name: str = "deep-research", +) -> str: + """Generate a comprehensive conclusion using LLM based on all research findings. + + Args: + state: The ResearchState containing all research findings + prompts_bundle: Bundle containing all prompts for the pipeline + llm_model: The model to use for conclusion generation + + Returns: + str: HTML-formatted conclusion content + """ + logger.info("Generating comprehensive conclusion using LLM") + + # Prepare input data for conclusion generation + conclusion_input = { + "main_query": state.main_query, + "sub_questions": state.sub_questions, + "enhanced_info": {}, + } + + # Include enhanced information for each sub-question + for question in state.sub_questions: + if question in state.enhanced_info: + info = state.enhanced_info[question] + conclusion_input["enhanced_info"][question] = { + "synthesized_answer": info.synthesized_answer, + "confidence_level": info.confidence_level, + "information_gaps": info.information_gaps, + "key_sources": info.key_sources, + "improvements": getattr(info, "improvements", []), + } + elif question in state.synthesized_info: + # Fallback to synthesized info if enhanced info not available + info = state.synthesized_info[question] + conclusion_input["enhanced_info"][question] = { + "synthesized_answer": info.synthesized_answer, + "confidence_level": info.confidence_level, + "information_gaps": info.information_gaps, + "key_sources": info.key_sources, + "improvements": [], + } + + # Include viewpoint analysis if available + if state.viewpoint_analysis: + conclusion_input["viewpoint_analysis"] = { + "main_points_of_agreement": state.viewpoint_analysis.main_points_of_agreement, + "areas_of_tension": [ + {"topic": tension.topic, "viewpoints": tension.viewpoints} + for tension in state.viewpoint_analysis.areas_of_tension + ], + "perspective_gaps": state.viewpoint_analysis.perspective_gaps, + "integrative_insights": state.viewpoint_analysis.integrative_insights, + } + + # Include reflection metadata if available + if state.reflection_metadata: + conclusion_input["reflection_metadata"] = { + "critique_summary": state.reflection_metadata.critique_summary, + "additional_questions_identified": state.reflection_metadata.additional_questions_identified, + "improvements_made": state.reflection_metadata.improvements_made, + } + + try: + # Get the prompt from the bundle + conclusion_prompt = prompts_bundle.get_prompt_content( + "conclusion_generation_prompt" + ) + + # Generate conclusion using LLM + conclusion_html = run_llm_completion( + prompt=json.dumps(conclusion_input, indent=2), + system_prompt=conclusion_prompt, + model=llm_model, + clean_output=True, + max_tokens=1500, # Sufficient for comprehensive conclusion + project=langfuse_project_name, + ) + + # Clean up any formatting issues + conclusion_html = conclusion_html.strip() + + # Remove any h2 tags with "Conclusion" text that LLM might have added + # Since we already have a Conclusion header in the template + conclusion_html = re.sub( + r"]*>\s*Conclusion\s*\s*", + "", + conclusion_html, + flags=re.IGNORECASE, + ) + conclusion_html = re.sub( + r"]*>\s*Conclusion\s*\s*", + "", + conclusion_html, + flags=re.IGNORECASE, + ) + + # Also remove plain text "Conclusion" at the start if it exists + conclusion_html = re.sub( + r"^Conclusion\s*\n*", + "", + conclusion_html.strip(), + flags=re.IGNORECASE, + ) + + if not conclusion_html.startswith("

"): + # Wrap in paragraph tags if not already formatted + conclusion_html = f"

{conclusion_html}

" + + logger.info("Successfully generated LLM-based conclusion") + return conclusion_html + + except Exception as e: + logger.warning(f"Failed to generate LLM conclusion: {e}") + # Return a basic fallback conclusion + return f"""

This report has explored {html.escape(state.main_query)} through a structured research approach, examining {len(state.sub_questions)} focused sub-questions and synthesizing information from diverse sources. The findings provide a comprehensive understanding of the topic, highlighting key aspects, perspectives, and current knowledge.

+

While some information gaps remain, as noted in the respective sections, this research provides a solid foundation for understanding the topic and its implications.

""" + + +def generate_report_from_template( + state: ResearchState, + prompts_bundle: PromptsBundle, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + langfuse_project_name: str = "deep-research", +) -> str: + """Generate a final HTML report from a static template. + + Instead of using an LLM to generate HTML, this function uses predefined HTML + templates and populates them with data from the research state. + + Args: + state: The current research state + prompts_bundle: Bundle containing all prompts for the pipeline + llm_model: The model to use for conclusion generation + + Returns: + str: The HTML content of the report + """ + logger.info( + f"Generating templated HTML report for query: {state.main_query}" + ) + + # Generate table of contents for sub-questions + sub_questions_toc = "" + for i, question in enumerate(state.sub_questions, 1): + safe_id = f"question-{i}" + sub_questions_toc += ( + f'
  • {html.escape(question)}
  • \n' + ) + + # Add viewpoint analysis to TOC if available + additional_sections_toc = "" + if state.viewpoint_analysis: + additional_sections_toc += ( + '
  • Viewpoint Analysis
  • \n' + ) + + # Generate HTML for sub-questions + sub_questions_html = "" + all_sources = set() + + for i, question in enumerate(state.sub_questions, 1): + info = state.enhanced_info.get(question, None) + + # Skip if no information is available + if not info: + continue + + # Process confidence level + confidence = info.confidence_level.lower() + confidence_upper = info.confidence_level.upper() + + # Process key sources + key_sources_html = "" + if info.key_sources: + all_sources.update(info.key_sources) + sources_list = "\n".join( + [ + f'
  • {html.escape(source)}
  • ' + if source.startswith(("http://", "https://")) + else f"
  • {html.escape(source)}
  • " + for source in info.key_sources + ] + ) + key_sources_html = f""" +
    +

    📚 Key Sources

    +
      + {sources_list} +
    +
    + """ + + # Process information gaps + info_gaps_html = "" + if info.information_gaps: + info_gaps_html = f""" +
    +

    🧩 Information Gaps

    +

    {format_text_with_code_blocks(info.information_gaps)}

    +
    + """ + + # Determine confidence icon based on level + confidence_icon = "🔴" # Default (low) + if confidence_upper == "HIGH": + confidence_icon = "🟢" + elif confidence_upper == "MEDIUM": + confidence_icon = "🟡" + + # Format the subquestion section using the template + sub_question_html = SUB_QUESTION_TEMPLATE.format( + index=i, + question=html.escape(question), + confidence=confidence, + confidence_upper=confidence_upper, + confidence_icon=confidence_icon, + answer=format_text_with_code_blocks(info.synthesized_answer), + info_gaps_html=info_gaps_html, + key_sources_html=key_sources_html, + ) + + sub_questions_html += sub_question_html + + # Generate viewpoint analysis HTML if available + viewpoint_analysis_html = "" + if state.viewpoint_analysis: + # Format points of agreement + agreements_html = "" + for point in state.viewpoint_analysis.main_points_of_agreement: + agreements_html += f"
  • {html.escape(point)}
  • \n" + + # Format areas of tension + tensions_html = "" + for tension in state.viewpoint_analysis.areas_of_tension: + viewpoints_html = "" + for title, content in tension.viewpoints.items(): + # Create category-specific styling + category_class = f"category-{title.lower()}" + category_title = title.capitalize() + + viewpoints_html += f""" +
    + {category_title} +

    {html.escape(content)}

    +
    + """ + + tensions_html += f""" +
    +

    {html.escape(tension.topic)}

    +
    + {viewpoints_html} +
    +
    + """ + + # Format the viewpoint analysis section using the template + viewpoint_analysis_html = VIEWPOINT_ANALYSIS_TEMPLATE.format( + agreements_html=agreements_html, + tensions_html=tensions_html, + perspective_gaps=format_text_with_code_blocks( + state.viewpoint_analysis.perspective_gaps + ), + integrative_insights=format_text_with_code_blocks( + state.viewpoint_analysis.integrative_insights + ), + ) + + # Generate references HTML + references_html = "
      " + if all_sources: + for source in sorted(all_sources): + if source.startswith(("http://", "https://")): + references_html += f'
    • {html.escape(source)}
    • \n' + else: + references_html += f"
    • {html.escape(source)}
    • \n" + else: + references_html += ( + "
    • No external sources were referenced in this research.
    • " + ) + references_html += "
    " + + # Generate dynamic executive summary using LLM + logger.info("Generating dynamic executive summary...") + executive_summary = generate_executive_summary( + state, prompts_bundle, llm_model, langfuse_project_name + ) + logger.info( + f"Executive summary generated: {len(executive_summary)} characters" + ) + + # Generate dynamic introduction using LLM + logger.info("Generating dynamic introduction...") + introduction_html = generate_introduction( + state, prompts_bundle, llm_model, langfuse_project_name + ) + logger.info(f"Introduction generated: {len(introduction_html)} characters") + + # Generate comprehensive conclusion using LLM + conclusion_html = generate_conclusion( + state, prompts_bundle, llm_model, langfuse_project_name + ) + + # Generate complete HTML report + html_content = STATIC_HTML_TEMPLATE.format( + main_query=html.escape(state.main_query), + sub_questions_toc=sub_questions_toc, + additional_sections_toc=additional_sections_toc, + executive_summary=executive_summary, + introduction_html=introduction_html, + num_sub_questions=len(state.sub_questions), + sub_questions_html=sub_questions_html, + viewpoint_analysis_html=viewpoint_analysis_html, + conclusion_html=conclusion_html, + references_html=references_html, + ) + + return html_content + + +def _generate_fallback_report(state: ResearchState) -> str: + """Generate a minimal fallback report when the main report generation fails. + + This function creates a simplified HTML report with a consistent structure when + the main report generation process encounters an error. The HTML includes: + - A header section with the main research query + - An error notice + - Introduction section + - Individual sections for each sub-question with available answers + - A references section if sources are available + + Args: + state: The current research state containing query and answer information + + Returns: + str: A basic HTML report with a standard research report structure + """ + # Create a simple HTML structure with embedded CSS for styling + html = f""" + + + + + + + +
    +

    Research Report: {state.main_query}

    + +
    +

    Note: This is a fallback report generated due to an error in the report generation process.

    +
    + + +
    +

    Table of Contents

    +
      +
    • Introduction
    • +""" + + # Add TOC entries for each sub-question + for i, sub_question in enumerate(state.sub_questions): + safe_id = f"question-{i + 1}" + html += f'
    • {sub_question}
    • \n' + + html += """
    • References
    • +
    +
    + + +
    +

    Executive Summary

    +

    This report presents findings related to the main research query. It explores multiple aspects of the topic through structured sub-questions and synthesizes information from various sources.

    +
    + +
    +

    Introduction

    +

    This report addresses the research query: "{state.main_query}"

    +

    The analysis is structured around {len(state.sub_questions)} sub-questions that explore different dimensions of this topic.

    +
    +""" + + # Add each sub-question and its synthesized information + for i, sub_question in enumerate(state.sub_questions): + safe_id = f"question-{i + 1}" + info = state.enhanced_info.get(sub_question, None) + + if not info: + # Try to get from synthesized info if not in enhanced info + info = state.synthesized_info.get(sub_question, None) + + if info: + answer = info.synthesized_answer + confidence = info.confidence_level + + # Add appropriate confidence class + confidence_class = "" + if confidence == "high": + confidence_class = "confidence-high" + elif confidence == "medium": + confidence_class = "confidence-medium" + elif confidence == "low": + confidence_class = "confidence-low" + + html += f""" +
    +

    {i + 1}. {sub_question}

    +

    Confidence Level: {confidence.upper()}

    +
    +

    {answer}

    +
    + """ + + # Add information gaps if available + if hasattr(info, "information_gaps") and info.information_gaps: + html += f""" +
    +

    Information Gaps

    +

    {info.information_gaps}

    +
    + """ + + # Add improvements if available + if hasattr(info, "improvements") and info.improvements: + html += """ +
    +

    Improvements Made

    +
      + """ + + for improvement in info.improvements: + html += f"
    • {improvement}
    • \n" + + html += """ +
    +
    + """ + + # Add key sources if available + if hasattr(info, "key_sources") and info.key_sources: + html += """ +
    +

    Key Sources

    +
      + """ + + for source in info.key_sources: + html += f"
    • {source}
    • \n" + + html += """ +
    +
    + """ + + html += """ +
    + """ + else: + html += f""" +
    +

    {i + 1}. {sub_question}

    +

    No information available for this question.

    +
    + """ + + # Add conclusion section + html += """ +
    +

    Conclusion

    +

    This report has explored the research query through multiple sub-questions, providing synthesized information based on available sources. While limitations exist in some areas, the report provides a structured analysis of the topic.

    +
    + """ + + # Add sources if available + sources_set = set() + for info in state.enhanced_info.values(): + if info.key_sources: + sources_set.update(info.key_sources) + + if sources_set: + html += """ +
    +

    References

    +
      + """ + + for source in sorted(sources_set): + html += f"
    • {source}
    • \n" + + html += """ +
    +
    + """ + else: + html += """ +
    +

    References

    +

    No references available.

    +
    + """ + + # Close the HTML structure + html += """ +
    + + + """ + + return html + + +@step( + output_materializers={ + "state": ResearchStateMaterializer, + } +) +def pydantic_final_report_step( + state: ResearchState, + prompts_bundle: PromptsBundle, + use_static_template: bool = True, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + langfuse_project_name: str = "deep-research", +) -> Tuple[ + Annotated[ResearchState, "state"], + Annotated[HTMLString, "report_html"], +]: + """Generate the final research report in HTML format using Pydantic models. + + This step uses the Pydantic models and materializers to generate a final + HTML report and return both the updated state and the HTML report as + separate artifacts. + + Args: + state: The current research state (Pydantic model) + prompts_bundle: Bundle containing all prompts for the pipeline + use_static_template: Whether to use a static template instead of LLM generation + llm_model: The model to use for report generation with provider prefix + + Returns: + A tuple containing the updated research state and the HTML report + """ + start_time = time.time() + logger.info("Generating final research report using Pydantic models") + + if use_static_template: + # Use the static HTML template approach + logger.info("Using static HTML template for report generation") + html_content = generate_report_from_template( + state, prompts_bundle, llm_model, langfuse_project_name + ) + + # Update the state with the final report HTML + state.set_final_report(html_content) + + # Collect metadata about the report + execution_time = time.time() - start_time + + # Count sources + all_sources = set() + for info in state.enhanced_info.values(): + if info.key_sources: + all_sources.update(info.key_sources) + + # Count confidence levels + confidence_distribution = {"high": 0, "medium": 0, "low": 0} + for info in state.enhanced_info.values(): + level = info.confidence_level.lower() + if level in confidence_distribution: + confidence_distribution[level] += 1 + + # Log metadata + log_metadata( + metadata={ + "report_generation": { + "execution_time_seconds": execution_time, + "generation_method": "static_template", + "llm_model": llm_model, + "report_length_chars": len(html_content), + "num_sub_questions": len(state.sub_questions), + "num_sources": len(all_sources), + "has_viewpoint_analysis": bool(state.viewpoint_analysis), + "has_reflection": bool(state.reflection_metadata), + "confidence_distribution": confidence_distribution, + "fallback_report": False, + } + } + ) + + # Log model metadata for cross-pipeline tracking + log_metadata( + metadata={ + "research_quality": { + "confidence_distribution": confidence_distribution, + } + }, + infer_model=True, + ) + + # Log artifact metadata for the HTML report + log_metadata( + metadata={ + "html_report_characteristics": { + "size_bytes": len(html_content.encode("utf-8")), + "has_toc": "toc" in html_content.lower(), + "has_executive_summary": "executive summary" + in html_content.lower(), + "has_conclusion": "conclusion" in html_content.lower(), + "has_references": "references" in html_content.lower(), + } + }, + infer_artifact=True, + artifact_name="report_html", + ) + + logger.info( + "Final research report generated successfully with static template" + ) + return state, HTMLString(html_content) + + # Otherwise use the LLM-generated approach + # Convert Pydantic model to dict for LLM input + report_input = { + "main_query": state.main_query, + "sub_questions": state.sub_questions, + "synthesized_information": state.enhanced_info, + } + + if state.viewpoint_analysis: + report_input["viewpoint_analysis"] = state.viewpoint_analysis + + if state.reflection_metadata: + report_input["reflection_metadata"] = state.reflection_metadata + + # Generate the report + try: + logger.info(f"Calling {llm_model} to generate final report") + + # Get the prompt from the bundle + report_prompt = prompts_bundle.get_prompt_content( + "report_generation_prompt" + ) + + # Use the utility function to run LLM completion + html_content = run_llm_completion( + prompt=json.dumps(report_input), + system_prompt=report_prompt, + model=llm_model, + clean_output=False, # Don't clean in case of breaking HTML formatting + max_tokens=4000, # Increased token limit for detailed report generation + project=langfuse_project_name, + ) + + # Clean up any JSON wrapper or other artifacts + html_content = remove_reasoning_from_output(html_content) + + # Process the HTML content to remove code block markers and fix common issues + html_content = clean_html_output(html_content) + + # Basic validation of HTML content + if not html_content.strip().startswith("<"): + logger.warning( + "Generated content does not appear to be valid HTML" + ) + # Try to extract HTML if it might be wrapped in code blocks or JSON + html_content = extract_html_from_content(html_content) + + # Update the state with the final report HTML + state.set_final_report(html_content) + + # Collect metadata about the report + execution_time = time.time() - start_time + + # Count sources + all_sources = set() + for info in state.enhanced_info.values(): + if info.key_sources: + all_sources.update(info.key_sources) + + # Count confidence levels + confidence_distribution = {"high": 0, "medium": 0, "low": 0} + for info in state.enhanced_info.values(): + level = info.confidence_level.lower() + if level in confidence_distribution: + confidence_distribution[level] += 1 + + # Log metadata + log_metadata( + metadata={ + "report_generation": { + "execution_time_seconds": execution_time, + "generation_method": "llm_generated", + "llm_model": llm_model, + "report_length_chars": len(html_content), + "num_sub_questions": len(state.sub_questions), + "num_sources": len(all_sources), + "has_viewpoint_analysis": bool(state.viewpoint_analysis), + "has_reflection": bool(state.reflection_metadata), + "confidence_distribution": confidence_distribution, + "fallback_report": False, + } + } + ) + + # Log model metadata for cross-pipeline tracking + log_metadata( + metadata={ + "research_quality": { + "confidence_distribution": confidence_distribution, + } + }, + infer_model=True, + ) + + logger.info("Final research report generated successfully") + return state, HTMLString(html_content) + + except Exception as e: + logger.error(f"Error generating final report: {e}") + # Generate a minimal fallback report + fallback_html = _generate_fallback_report(state) + + # Process the fallback HTML to ensure it's clean + fallback_html = clean_html_output(fallback_html) + + # Update the state with the fallback report + state.set_final_report(fallback_html) + + # Collect metadata about the fallback report + execution_time = time.time() - start_time + + # Count sources + all_sources = set() + for info in state.enhanced_info.values(): + if info.key_sources: + all_sources.update(info.key_sources) + + # Count confidence levels + confidence_distribution = {"high": 0, "medium": 0, "low": 0} + for info in state.enhanced_info.values(): + level = info.confidence_level.lower() + if level in confidence_distribution: + confidence_distribution[level] += 1 + + # Log metadata for fallback report + log_metadata( + metadata={ + "report_generation": { + "execution_time_seconds": execution_time, + "generation_method": "fallback", + "llm_model": llm_model, + "report_length_chars": len(fallback_html), + "num_sub_questions": len(state.sub_questions), + "num_sources": len(all_sources), + "has_viewpoint_analysis": bool(state.viewpoint_analysis), + "has_reflection": bool(state.reflection_metadata), + "confidence_distribution": confidence_distribution, + "fallback_report": True, + "error_message": str(e), + } + } + ) + + # Log model metadata for cross-pipeline tracking + log_metadata( + metadata={ + "research_quality": { + "confidence_distribution": confidence_distribution, + } + }, + infer_model=True, + ) + + return state, HTMLString(fallback_html) diff --git a/deep_research/steps/query_decomposition_step.py b/deep_research/steps/query_decomposition_step.py new file mode 100644 index 00000000..78e50b9f --- /dev/null +++ b/deep_research/steps/query_decomposition_step.py @@ -0,0 +1,174 @@ +import logging +import time +from typing import Annotated + +from materializers.pydantic_materializer import ResearchStateMaterializer +from utils.llm_utils import get_structured_llm_output +from utils.prompt_models import PromptsBundle +from utils.pydantic_models import ResearchState +from zenml import log_metadata, step + +logger = logging.getLogger(__name__) + + +@step(output_materializers=ResearchStateMaterializer) +def initial_query_decomposition_step( + state: ResearchState, + prompts_bundle: PromptsBundle, + llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", + max_sub_questions: int = 8, + langfuse_project_name: str = "deep-research", +) -> Annotated[ResearchState, "updated_state"]: + """Break down a complex research query into specific sub-questions. + + Args: + state: The current research state + prompts_bundle: Bundle containing all prompts for the pipeline + llm_model: The reasoning model to use with provider prefix + max_sub_questions: Maximum number of sub-questions to generate + + Returns: + Updated research state with sub-questions + """ + start_time = time.time() + logger.info(f"Decomposing research query: {state.main_query}") + + # Get the prompt from the bundle + system_prompt = prompts_bundle.get_prompt_content( + "query_decomposition_prompt" + ) + + try: + # Call OpenAI API to decompose the query + updated_system_prompt = ( + system_prompt + + f"\nPlease generate at most {max_sub_questions} sub-questions." + ) + logger.info( + f"Calling {llm_model} to decompose query into max {max_sub_questions} sub-questions" + ) + + # Define fallback questions + fallback_questions = [ + { + "sub_question": f"What is {state.main_query}?", + "reasoning": "Basic understanding of the topic", + }, + { + "sub_question": f"What are the key aspects of {state.main_query}?", + "reasoning": "Exploring important dimensions", + }, + { + "sub_question": f"What are the implications of {state.main_query}?", + "reasoning": "Understanding broader impact", + }, + ] + + # Use utility function to get structured output + decomposed_questions = get_structured_llm_output( + prompt=state.main_query, + system_prompt=updated_system_prompt, + model=llm_model, + fallback_response=fallback_questions, + project=langfuse_project_name, + ) + + # Extract just the sub-questions + sub_questions = [ + item.get("sub_question") + for item in decomposed_questions + if "sub_question" in item + ] + + # Limit to max_sub_questions + sub_questions = sub_questions[:max_sub_questions] + + logger.info(f"Generated {len(sub_questions)} sub-questions") + for i, question in enumerate(sub_questions, 1): + logger.info(f" {i}. {question}") + + # Update the state with the new sub-questions + state.update_sub_questions(sub_questions) + + # Log step metadata + execution_time = time.time() - start_time + log_metadata( + metadata={ + "query_decomposition": { + "execution_time_seconds": execution_time, + "num_sub_questions": len(sub_questions), + "llm_model": llm_model, + "max_sub_questions_requested": max_sub_questions, + "fallback_used": False, + "main_query_length": len(state.main_query), + "sub_questions": sub_questions, + } + } + ) + + # Log model metadata for cross-pipeline tracking + log_metadata( + metadata={ + "research_scope": { + "num_sub_questions": len(sub_questions), + } + }, + infer_model=True, + ) + + # Log artifact metadata for the output state + log_metadata( + metadata={ + "state_characteristics": { + "total_sub_questions": len(state.sub_questions), + "has_search_results": bool(state.search_results), + "has_synthesized_info": bool(state.synthesized_info), + } + }, + infer_artifact=True, + ) + + return state + + except Exception as e: + logger.error(f"Error decomposing query: {e}") + # Return fallback questions in the state + fallback_questions = [ + f"What is {state.main_query}?", + f"What are the key aspects of {state.main_query}?", + f"What are the implications of {state.main_query}?", + ] + fallback_questions = fallback_questions[:max_sub_questions] + logger.info(f"Using {len(fallback_questions)} fallback questions:") + for i, question in enumerate(fallback_questions, 1): + logger.info(f" {i}. {question}") + state.update_sub_questions(fallback_questions) + + # Log metadata for fallback scenario + execution_time = time.time() - start_time + log_metadata( + metadata={ + "query_decomposition": { + "execution_time_seconds": execution_time, + "num_sub_questions": len(fallback_questions), + "llm_model": llm_model, + "max_sub_questions_requested": max_sub_questions, + "fallback_used": True, + "error_message": str(e), + "main_query_length": len(state.main_query), + "sub_questions": fallback_questions, + } + } + ) + + # Log model metadata for cross-pipeline tracking + log_metadata( + metadata={ + "research_scope": { + "num_sub_questions": len(fallback_questions), + } + }, + infer_model=True, + ) + + return state diff --git a/deep_research/tests/__init__.py b/deep_research/tests/__init__.py new file mode 100644 index 00000000..6206856b --- /dev/null +++ b/deep_research/tests/__init__.py @@ -0,0 +1 @@ +"""Test package for ZenML Deep Research project.""" diff --git a/deep_research/tests/conftest.py b/deep_research/tests/conftest.py new file mode 100644 index 00000000..b972a5e1 --- /dev/null +++ b/deep_research/tests/conftest.py @@ -0,0 +1,11 @@ +"""Test configuration for pytest. + +This file sets up the proper Python path for importing modules in tests. +""" + +import os +import sys + +# Add the project root directory to the Python path +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, project_root) diff --git a/deep_research/tests/test_approval_utils.py b/deep_research/tests/test_approval_utils.py new file mode 100644 index 00000000..fe859b0f --- /dev/null +++ b/deep_research/tests/test_approval_utils.py @@ -0,0 +1,150 @@ +"""Unit tests for approval utility functions.""" + +from utils.approval_utils import ( + calculate_estimated_cost, + format_approval_request, + format_critique_summary, + format_query_list, + parse_approval_response, + summarize_research_progress, +) +from utils.pydantic_models import ResearchState, SynthesizedInfo + + +def test_parse_approval_responses(): + """Test parsing different approval responses.""" + queries = ["query1", "query2", "query3"] + + # Test approve all + decision = parse_approval_response("APPROVE ALL", queries) + assert decision.approved == True + assert decision.selected_queries == queries + assert decision.approval_method == "APPROVE_ALL" + + # Test skip + decision = parse_approval_response( + "skip", queries + ) # Test case insensitive + assert decision.approved == False + assert decision.selected_queries == [] + assert decision.approval_method == "SKIP" + + # Test selection + decision = parse_approval_response("SELECT 1,3", queries) + assert decision.approved == True + assert decision.selected_queries == ["query1", "query3"] + assert decision.approval_method == "SELECT_SPECIFIC" + + # Test invalid selection + decision = parse_approval_response("SELECT invalid", queries) + assert decision.approved == False + assert decision.approval_method == "PARSE_ERROR" + + # Test out of range indices + decision = parse_approval_response("SELECT 1,5,10", queries) + assert decision.approved == True + assert decision.selected_queries == ["query1"] # Only valid indices + assert decision.approval_method == "SELECT_SPECIFIC" + + # Test unknown response + decision = parse_approval_response("maybe later", queries) + assert decision.approved == False + assert decision.approval_method == "UNKNOWN_RESPONSE" + + +def test_format_approval_request(): + """Test formatting of approval request messages.""" + message = format_approval_request( + main_query="Test query", + progress_summary={ + "completed_count": 5, + "avg_confidence": 0.75, + "low_confidence_count": 2, + }, + critique_points=[ + {"issue": "Missing data", "importance": "high"}, + {"issue": "Minor gap", "importance": "low"}, + ], + proposed_queries=["query1", "query2"], + ) + + assert "Test query" in message + assert "5" in message + assert "0.75" in message + assert "2 queries" in message + assert "approve" in message.lower() + assert "reject" in message.lower() + assert "Missing data" in message + + +def test_summarize_research_progress(): + """Test research progress summarization.""" + state = ResearchState( + main_query="test", + synthesized_info={ + "q1": SynthesizedInfo( + synthesized_answer="a1", confidence_level="high" + ), + "q2": SynthesizedInfo( + synthesized_answer="a2", confidence_level="medium" + ), + "q3": SynthesizedInfo( + synthesized_answer="a3", confidence_level="low" + ), + "q4": SynthesizedInfo( + synthesized_answer="a4", confidence_level="low" + ), + }, + ) + + summary = summarize_research_progress(state) + + assert summary["completed_count"] == 4 + # (1.0 + 0.5 + 0.0 + 0.0) / 4 = 1.5 / 4 = 0.375, rounded to 0.38 + assert summary["avg_confidence"] == 0.38 + assert summary["low_confidence_count"] == 2 + + +def test_format_critique_summary(): + """Test critique summary formatting.""" + # Test with no critiques + result = format_critique_summary([]) + assert result == "No critical issues identified." + + # Test with few critiques + critiques = [{"issue": "Issue 1"}, {"issue": "Issue 2"}] + result = format_critique_summary(critiques) + assert "- Issue 1" in result + assert "- Issue 2" in result + assert "more issues" not in result + + # Test with many critiques + critiques = [{"issue": f"Issue {i}"} for i in range(5)] + result = format_critique_summary(critiques) + assert "- Issue 0" in result + assert "- Issue 1" in result + assert "- Issue 2" in result + assert "- Issue 3" not in result + assert "... and 2 more issues" in result + + +def test_format_query_list(): + """Test query list formatting.""" + # Test empty list + result = format_query_list([]) + assert result == "No queries proposed." + + # Test with queries + queries = ["Query A", "Query B", "Query C"] + result = format_query_list(queries) + assert "1. Query A" in result + assert "2. Query B" in result + assert "3. Query C" in result + + +def test_calculate_estimated_cost(): + """Test cost estimation.""" + assert calculate_estimated_cost([]) == 0.0 + assert calculate_estimated_cost(["q1"]) == 0.01 + assert calculate_estimated_cost(["q1", "q2", "q3"]) == 0.03 + assert calculate_estimated_cost(["q1"] * 10) == 0.10 diff --git a/deep_research/tests/test_prompt_loader.py b/deep_research/tests/test_prompt_loader.py new file mode 100644 index 00000000..fdc7b7da --- /dev/null +++ b/deep_research/tests/test_prompt_loader.py @@ -0,0 +1,108 @@ +"""Unit tests for prompt loader utilities.""" + +import pytest +from utils.prompt_loader import get_prompt_for_step, load_prompts_bundle +from utils.prompt_models import PromptsBundle + + +class TestPromptLoader: + """Test cases for prompt loader functions.""" + + def test_load_prompts_bundle(self): + """Test loading prompts bundle from prompts.py.""" + bundle = load_prompts_bundle(pipeline_version="2.0.0") + + # Check it returns a PromptsBundle + assert isinstance(bundle, PromptsBundle) + + # Check pipeline version + assert bundle.pipeline_version == "2.0.0" + + # Check all core prompts are loaded + assert bundle.search_query_prompt is not None + assert bundle.query_decomposition_prompt is not None + assert bundle.synthesis_prompt is not None + assert bundle.viewpoint_analysis_prompt is not None + assert bundle.reflection_prompt is not None + assert bundle.additional_synthesis_prompt is not None + assert bundle.conclusion_generation_prompt is not None + + # Check prompts have correct metadata + assert bundle.search_query_prompt.name == "search_query_prompt" + assert ( + bundle.search_query_prompt.description + == "Generates effective search queries from sub-questions" + ) + assert bundle.search_query_prompt.version == "1.0.0" + assert "search" in bundle.search_query_prompt.tags + + # Check that actual prompt content is loaded + assert "search query" in bundle.search_query_prompt.content.lower() + assert "json schema" in bundle.search_query_prompt.content.lower() + + def test_load_prompts_bundle_default_version(self): + """Test loading prompts bundle with default version.""" + bundle = load_prompts_bundle() + assert bundle.pipeline_version == "1.0.0" + + def test_get_prompt_for_step(self): + """Test getting prompt content for specific steps.""" + bundle = load_prompts_bundle() + + # Test valid step names + test_cases = [ + ("query_decomposition", "query_decomposition_prompt"), + ("search_query_generation", "search_query_prompt"), + ("synthesis", "synthesis_prompt"), + ("viewpoint_analysis", "viewpoint_analysis_prompt"), + ("reflection", "reflection_prompt"), + ("additional_synthesis", "additional_synthesis_prompt"), + ("conclusion_generation", "conclusion_generation_prompt"), + ] + + for step_name, expected_prompt_attr in test_cases: + content = get_prompt_for_step(bundle, step_name) + expected_content = getattr(bundle, expected_prompt_attr).content + assert content == expected_content + + def test_get_prompt_for_step_invalid(self): + """Test getting prompt for invalid step name.""" + bundle = load_prompts_bundle() + + with pytest.raises( + ValueError, match="No prompt mapping found for step: invalid_step" + ): + get_prompt_for_step(bundle, "invalid_step") + + def test_all_prompts_have_content(self): + """Test that all loaded prompts have non-empty content.""" + bundle = load_prompts_bundle() + + all_prompts = bundle.list_all_prompts() + for name, prompt in all_prompts.items(): + assert prompt.content, f"Prompt {name} has empty content" + assert ( + len(prompt.content) > 50 + ), f"Prompt {name} content seems too short" + + def test_all_prompts_have_descriptions(self): + """Test that all loaded prompts have descriptions.""" + bundle = load_prompts_bundle() + + all_prompts = bundle.list_all_prompts() + for name, prompt in all_prompts.items(): + assert prompt.description, f"Prompt {name} has no description" + assert ( + len(prompt.description) > 10 + ), f"Prompt {name} description seems too short" + + def test_all_prompts_have_tags(self): + """Test that all loaded prompts have at least one tag.""" + bundle = load_prompts_bundle() + + all_prompts = bundle.list_all_prompts() + for name, prompt in all_prompts.items(): + assert prompt.tags, f"Prompt {name} has no tags" + assert ( + len(prompt.tags) >= 1 + ), f"Prompt {name} should have at least one tag" diff --git a/deep_research/tests/test_prompt_models.py b/deep_research/tests/test_prompt_models.py new file mode 100644 index 00000000..f8b0519f --- /dev/null +++ b/deep_research/tests/test_prompt_models.py @@ -0,0 +1,183 @@ +"""Unit tests for prompt models and utilities.""" + +import pytest +from utils.prompt_models import PromptsBundle, PromptTemplate + + +class TestPromptTemplate: + """Test cases for PromptTemplate model.""" + + def test_prompt_template_creation(self): + """Test creating a prompt template with all fields.""" + prompt = PromptTemplate( + name="test_prompt", + content="This is a test prompt", + description="A test prompt for unit testing", + version="1.0.0", + tags=["test", "unit"], + ) + + assert prompt.name == "test_prompt" + assert prompt.content == "This is a test prompt" + assert prompt.description == "A test prompt for unit testing" + assert prompt.version == "1.0.0" + assert prompt.tags == ["test", "unit"] + + def test_prompt_template_minimal(self): + """Test creating a prompt template with minimal fields.""" + prompt = PromptTemplate( + name="minimal_prompt", content="Minimal content" + ) + + assert prompt.name == "minimal_prompt" + assert prompt.content == "Minimal content" + assert prompt.description == "" + assert prompt.version == "1.0.0" + assert prompt.tags == [] + + +class TestPromptsBundle: + """Test cases for PromptsBundle model.""" + + @pytest.fixture + def sample_prompts(self): + """Create sample prompts for testing.""" + return { + "search_query_prompt": PromptTemplate( + name="search_query_prompt", + content="Search query content", + description="Generates search queries", + tags=["search"], + ), + "query_decomposition_prompt": PromptTemplate( + name="query_decomposition_prompt", + content="Query decomposition content", + description="Decomposes queries", + tags=["analysis"], + ), + "synthesis_prompt": PromptTemplate( + name="synthesis_prompt", + content="Synthesis content", + description="Synthesizes information", + tags=["synthesis"], + ), + "viewpoint_analysis_prompt": PromptTemplate( + name="viewpoint_analysis_prompt", + content="Viewpoint analysis content", + description="Analyzes viewpoints", + tags=["analysis"], + ), + "reflection_prompt": PromptTemplate( + name="reflection_prompt", + content="Reflection content", + description="Reflects on research", + tags=["reflection"], + ), + "additional_synthesis_prompt": PromptTemplate( + name="additional_synthesis_prompt", + content="Additional synthesis content", + description="Additional synthesis", + tags=["synthesis"], + ), + "conclusion_generation_prompt": PromptTemplate( + name="conclusion_generation_prompt", + content="Conclusion generation content", + description="Generates conclusions", + tags=["report"], + ), + } + + def test_prompts_bundle_creation(self, sample_prompts): + """Test creating a prompts bundle.""" + bundle = PromptsBundle(**sample_prompts) + + assert bundle.search_query_prompt.name == "search_query_prompt" + assert ( + bundle.query_decomposition_prompt.name + == "query_decomposition_prompt" + ) + assert bundle.pipeline_version == "1.0.0" + assert isinstance(bundle.created_at, str) + assert bundle.custom_prompts == {} + + def test_prompts_bundle_with_custom_prompts(self, sample_prompts): + """Test creating a prompts bundle with custom prompts.""" + custom_prompt = PromptTemplate( + name="custom_prompt", + content="Custom prompt content", + description="A custom prompt", + ) + + bundle = PromptsBundle( + **sample_prompts, custom_prompts={"custom_prompt": custom_prompt} + ) + + assert "custom_prompt" in bundle.custom_prompts + assert bundle.custom_prompts["custom_prompt"].name == "custom_prompt" + + def test_get_prompt_by_name(self, sample_prompts): + """Test retrieving prompts by name.""" + bundle = PromptsBundle(**sample_prompts) + + # Test getting a core prompt + prompt = bundle.get_prompt_by_name("search_query_prompt") + assert prompt is not None + assert prompt.name == "search_query_prompt" + + # Test getting a non-existent prompt + prompt = bundle.get_prompt_by_name("non_existent") + assert prompt is None + + def test_get_prompt_by_name_custom(self, sample_prompts): + """Test retrieving custom prompts by name.""" + custom_prompt = PromptTemplate( + name="custom_prompt", content="Custom content" + ) + + bundle = PromptsBundle( + **sample_prompts, custom_prompts={"custom_prompt": custom_prompt} + ) + + prompt = bundle.get_prompt_by_name("custom_prompt") + assert prompt is not None + assert prompt.name == "custom_prompt" + + def test_list_all_prompts(self, sample_prompts): + """Test listing all prompts.""" + bundle = PromptsBundle(**sample_prompts) + + all_prompts = bundle.list_all_prompts() + assert len(all_prompts) == 7 # 7 core prompts + assert "search_query_prompt" in all_prompts + assert "conclusion_generation_prompt" in all_prompts + + def test_list_all_prompts_with_custom(self, sample_prompts): + """Test listing all prompts including custom ones.""" + custom_prompt = PromptTemplate( + name="custom_prompt", content="Custom content" + ) + + bundle = PromptsBundle( + **sample_prompts, custom_prompts={"custom_prompt": custom_prompt} + ) + + all_prompts = bundle.list_all_prompts() + assert len(all_prompts) == 8 # 7 core + 1 custom + assert "custom_prompt" in all_prompts + + def test_get_prompt_content(self, sample_prompts): + """Test getting prompt content by type.""" + bundle = PromptsBundle(**sample_prompts) + + content = bundle.get_prompt_content("search_query_prompt") + assert content == "Search query content" + + content = bundle.get_prompt_content("synthesis_prompt") + assert content == "Synthesis content" + + def test_get_prompt_content_invalid(self, sample_prompts): + """Test getting prompt content with invalid type.""" + bundle = PromptsBundle(**sample_prompts) + + with pytest.raises(AttributeError): + bundle.get_prompt_content("invalid_prompt_type") diff --git a/deep_research/tests/test_pydantic_final_report_step.py b/deep_research/tests/test_pydantic_final_report_step.py new file mode 100644 index 00000000..c0f13530 --- /dev/null +++ b/deep_research/tests/test_pydantic_final_report_step.py @@ -0,0 +1,167 @@ +"""Tests for the Pydantic-based final report step. + +This module contains tests for the Pydantic-based implementation of +final_report_step, which uses the new Pydantic models and materializers. +""" + +from typing import Dict, List + +import pytest +from steps.pydantic_final_report_step import pydantic_final_report_step +from utils.pydantic_models import ( + ReflectionMetadata, + ResearchState, + SearchResult, + SynthesizedInfo, + ViewpointAnalysis, + ViewpointTension, +) +from zenml.types import HTMLString + + +@pytest.fixture +def sample_research_state() -> ResearchState: + """Create a sample research state for testing.""" + # Create a basic research state + state = ResearchState(main_query="What are the impacts of climate change?") + + # Add sub-questions + state.update_sub_questions(["Economic impacts", "Environmental impacts"]) + + # Add search results + search_results: Dict[str, List[SearchResult]] = { + "Economic impacts": [ + SearchResult( + url="https://example.com/economy", + title="Economic Impacts of Climate Change", + snippet="Overview of economic impacts", + content="Detailed content about economic impacts of climate change", + ) + ] + } + state.update_search_results(search_results) + + # Add synthesized info + synthesized_info: Dict[str, SynthesizedInfo] = { + "Economic impacts": SynthesizedInfo( + synthesized_answer="Climate change will have significant economic impacts...", + key_sources=["https://example.com/economy"], + confidence_level="high", + ), + "Environmental impacts": SynthesizedInfo( + synthesized_answer="Environmental impacts include rising sea levels...", + key_sources=["https://example.com/environment"], + confidence_level="high", + ), + } + state.update_synthesized_info(synthesized_info) + + # Add enhanced info (same as synthesized for this test) + state.enhanced_info = state.synthesized_info + + # Add viewpoint analysis + viewpoint_analysis = ViewpointAnalysis( + main_points_of_agreement=[ + "Climate change is happening", + "Action is needed", + ], + areas_of_tension=[ + ViewpointTension( + topic="Economic policy", + viewpoints={ + "Progressive": "Support carbon taxes and regulations", + "Conservative": "Prefer market-based solutions", + }, + ) + ], + perspective_gaps="Indigenous perspectives are underrepresented", + integrative_insights="A balanced approach combining regulations and market incentives may be most effective", + ) + state.update_viewpoint_analysis(viewpoint_analysis) + + # Add reflection metadata + reflection_metadata = ReflectionMetadata( + critique_summary=["Need more sources for economic impacts"], + additional_questions_identified=[ + "How will climate change affect different regions?" + ], + searches_performed=[ + "economic impacts of climate change", + "regional climate impacts", + ], + improvements_made=2, + ) + state.reflection_metadata = reflection_metadata + + return state + + +def test_pydantic_final_report_step_returns_tuple(): + """Test that the step returns a tuple with state and HTML.""" + # Create a simple state + state = ResearchState(main_query="What is climate change?") + state.update_sub_questions(["What causes climate change?"]) + + # Run the step + result = pydantic_final_report_step(state=state) + + # Assert that result is a tuple with 2 elements + assert isinstance(result, tuple) + assert len(result) == 2 + + # Assert first element is ResearchState + assert isinstance(result[0], ResearchState) + + # Assert second element is HTMLString + assert isinstance(result[1], HTMLString) + + +def test_pydantic_final_report_step_with_complex_state(sample_research_state): + """Test that the step handles a complex state properly.""" + # Run the step with a complex state + result = pydantic_final_report_step(state=sample_research_state) + + # Unpack the results + updated_state, html_report = result + + # Assert state contains final report HTML + assert updated_state.final_report_html != "" + + # Assert HTML report contains key elements + html_str = str(html_report) + assert "Economic impacts" in html_str + assert "Environmental impacts" in html_str + assert "Viewpoint Analysis" in html_str + assert "Progressive" in html_str + assert "Conservative" in html_str + + +def test_pydantic_final_report_step_updates_state(): + """Test that the step properly updates the state.""" + # Create an initial state without a final report + state = ResearchState( + main_query="What is climate change?", + sub_questions=["What causes climate change?"], + synthesized_info={ + "What causes climate change?": SynthesizedInfo( + synthesized_answer="Climate change is caused by greenhouse gases.", + confidence_level="high", + ) + }, + enhanced_info={ + "What causes climate change?": SynthesizedInfo( + synthesized_answer="Climate change is caused by greenhouse gases.", + confidence_level="high", + ) + }, + ) + + # Verify initial state has no report + assert state.final_report_html == "" + + # Run the step + updated_state, _ = pydantic_final_report_step(state=state) + + # Verify state was updated with a report + assert updated_state.final_report_html != "" + assert "climate change" in updated_state.final_report_html.lower() diff --git a/deep_research/tests/test_pydantic_materializer.py b/deep_research/tests/test_pydantic_materializer.py new file mode 100644 index 00000000..49cb17f8 --- /dev/null +++ b/deep_research/tests/test_pydantic_materializer.py @@ -0,0 +1,161 @@ +"""Tests for Pydantic-based materializer. + +This module contains tests for the Pydantic-based implementation of +ResearchStateMaterializer, verifying that it correctly serializes and +visualizes ResearchState objects. +""" + +import os +import tempfile +from typing import Dict, List + +import pytest +from materializers.pydantic_materializer import ResearchStateMaterializer +from utils.pydantic_models import ( + ResearchState, + SearchResult, + SynthesizedInfo, + ViewpointAnalysis, + ViewpointTension, +) + + +@pytest.fixture +def sample_state() -> ResearchState: + """Create a sample research state for testing.""" + # Create a basic research state + state = ResearchState(main_query="What are the impacts of climate change?") + + # Add sub-questions + state.update_sub_questions(["Economic impacts", "Environmental impacts"]) + + # Add search results + search_results: Dict[str, List[SearchResult]] = { + "Economic impacts": [ + SearchResult( + url="https://example.com/economy", + title="Economic Impacts of Climate Change", + snippet="Overview of economic impacts", + content="Detailed content about economic impacts of climate change", + ) + ] + } + state.update_search_results(search_results) + + # Add synthesized info + synthesized_info: Dict[str, SynthesizedInfo] = { + "Economic impacts": SynthesizedInfo( + synthesized_answer="Climate change will have significant economic impacts...", + key_sources=["https://example.com/economy"], + confidence_level="high", + ) + } + state.update_synthesized_info(synthesized_info) + + return state + + +def test_materializer_initialization(): + """Test that the materializer can be initialized.""" + # Create a temporary directory for artifact storage + with tempfile.TemporaryDirectory() as tmpdirname: + materializer = ResearchStateMaterializer(uri=tmpdirname) + assert materializer is not None + + +def test_materializer_save_and_load(sample_state: ResearchState): + """Test saving and loading a state using the materializer.""" + # Create a temporary directory for artifact storage + with tempfile.TemporaryDirectory() as tmpdirname: + # Initialize materializer with temporary artifact URI + materializer = ResearchStateMaterializer(uri=tmpdirname) + + # Save the state + materializer.save(sample_state) + + # Load the state + loaded_state = materializer.load(ResearchState) + + # Verify that the loaded state matches the original + assert loaded_state.main_query == sample_state.main_query + assert loaded_state.sub_questions == sample_state.sub_questions + assert len(loaded_state.search_results) == len( + sample_state.search_results + ) + assert ( + loaded_state.get_current_stage() + == sample_state.get_current_stage() + ) + + # Check that key fields were preserved + question = "Economic impacts" + assert ( + loaded_state.synthesized_info[question].synthesized_answer + == sample_state.synthesized_info[question].synthesized_answer + ) + assert ( + loaded_state.synthesized_info[question].confidence_level + == sample_state.synthesized_info[question].confidence_level + ) + + +def test_materializer_save_visualizations(sample_state: ResearchState): + """Test generating and saving visualizations.""" + # Create a temporary directory for artifact storage + with tempfile.TemporaryDirectory() as tmpdirname: + # Initialize materializer with temporary artifact URI + materializer = ResearchStateMaterializer(uri=tmpdirname) + + # Generate and save visualizations + viz_paths = materializer.save_visualizations(sample_state) + + # Verify visualization file exists + html_path = list(viz_paths.keys())[0] + assert os.path.exists(html_path) + + # Verify the file has content + with open(html_path, "r") as f: + content = f.read() + # Check for expected elements in the HTML + assert "Research State" in content + assert sample_state.main_query in content + assert "Economic impacts" in content + + +def test_html_generation_stages(sample_state: ResearchState): + """Test that HTML visualization reflects the correct research stage.""" + # Create the materializer + with tempfile.TemporaryDirectory() as tmpdirname: + materializer = ResearchStateMaterializer(uri=tmpdirname) + + # Generate visualization at initial state + html = materializer._generate_visualization_html(sample_state) + # Verify stage by checking for expected elements in the HTML + assert ( + "Synthesized Information" in html + ) # Should show synthesized info + + # Add viewpoint analysis + state_with_viewpoints = sample_state.model_copy(deep=True) + viewpoint_analysis = ViewpointAnalysis( + main_points_of_agreement=["There will be economic impacts"], + areas_of_tension=[ + ViewpointTension( + topic="Job impacts", + viewpoints={ + "Positive": "New green jobs", + "Negative": "Job losses", + }, + ) + ], + ) + state_with_viewpoints.update_viewpoint_analysis(viewpoint_analysis) + html = materializer._generate_visualization_html(state_with_viewpoints) + assert "Viewpoint Analysis" in html + assert "Points of Agreement" in html + + # Add final report + state_with_report = state_with_viewpoints.model_copy(deep=True) + state_with_report.set_final_report("Final report content") + html = materializer._generate_visualization_html(state_with_report) + assert "Final Report" in html diff --git a/deep_research/tests/test_pydantic_models.py b/deep_research/tests/test_pydantic_models.py new file mode 100644 index 00000000..b900f8d8 --- /dev/null +++ b/deep_research/tests/test_pydantic_models.py @@ -0,0 +1,303 @@ +"""Tests for Pydantic model implementations. + +This module contains tests for the Pydantic models that validate: +1. Basic model instantiation +2. Default values +3. Serialization and deserialization +4. Method functionality +""" + +import json +from typing import Dict, List + +from utils.pydantic_models import ( + ReflectionMetadata, + ResearchState, + SearchResult, + SynthesizedInfo, + ViewpointAnalysis, + ViewpointTension, +) + + +def test_search_result_creation(): + """Test creating a SearchResult model.""" + # Create with defaults + result = SearchResult() + assert result.url == "" + assert result.content == "" + assert result.title == "" + assert result.snippet == "" + + # Create with values + result = SearchResult( + url="https://example.com", + content="Example content", + title="Example Title", + snippet="This is a snippet", + ) + assert result.url == "https://example.com" + assert result.content == "Example content" + assert result.title == "Example Title" + assert result.snippet == "This is a snippet" + + +def test_search_result_serialization(): + """Test serializing and deserializing a SearchResult.""" + result = SearchResult( + url="https://example.com", + content="Example content", + title="Example Title", + snippet="This is a snippet", + ) + + # Serialize to dict + result_dict = result.model_dump() + assert result_dict["url"] == "https://example.com" + assert result_dict["content"] == "Example content" + + # Serialize to JSON + result_json = result.model_dump_json() + result_dict_from_json = json.loads(result_json) + assert result_dict_from_json["url"] == "https://example.com" + + # Deserialize from dict + new_result = SearchResult.model_validate(result_dict) + assert new_result.url == "https://example.com" + assert new_result.content == "Example content" + + # Deserialize from JSON + new_result_from_json = SearchResult.model_validate_json(result_json) + assert new_result_from_json.url == "https://example.com" + + +def test_viewpoint_tension_model(): + """Test the ViewpointTension model.""" + # Empty model + tension = ViewpointTension() + assert tension.topic == "" + assert tension.viewpoints == {} + + # With data + tension = ViewpointTension( + topic="Climate Change Impacts", + viewpoints={ + "Economic": "Focuses on financial costs and benefits", + "Environmental": "Emphasizes ecosystem impacts", + }, + ) + assert tension.topic == "Climate Change Impacts" + assert len(tension.viewpoints) == 2 + assert "Economic" in tension.viewpoints + + # Serialization + tension_dict = tension.model_dump() + assert tension_dict["topic"] == "Climate Change Impacts" + assert len(tension_dict["viewpoints"]) == 2 + + # Deserialization + new_tension = ViewpointTension.model_validate(tension_dict) + assert new_tension.topic == tension.topic + assert new_tension.viewpoints == tension.viewpoints + + +def test_synthesized_info_model(): + """Test the SynthesizedInfo model.""" + # Default values + info = SynthesizedInfo() + assert info.synthesized_answer == "" + assert info.key_sources == [] + assert info.confidence_level == "medium" + assert info.information_gaps == "" + assert info.improvements == [] + + # With values + info = SynthesizedInfo( + synthesized_answer="This is a synthesized answer", + key_sources=["https://source1.com", "https://source2.com"], + confidence_level="high", + information_gaps="Missing some context", + improvements=["Add more detail", "Check more sources"], + ) + assert info.synthesized_answer == "This is a synthesized answer" + assert len(info.key_sources) == 2 + assert info.confidence_level == "high" + + # Serialization and deserialization + info_dict = info.model_dump() + new_info = SynthesizedInfo.model_validate(info_dict) + assert new_info.synthesized_answer == info.synthesized_answer + assert new_info.key_sources == info.key_sources + + +def test_viewpoint_analysis_model(): + """Test the ViewpointAnalysis model.""" + # Create tensions for the analysis + tension1 = ViewpointTension( + topic="Economic Impact", + viewpoints={ + "Positive": "Creates jobs", + "Negative": "Increases inequality", + }, + ) + tension2 = ViewpointTension( + topic="Environmental Impact", + viewpoints={ + "Positive": "Reduces emissions", + "Negative": "Land use changes", + }, + ) + + # Create the analysis + analysis = ViewpointAnalysis( + main_points_of_agreement=[ + "Need for action", + "Technological innovation", + ], + areas_of_tension=[tension1, tension2], + perspective_gaps="Missing indigenous perspectives", + integrative_insights="Combined economic and environmental approach needed", + ) + + assert len(analysis.main_points_of_agreement) == 2 + assert len(analysis.areas_of_tension) == 2 + assert analysis.areas_of_tension[0].topic == "Economic Impact" + + # Test serialization + analysis_dict = analysis.model_dump() + assert len(analysis_dict["areas_of_tension"]) == 2 + assert analysis_dict["areas_of_tension"][0]["topic"] == "Economic Impact" + + # Test deserialization + new_analysis = ViewpointAnalysis.model_validate(analysis_dict) + assert len(new_analysis.areas_of_tension) == 2 + assert new_analysis.areas_of_tension[0].topic == "Economic Impact" + assert new_analysis.perspective_gaps == analysis.perspective_gaps + + +def test_reflection_metadata_model(): + """Test the ReflectionMetadata model.""" + metadata = ReflectionMetadata( + critique_summary=["Need more sources", "Missing detailed analysis"], + additional_questions_identified=["What about future trends?"], + searches_performed=["future climate trends", "economic impacts"], + improvements_made=3, + error=None, + ) + + assert len(metadata.critique_summary) == 2 + assert len(metadata.additional_questions_identified) == 1 + assert metadata.improvements_made == 3 + assert metadata.error is None + + # Serialization + metadata_dict = metadata.model_dump() + assert len(metadata_dict["critique_summary"]) == 2 + assert metadata_dict["improvements_made"] == 3 + + # Deserialization + new_metadata = ReflectionMetadata.model_validate(metadata_dict) + assert new_metadata.improvements_made == metadata.improvements_made + assert new_metadata.critique_summary == metadata.critique_summary + + +def test_research_state_model(): + """Test the main ResearchState model.""" + # Create with defaults + state = ResearchState() + assert state.main_query == "" + assert state.sub_questions == [] + assert state.search_results == {} + assert state.get_current_stage() == "empty" + + # Set main query + state.main_query = "What are the impacts of climate change?" + assert state.get_current_stage() == "initial" + + # Test update methods + state.update_sub_questions( + ["What are economic impacts?", "What are environmental impacts?"] + ) + assert len(state.sub_questions) == 2 + assert state.get_current_stage() == "after_query_decomposition" + + # Add search results + search_results: Dict[str, List[SearchResult]] = { + "What are economic impacts?": [ + SearchResult( + url="https://example.com/economy", + title="Economic Impacts", + snippet="Overview of economic impacts", + content="Detailed content about economic impacts", + ) + ] + } + state.update_search_results(search_results) + assert state.get_current_stage() == "after_search" + assert len(state.search_results["What are economic impacts?"]) == 1 + + # Add synthesized info + synthesized_info: Dict[str, SynthesizedInfo] = { + "What are economic impacts?": SynthesizedInfo( + synthesized_answer="Economic impacts include job losses and growth opportunities", + key_sources=["https://example.com/economy"], + confidence_level="high", + ) + } + state.update_synthesized_info(synthesized_info) + assert state.get_current_stage() == "after_synthesis" + + # Add viewpoint analysis + analysis = ViewpointAnalysis( + main_points_of_agreement=["Economic changes are happening"], + areas_of_tension=[ + ViewpointTension( + topic="Job impacts", + viewpoints={ + "Positive": "New green jobs", + "Negative": "Fossil fuel job losses", + }, + ) + ], + ) + state.update_viewpoint_analysis(analysis) + assert state.get_current_stage() == "after_viewpoint_analysis" + + # Add reflection results + enhanced_info = { + "What are economic impacts?": SynthesizedInfo( + synthesized_answer="Enhanced answer with more details", + key_sources=[ + "https://example.com/economy", + "https://example.com/new-source", + ], + confidence_level="high", + improvements=["Added more context", "Added more sources"], + ) + } + metadata = ReflectionMetadata( + critique_summary=["Needed more sources"], + improvements_made=2, + ) + state.update_after_reflection(enhanced_info, metadata) + assert state.get_current_stage() == "after_reflection" + + # Set final report + state.set_final_report("Final report content") + assert state.get_current_stage() == "final_report" + assert state.final_report_html == "Final report content" + + # Test serialization and deserialization + state_dict = state.model_dump() + new_state = ResearchState.model_validate(state_dict) + + # Verify key properties were preserved + assert new_state.main_query == state.main_query + assert len(new_state.sub_questions) == len(state.sub_questions) + assert new_state.get_current_stage() == state.get_current_stage() + assert new_state.viewpoint_analysis is not None + assert len(new_state.viewpoint_analysis.areas_of_tension) == 1 + assert ( + new_state.viewpoint_analysis.areas_of_tension[0].topic == "Job impacts" + ) + assert new_state.final_report_html == state.final_report_html diff --git a/deep_research/utils/__init__.py b/deep_research/utils/__init__.py new file mode 100644 index 00000000..395e1d67 --- /dev/null +++ b/deep_research/utils/__init__.py @@ -0,0 +1,7 @@ +""" +Utilities package for the ZenML Deep Research project. + +This package contains various utility functions and helpers used throughout the project, +including data models, LLM interaction utilities, search functionality, and common helper +functions for text processing and state management. +""" diff --git a/deep_research/utils/approval_utils.py b/deep_research/utils/approval_utils.py new file mode 100644 index 00000000..56a8ff91 --- /dev/null +++ b/deep_research/utils/approval_utils.py @@ -0,0 +1,159 @@ +"""Utility functions for the human approval process.""" + +from typing import Any, Dict, List + +from utils.pydantic_models import ApprovalDecision + + +def summarize_research_progress(state) -> Dict[str, Any]: + """Summarize the current research progress.""" + completed_count = len(state.synthesized_info) + confidence_levels = [ + info.confidence_level for info in state.synthesized_info.values() + ] + + # Calculate average confidence (high=1.0, medium=0.5, low=0.0) + confidence_map = {"high": 1.0, "medium": 0.5, "low": 0.0} + avg_confidence = sum( + confidence_map.get(c, 0.5) for c in confidence_levels + ) / max(len(confidence_levels), 1) + + low_confidence_count = sum(1 for c in confidence_levels if c == "low") + + return { + "completed_count": completed_count, + "avg_confidence": round(avg_confidence, 2), + "low_confidence_count": low_confidence_count, + } + + +def format_critique_summary(critique_points: List[Dict[str, Any]]) -> str: + """Format critique points for display.""" + if not critique_points: + return "No critical issues identified." + + formatted = [] + for point in critique_points[:3]: # Show top 3 + issue = point.get("issue", "Unknown issue") + formatted.append(f"- {issue}") + + if len(critique_points) > 3: + formatted.append(f"- ... and {len(critique_points) - 3} more issues") + + return "\n".join(formatted) + + +def format_query_list(queries: List[str]) -> str: + """Format query list for display.""" + if not queries: + return "No queries proposed." + + formatted = [] + for i, query in enumerate(queries, 1): + formatted.append(f"{i}. {query}") + + return "\n".join(formatted) + + +def calculate_estimated_cost(queries: List[str]) -> float: + """Calculate estimated cost for additional queries.""" + # Rough estimate: ~$0.01 per query (including search API + LLM costs) + return round(len(queries) * 0.01, 2) + + +def format_approval_request( + main_query: str, + progress_summary: Dict[str, Any], + critique_points: List[Dict[str, Any]], + proposed_queries: List[str], + timeout: int = 3600, +) -> str: + """Format the approval request message.""" + + # High-priority critiques + high_priority = [ + c for c in critique_points if c.get("importance") == "high" + ] + + message = f"""📊 **Research Progress Update** + +**Main Query:** {main_query} + +**Current Status:** +- Sub-questions analyzed: {progress_summary["completed_count"]} +- Average confidence: {progress_summary["avg_confidence"]} +- Low confidence areas: {progress_summary["low_confidence_count"]} + +**Key Issues Identified:** +{format_critique_summary(high_priority or critique_points)} + +**Proposed Additional Research** ({len(proposed_queries)} queries): +{format_query_list(proposed_queries)} + +**Estimated Additional Time:** ~{len(proposed_queries) * 2} minutes +**Estimated Additional Cost:** ~${calculate_estimated_cost(proposed_queries)} + +**Response Options:** +- Reply with `approve`, `yes`, `ok`, or `LGTM` to proceed with all queries +- Reply with `reject`, `no`, `skip`, or `decline` to finish with current findings + +**Timeout:** Response required within {timeout // 60} minutes""" + + return message + + +def parse_approval_response( + response: str, proposed_queries: List[str] +) -> ApprovalDecision: + """Parse the approval response from user.""" + + response_upper = response.strip().upper() + + if response_upper == "APPROVE ALL": + return ApprovalDecision( + approved=True, + selected_queries=proposed_queries, + approval_method="APPROVE_ALL", + reviewer_notes=response, + ) + + elif response_upper == "SKIP": + return ApprovalDecision( + approved=False, + selected_queries=[], + approval_method="SKIP", + reviewer_notes=response, + ) + + elif response_upper.startswith("SELECT"): + # Parse selection like "SELECT 1,3,5" + try: + # Extract the part after "SELECT" + selection_part = response_upper[6:].strip() + indices = [int(x.strip()) - 1 for x in selection_part.split(",")] + selected = [ + proposed_queries[i] + for i in indices + if 0 <= i < len(proposed_queries) + ] + return ApprovalDecision( + approved=True, + selected_queries=selected, + approval_method="SELECT_SPECIFIC", + reviewer_notes=response, + ) + except Exception as e: + return ApprovalDecision( + approved=False, + selected_queries=[], + approval_method="PARSE_ERROR", + reviewer_notes=f"Failed to parse: {response} - {str(e)}", + ) + + else: + return ApprovalDecision( + approved=False, + selected_queries=[], + approval_method="UNKNOWN_RESPONSE", + reviewer_notes=f"Unknown response: {response}", + ) diff --git a/deep_research/utils/helper_functions.py b/deep_research/utils/helper_functions.py new file mode 100644 index 00000000..256cdd28 --- /dev/null +++ b/deep_research/utils/helper_functions.py @@ -0,0 +1,192 @@ +import json +import logging +import os +from json.decoder import JSONDecodeError +from typing import Any, Dict, Optional + +import yaml + +logger = logging.getLogger(__name__) + + +def remove_reasoning_from_output(output: str) -> str: + """Remove the reasoning portion from LLM output. + + Args: + output: Raw output from LLM that may contain reasoning + + Returns: + Cleaned output without the reasoning section + """ + if not output: + return "" + + if "" in output: + return output.split("")[-1].strip() + return output.strip() + + +def clean_json_tags(text: str) -> str: + """Clean JSON markdown tags from text. + + Args: + text: Text with potential JSON markdown tags + + Returns: + Cleaned text without JSON markdown tags + """ + if not text: + return "" + + cleaned = text.replace("```json\n", "").replace("\n```", "") + cleaned = cleaned.replace("```json", "").replace("```", "") + return cleaned + + +def clean_markdown_tags(text: str) -> str: + """Clean Markdown tags from text. + + Args: + text: Text with potential markdown tags + + Returns: + Cleaned text without markdown tags + """ + if not text: + return "" + + cleaned = text.replace("```markdown\n", "").replace("\n```", "") + cleaned = cleaned.replace("```markdown", "").replace("```", "") + return cleaned + + +def extract_html_from_content(content: str) -> str: + """Attempt to extract HTML content from a response that might be wrapped in other formats. + + Args: + content: The content to extract HTML from + + Returns: + The extracted HTML, or a basic fallback if extraction fails + """ + if not content: + return "" + + # Try to find HTML between tags + if "" in content: + start = content.find("") + 7 # Include the closing tag + return content[start:end] + + # Try to find div class="research-report" + if '
    " in content: + start = content.find('
    ") + if last_div > start: + return content[start : last_div + 6] # Include the closing tag + + # Look for code blocks + if "```html" in content and "```" in content: + start = content.find("```html") + 7 + end = content.find("```", start) + if end > start: + return content[start:end].strip() + + # Look for JSON with an "html" field + try: + parsed = json.loads(content) + if isinstance(parsed, dict) and "html" in parsed: + return parsed["html"] + except: + pass + + # If all extraction attempts fail, return the original content + return content + + +def safe_json_loads(json_str: Optional[str]) -> Dict[str, Any]: + """Safely parse JSON string. + + Args: + json_str: JSON string to parse, can be None. + + Returns: + Dict[str, Any]: Parsed JSON as dictionary or empty dict if parsing fails or input is None. + """ + if json_str is None: + # Optionally, log a warning here if None input is unexpected for certain call sites + # logger.warning("safe_json_loads received None input.") + return {} + try: + return json.loads(json_str) + except ( + JSONDecodeError, + TypeError, + ): # Catch TypeError if json_str is not a valid type for json.loads + # Optionally, log the error and the problematic string (or its beginning) + # logger.warning(f"Failed to decode JSON string: '{str(json_str)[:200]}...'", exc_info=True) + return {} + + +def load_pipeline_config(config_path: str) -> Dict[str, Any]: + """Load pipeline configuration from YAML file. + + This is used only for pipeline-level configuration, not for step parameters. + Step parameters should be defined directly in the step functions. + + Args: + config_path: Path to the configuration YAML file + + Returns: + Pipeline configuration dictionary + """ + # Get absolute path if relative + if not os.path.isabs(config_path): + base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + config_path = os.path.join(base_dir, config_path) + + # Load YAML configuration + try: + with open(config_path, "r") as f: + config = yaml.safe_load(f) + return config + except Exception as e: + logger.error(f"Error loading pipeline configuration: {e}") + # Return a minimal default configuration in case of loading error + return { + "pipeline": { + "name": "deep_research_pipeline", + "enable_cache": True, + }, + "environment": { + "docker": { + "requirements": [ + "openai>=1.0.0", + "tavily-python>=0.2.8", + "PyYAML>=6.0", + "click>=8.0.0", + "pydantic>=2.0.0", + "typing_extensions>=4.0.0", + ] + } + }, + "resources": {"cpu": 1, "memory": "4Gi"}, + "timeout": 3600, + } + + +def check_required_env_vars(env_vars: list[str]) -> list[str]: + """Check if required environment variables are set. + + Args: + env_vars: List of environment variable names to check + + Returns: + List of missing environment variables + """ + missing_vars = [] + for var in env_vars: + if not os.environ.get(var): + missing_vars.append(var) + return missing_vars diff --git a/deep_research/utils/llm_utils.py b/deep_research/utils/llm_utils.py new file mode 100644 index 00000000..1f4dc194 --- /dev/null +++ b/deep_research/utils/llm_utils.py @@ -0,0 +1,387 @@ +import contextlib +import json +import logging +from typing import Any, Dict, List, Optional + +import litellm +from litellm import completion +from utils.helper_functions import ( + clean_json_tags, + remove_reasoning_from_output, + safe_json_loads, +) +from utils.prompts import SYNTHESIS_PROMPT +from zenml import get_step_context + +logger = logging.getLogger(__name__) + +# This module uses litellm for all LLM interactions +# Models are specified with a provider prefix (e.g., "sambanova/DeepSeek-R1-Distill-Llama-70B") +# ALL model names require a provider prefix (e.g., "sambanova/", "openai/", "anthropic/") + +litellm.callbacks = ["langfuse"] + + +def run_llm_completion( + prompt: str, + system_prompt: str, + model: str = "sambanova/Llama-4-Maverick-17B-128E-Instruct", + clean_output: bool = True, + max_tokens: int = 2000, # Increased default token limit + temperature: float = 0.2, + top_p: float = 0.9, + project: str = "deep-research", + tags: Optional[List[str]] = None, +) -> str: + """Run an LLM completion with standard error handling and output cleaning. + + Uses litellm for model inference. + + Args: + prompt: User prompt for the LLM + system_prompt: System prompt for the LLM + model: Model to use for completion (with provider prefix) + clean_output: Whether to clean reasoning and JSON tags from output. When True, + this removes any reasoning sections marked with tags and strips JSON + code block markers. + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + top_p: Top-p sampling value + project: Langfuse project name for LLM tracking + tags: Optional list of tags for Langfuse tracking. If provided, also converted to trace_metadata format. + + Returns: + str: Processed LLM output with optional cleaning applied + """ + try: + # Ensure model name has provider prefix + if not any( + model.startswith(prefix + "/") + for prefix in [ + "sambanova", + "openai", + "anthropic", + "meta", + "google", + "aws", + "openrouter", + ] + ): + # Raise an error if no provider prefix is specified + error_msg = f"Model '{model}' does not have a provider prefix. Please specify provider (e.g., 'sambanova/{model}')" + logger.error(error_msg) + raise ValueError(error_msg) + + # Get pipeline run name and id for trace_name and trace_id if running in a step + trace_name = None + trace_id = None + with contextlib.suppress(RuntimeError): + context = get_step_context() + trace_name = context.pipeline_run.name + trace_id = str(context.pipeline_run.id) + # Build metadata dict + metadata = {"project": project} + if tags is not None: + metadata["tags"] = tags + # Convert tags to trace_metadata format + metadata["trace_metadata"] = {tag: True for tag in tags} + if trace_name: + metadata["trace_name"] = trace_name + if trace_id: + metadata["trace_id"] = trace_id + + response = completion( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ], + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + metadata=metadata, + ) + + # Defensive access to content + content = None + if response and response.choices and len(response.choices) > 0: + choice = response.choices[0] + if choice and choice.message: + content = choice.message.content + + if content is None: + logger.warning("LLM response content is missing or empty.") + return "" + + if clean_output: + content = remove_reasoning_from_output(content) + content = clean_json_tags(content) + + return content + except Exception as e: + logger.error(f"Error in LLM completion: {e}") + return "" + + +def get_structured_llm_output( + prompt: str, + system_prompt: str, + model: str = "sambanova/Llama-4-Maverick-17B-128E-Instruct", + fallback_response: Optional[Dict[str, Any]] = None, + max_tokens: int = 2000, # Increased default token limit for structured outputs + temperature: float = 0.2, + top_p: float = 0.9, + project: str = "deep-research", + tags: Optional[List[str]] = None, +) -> Dict[str, Any]: + """Get structured JSON output from an LLM with error handling. + + Uses litellm for model inference. + + Args: + prompt: User prompt for the LLM + system_prompt: System prompt for the LLM + model: Model to use for completion (with provider prefix) + fallback_response: Fallback response if parsing fails + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + top_p: Top-p sampling value + project: Langfuse project name for LLM tracking + tags: Optional list of tags for Langfuse tracking. Defaults to ["structured_llm_output"] if None. + + Returns: + Parsed JSON response or fallback + """ + try: + # Use provided tags or default to ["structured_llm_output"] + if tags is None: + tags = ["structured_llm_output"] + + content = run_llm_completion( + prompt=prompt, + system_prompt=system_prompt, + model=model, + clean_output=True, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + project=project, + tags=tags, + ) + + if not content: + logger.warning("Empty content returned from LLM") + return fallback_response if fallback_response is not None else {} + + result = safe_json_loads(content) + + if not result and fallback_response is not None: + return fallback_response + + return result + except Exception as e: + logger.error(f"Error processing structured LLM output: {e}") + return fallback_response if fallback_response is not None else {} + + +def is_text_relevant(text1: str, text2: str, min_word_length: int = 4) -> bool: + """Determine if two pieces of text are relevant to each other. + + Relevance is determined by checking if one text is contained within the other, + or if they share significant words (words longer than min_word_length). + This is a simple heuristic approach that checks for: + 1. Complete containment (one text string inside the other) + 2. Shared significant words (words longer than min_word_length) + + Args: + text1: First text to compare + text2: Second text to compare + min_word_length: Minimum length of words to check for shared content + + Returns: + bool: True if the texts are deemed relevant to each other based on the criteria + """ + if not text1 or not text2: + return False + + return ( + text1.lower() in text2.lower() + or text2.lower() in text1.lower() + or any( + word + for word in text1.lower().split() + if len(word) > min_word_length and word in text2.lower() + ) + ) + + +def find_most_relevant_string( + target: str, + options: List[str], + model: Optional[str] = "sambanova/Llama-4-Maverick-17B-128E-Instruct", + project: str = "deep-research", + tags: Optional[List[str]] = None, +) -> Optional[str]: + """Find the most relevant string from a list of options using simple text matching. + + If model is provided, uses litellm to determine relevance. + + Args: + target: The target string to find relevance for + options: List of string options to check against + model: Model to use for matching (with provider prefix) + project: Langfuse project name for LLM tracking + tags: Optional list of tags for Langfuse tracking. Defaults to ["find_most_relevant_string"] if None. + + Returns: + The most relevant string, or None if no relevant options + """ + if not options: + return None + + if len(options) == 1: + return options[0] + + # If model is provided, use litellm for more accurate matching + if model: + try: + # Ensure model name has provider prefix + if not any( + model.startswith(prefix + "/") + for prefix in [ + "sambanova", + "openai", + "anthropic", + "meta", + "google", + "aws", + "openrouter", + ] + ): + # Raise an error if no provider prefix is specified + error_msg = f"Model '{model}' does not have a provider prefix. Please specify provider (e.g., 'sambanova/{model}')" + logger.error(error_msg) + raise ValueError(error_msg) + + system_prompt = "You are a research assistant." + prompt = f"""Given the text: "{target}" +Which of the following options is most relevant to this text? +{options} + +Respond with only the exact text of the most relevant option.""" + + # Get pipeline run name and id for trace_name and trace_id if running in a step + trace_name = None + trace_id = None + try: + context = get_step_context() + trace_name = context.pipeline_run.name + trace_id = str(context.pipeline_run.id) + except RuntimeError: + # Not running in a step context + pass + + # Use provided tags or default to ["find_most_relevant_string"] + if tags is None: + tags = ["find_most_relevant_string"] + + # Build metadata dict + metadata = {"project": project, "tags": tags} + # Convert tags to trace_metadata format + metadata["trace_metadata"] = {tag: True for tag in tags} + if trace_name: + metadata["trace_name"] = trace_name + if trace_id: + metadata["trace_id"] = trace_id + + response = completion( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ], + max_tokens=100, + temperature=0.2, + metadata=metadata, + ) + + answer = response.choices[0].message.content.strip() + + # Check if the answer is one of the options + if answer in options: + return answer + + # If not an exact match, find the closest one + for option in options: + if option in answer or answer in option: + return option + + except Exception as e: + logger.error(f"Error finding relevant string with LLM: {e}") + + # Simple relevance check - find exact matches first + for option in options: + if target.lower() == option.lower(): + return option + + # Then check partial matches + for option in options: + if is_text_relevant(target, option): + return option + + # Return the first option as a fallback + return options[0] + + +def synthesize_information( + synthesis_input: Dict[str, Any], + model: str = "sambanova/Llama-4-Maverick-17B-128E-Instruct", + system_prompt: Optional[str] = None, + project: str = "deep-research", + tags: Optional[List[str]] = None, +) -> Dict[str, Any]: + """Synthesize information from search results for a sub-question. + + Uses litellm for model inference. + + Args: + synthesis_input: Dictionary with sub-question, search results, and sources + model: Model to use (with provider prefix) + system_prompt: System prompt for the LLM + project: Langfuse project name for LLM tracking + tags: Optional list of tags for Langfuse tracking. Defaults to ["information_synthesis"] if None. + + Returns: + Dictionary with synthesized information + """ + if system_prompt is None: + system_prompt = SYNTHESIS_PROMPT + + sub_question_for_log = synthesis_input.get( + "sub_question", "unknown question" + ) + + # Define the fallback response + fallback_response = { + "synthesized_answer": f"Synthesis failed for '{sub_question_for_log}'.", + "key_sources": synthesis_input.get("sources", [])[:1], + "confidence_level": "low", + "information_gaps": "An error occurred during the synthesis process.", + } + + # Use provided tags or default to ["information_synthesis"] + if tags is None: + tags = ["information_synthesis"] + + # Use the utility function to get structured output + result = get_structured_llm_output( + prompt=json.dumps(synthesis_input), + system_prompt=system_prompt, + model=model, + fallback_response=fallback_response, + max_tokens=3000, # Increased for more detailed synthesis + project=project, + tags=tags, + ) + + return result diff --git a/deep_research/utils/prompt_loader.py b/deep_research/utils/prompt_loader.py new file mode 100644 index 00000000..c2b3d2f7 --- /dev/null +++ b/deep_research/utils/prompt_loader.py @@ -0,0 +1,136 @@ +"""Utility functions for loading prompts into the PromptsBundle model. + +This module provides functions to create PromptsBundle instances from +the existing prompt definitions in prompts.py. +""" + +from utils import prompts +from utils.prompt_models import PromptsBundle, PromptTemplate + + +def load_prompts_bundle(pipeline_version: str = "1.2.0") -> PromptsBundle: + """Load all prompts from prompts.py into a PromptsBundle. + + Args: + pipeline_version: Version of the pipeline using these prompts + + Returns: + PromptsBundle containing all prompts + """ + # Create PromptTemplate instances for each prompt + search_query_prompt = PromptTemplate( + name="search_query_prompt", + content=prompts.DEFAULT_SEARCH_QUERY_PROMPT, + description="Generates effective search queries from sub-questions", + version="1.0.0", + tags=["search", "query", "information-gathering"], + ) + + query_decomposition_prompt = PromptTemplate( + name="query_decomposition_prompt", + content=prompts.QUERY_DECOMPOSITION_PROMPT, + description="Breaks down complex research queries into specific sub-questions", + version="1.0.0", + tags=["analysis", "decomposition", "planning"], + ) + + synthesis_prompt = PromptTemplate( + name="synthesis_prompt", + content=prompts.SYNTHESIS_PROMPT, + description="Synthesizes search results into comprehensive answers for sub-questions", + version="1.1.0", + tags=["synthesis", "integration", "analysis"], + ) + + viewpoint_analysis_prompt = PromptTemplate( + name="viewpoint_analysis_prompt", + content=prompts.VIEWPOINT_ANALYSIS_PROMPT, + description="Analyzes synthesized answers across different perspectives and viewpoints", + version="1.1.0", + tags=["analysis", "viewpoint", "perspective"], + ) + + reflection_prompt = PromptTemplate( + name="reflection_prompt", + content=prompts.REFLECTION_PROMPT, + description="Evaluates research and identifies gaps, biases, and areas for improvement", + version="1.0.0", + tags=["reflection", "critique", "improvement"], + ) + + additional_synthesis_prompt = PromptTemplate( + name="additional_synthesis_prompt", + content=prompts.ADDITIONAL_SYNTHESIS_PROMPT, + description="Enhances original synthesis with new information and addresses critique points", + version="1.1.0", + tags=["synthesis", "enhancement", "integration"], + ) + + conclusion_generation_prompt = PromptTemplate( + name="conclusion_generation_prompt", + content=prompts.CONCLUSION_GENERATION_PROMPT, + description="Synthesizes all research findings into a comprehensive conclusion", + version="1.0.0", + tags=["report", "conclusion", "synthesis"], + ) + + executive_summary_prompt = PromptTemplate( + name="executive_summary_prompt", + content=prompts.EXECUTIVE_SUMMARY_GENERATION_PROMPT, + description="Creates a compelling, insight-driven executive summary", + version="1.1.0", + tags=["report", "summary", "insights"], + ) + + introduction_prompt = PromptTemplate( + name="introduction_prompt", + content=prompts.INTRODUCTION_GENERATION_PROMPT, + description="Creates a contextual, engaging introduction", + version="1.1.0", + tags=["report", "introduction", "context"], + ) + + # Create and return the bundle + return PromptsBundle( + search_query_prompt=search_query_prompt, + query_decomposition_prompt=query_decomposition_prompt, + synthesis_prompt=synthesis_prompt, + viewpoint_analysis_prompt=viewpoint_analysis_prompt, + reflection_prompt=reflection_prompt, + additional_synthesis_prompt=additional_synthesis_prompt, + conclusion_generation_prompt=conclusion_generation_prompt, + executive_summary_prompt=executive_summary_prompt, + introduction_prompt=introduction_prompt, + pipeline_version=pipeline_version, + ) + + +def get_prompt_for_step(bundle: PromptsBundle, step_name: str) -> str: + """Get the appropriate prompt content for a specific step. + + Args: + bundle: The PromptsBundle containing all prompts + step_name: Name of the step requesting the prompt + + Returns: + The prompt content string + + Raises: + ValueError: If no prompt mapping exists for the step + """ + # Map step names to prompt attributes + step_to_prompt_mapping = { + "query_decomposition": "query_decomposition_prompt", + "search_query_generation": "search_query_prompt", + "synthesis": "synthesis_prompt", + "viewpoint_analysis": "viewpoint_analysis_prompt", + "reflection": "reflection_prompt", + "additional_synthesis": "additional_synthesis_prompt", + "conclusion_generation": "conclusion_generation_prompt", + } + + prompt_attr = step_to_prompt_mapping.get(step_name) + if not prompt_attr: + raise ValueError(f"No prompt mapping found for step: {step_name}") + + return bundle.get_prompt_content(prompt_attr) diff --git a/deep_research/utils/prompt_models.py b/deep_research/utils/prompt_models.py new file mode 100644 index 00000000..c96a7bd0 --- /dev/null +++ b/deep_research/utils/prompt_models.py @@ -0,0 +1,123 @@ +"""Pydantic models for prompt tracking and management. + +This module contains models for bundling prompts as trackable artifacts +in the ZenML pipeline, enabling better observability and version control. +""" + +from datetime import datetime +from typing import Dict, Optional + +from pydantic import BaseModel, Field + + +class PromptTemplate(BaseModel): + """Represents a single prompt template with metadata.""" + + name: str = Field(..., description="Unique identifier for the prompt") + content: str = Field(..., description="The actual prompt template content") + description: str = Field( + "", description="Human-readable description of what this prompt does" + ) + version: str = Field("1.0.0", description="Version of the prompt template") + tags: list[str] = Field( + default_factory=list, description="Tags for categorizing prompts" + ) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +class PromptsBundle(BaseModel): + """Bundle of all prompts used in the research pipeline. + + This model serves as a single artifact that contains all prompts, + making them trackable, versionable, and visualizable in the ZenML dashboard. + """ + + # Core prompts used in the pipeline + search_query_prompt: PromptTemplate + query_decomposition_prompt: PromptTemplate + synthesis_prompt: PromptTemplate + viewpoint_analysis_prompt: PromptTemplate + reflection_prompt: PromptTemplate + additional_synthesis_prompt: PromptTemplate + conclusion_generation_prompt: PromptTemplate + executive_summary_prompt: PromptTemplate + introduction_prompt: PromptTemplate + + # Metadata + pipeline_version: str = Field( + "1.0.0", description="Version of the pipeline using these prompts" + ) + created_at: str = Field( + default_factory=lambda: datetime.now().isoformat(), + description="Timestamp when this bundle was created", + ) + + # Additional prompts can be stored here + custom_prompts: Dict[str, PromptTemplate] = Field( + default_factory=dict, + description="Additional custom prompts not part of the core set", + ) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + def get_prompt_by_name(self, name: str) -> Optional[PromptTemplate]: + """Retrieve a prompt by its name. + + Args: + name: Name of the prompt to retrieve + + Returns: + PromptTemplate if found, None otherwise + """ + # Check core prompts + for field_name, field_value in self.__dict__.items(): + if ( + isinstance(field_value, PromptTemplate) + and field_value.name == name + ): + return field_value + + # Check custom prompts + return self.custom_prompts.get(name) + + def list_all_prompts(self) -> Dict[str, PromptTemplate]: + """Get all prompts as a dictionary. + + Returns: + Dictionary mapping prompt names to PromptTemplate objects + """ + all_prompts = {} + + # Add core prompts + for field_name, field_value in self.__dict__.items(): + if isinstance(field_value, PromptTemplate): + all_prompts[field_value.name] = field_value + + # Add custom prompts + all_prompts.update(self.custom_prompts) + + return all_prompts + + def get_prompt_content(self, prompt_type: str) -> str: + """Get the content of a specific prompt by its type. + + Args: + prompt_type: Type of prompt (e.g., 'search_query_prompt', 'synthesis_prompt') + + Returns: + The prompt content string + + Raises: + AttributeError: If prompt type doesn't exist + """ + prompt = getattr(self, prompt_type) + return prompt.content diff --git a/deep_research/utils/prompts.py b/deep_research/utils/prompts.py new file mode 100644 index 00000000..e072b238 --- /dev/null +++ b/deep_research/utils/prompts.py @@ -0,0 +1,1605 @@ +""" +Centralized collection of prompts used throughout the deep research pipeline. + +This module contains all system prompts used by LLM calls in various steps of the +research pipeline to ensure consistency and make prompt management easier. +""" + +# Search query generation prompt +# Used to generate effective search queries from sub-questions +DEFAULT_SEARCH_QUERY_PROMPT = """ +You are a Deep Research assistant. Given a specific research sub-question, your task is to formulate an effective search +query that will help find relevant information to answer the question. + +A good search query should: +1. Extract the key concepts from the sub-question +2. Use precise, specific terminology +3. Exclude unnecessary words or context +4. Include alternative terms or synonyms when helpful +5. Be concise yet comprehensive enough to find relevant results + +Format the output in json with the following json schema definition: + + +{ + "type": "object", + "properties": { + "search_query": {"type": "string"}, + "reasoning": {"type": "string"} + } +} + + +Make sure that the output is a json object with an output json schema defined above. +Only return the json object, no explanation or additional text. +""" + +# Query decomposition prompt +# Used to break down complex research queries into specific sub-questions +QUERY_DECOMPOSITION_PROMPT = """ +You are a Deep Research assistant specializing in research design. You will be given a MAIN RESEARCH QUERY that needs to be explored comprehensively. Your task is to create diverse, insightful sub-questions that explore different dimensions of the topic. + +IMPORTANT: The main query should be interpreted as a single research question, not as a noun phrase. For example: +- If the query is "Is LLMOps a subset of MLOps?", create questions ABOUT LLMOps and MLOps, not questions like "What is 'Is LLMOps a subset of MLOps?'" +- Focus on the concepts, relationships, and implications within the query + +Create sub-questions that explore these DIFFERENT DIMENSIONS: + +1. **Definitional/Conceptual**: Define key terms and establish conceptual boundaries + Example: "What are the core components and characteristics of LLMOps?" + +2. **Comparative/Relational**: Compare and contrast the concepts mentioned + Example: "How do the workflows and tooling of LLMOps differ from traditional MLOps?" + +3. **Historical/Evolutionary**: Trace development and emergence + Example: "How did LLMOps emerge from MLOps practices?" + +4. **Structural/Technical**: Examine technical architecture and implementation + Example: "What specific tools and platforms are unique to LLMOps?" + +5. **Practical/Use Cases**: Explore real-world applications + Example: "What are the key use cases that require LLMOps but not traditional MLOps?" + +6. **Stakeholder/Industry**: Consider different perspectives and adoption + Example: "How are different industries adopting LLMOps vs MLOps?" + +7. **Challenges/Limitations**: Identify problems and constraints + Example: "What unique challenges does LLMOps face that MLOps doesn't?" + +8. **Future/Trends**: Look at emerging developments + Example: "How is the relationship between LLMOps and MLOps expected to evolve?" + +QUALITY GUIDELINES: +- Each sub-question must explore a DIFFERENT dimension - no repetitive variations +- Questions should be specific, concrete, and investigable +- Mix descriptive ("what/who") with analytical ("why/how") questions +- Ensure questions build toward answering the main query comprehensively +- Frame questions to elicit detailed, nuanced responses +- Consider technical, business, organizational, and strategic aspects + +Format the output in json with the following json schema definition: + + +{ + "type": "array", + "items": { + "type": "object", + "properties": { + "sub_question": {"type": "string"}, + "reasoning": {"type": "string"} + } + } +} + + +Make sure that the output is a json object with an output json schema defined above. +Only return the json object, no explanation or additional text. +""" + +# Synthesis prompt for individual sub-questions +# Used to synthesize search results into comprehensive answers for sub-questions +SYNTHESIS_PROMPT = """ +You are a Deep Research assistant specializing in information synthesis. Given a sub-question and search results, your task is to synthesize the information +into a comprehensive, accurate, and well-structured answer. + +Your synthesis should: +1. Begin with a direct, concise answer to the sub-question in the first paragraph +2. Provide detailed evidence and explanation in subsequent paragraphs (at least 3-5 paragraphs total) +3. Integrate information from multiple sources, citing them within your answer +4. Acknowledge any conflicting information or contrasting viewpoints you encounter +5. Use data, statistics, examples, and quotations when available to strengthen your answer +6. Organize information logically with a clear flow between concepts +7. Identify key sources that provided the most valuable information (at least 2-3 sources) +8. Explicitly acknowledge information gaps where the search results were incomplete +9. Write in plain text format - do NOT use markdown formatting, bullet points, or special characters + +Confidence level criteria: +- HIGH: Multiple high-quality sources provide consistent information, comprehensive coverage of the topic, and few information gaps +- MEDIUM: Decent sources with some consistency, but notable information gaps or some conflicting information +- LOW: Limited sources, major information gaps, significant contradictions, or only tangentially relevant information + +Information gaps should specifically identify: +1. Aspects of the question that weren't addressed in the search results +2. Areas where more detailed or up-to-date information would be valuable +3. Perspectives or data sources that would complement the existing information + +Format the output in json with the following json schema definition: + + +{ + "type": "object", + "properties": { + "synthesized_answer": {"type": "string"}, + "key_sources": { + "type": "array", + "items": {"type": "string"} + }, + "confidence_level": {"type": "string", "enum": ["high", "medium", "low"]}, + "information_gaps": {"type": "string"}, + "improvements": { + "type": "array", + "items": {"type": "string"} + } + } +} + + +Make sure that the output is a json object with an output json schema defined above. +Only return the json object, no explanation or additional text. +""" + +# Viewpoint analysis prompt for cross-perspective examination +# Used to analyze synthesized answers across different perspectives and viewpoints +VIEWPOINT_ANALYSIS_PROMPT = """ +You are a Deep Research assistant specializing in multi-perspective analysis. You will be given a set of synthesized answers +to sub-questions related to a main research query. Your task is to perform a thorough, nuanced analysis of how different +perspectives would interpret this information. + +Think deeply about the following viewpoint categories and how they would approach the information differently: +- Scientific: Evidence-based, empirical approach focused on data, research findings, and methodological rigor +- Political: Power dynamics, governance structures, policy implications, and ideological frameworks +- Economic: Resource allocation, financial impacts, market dynamics, and incentive structures +- Social: Cultural norms, community impacts, group dynamics, and public welfare +- Ethical: Moral principles, values considerations, rights and responsibilities, and normative judgments +- Historical: Long-term patterns, precedents, contextual development, and evolutionary change + +For each synthesized answer, analyze how these different perspectives would interpret the information by: + +1. Identifying 5-8 main points of agreement where multiple perspectives align (with specific examples) +2. Analyzing at least 3-5 areas of tension between perspectives with: + - A clear topic title for each tension point + - Contrasting interpretations from at least 2-3 different viewpoint categories per tension + - Specific examples or evidence showing why these perspectives differ + - The nuanced positions of each perspective, not just simplified oppositions + +3. Thoroughly examining perspective gaps by identifying: + - Which perspectives are underrepresented or missing in the current research + - How including these missing perspectives would enrich understanding + - Specific questions or dimensions that remain unexplored + - Write in plain text format - do NOT use markdown formatting, bullet points, or special characters + +4. Developing integrative insights that: + - Synthesize across multiple perspectives to form a more complete understanding + - Highlight how seemingly contradictory viewpoints can complement each other + - Suggest frameworks for reconciling tensions or finding middle-ground approaches + - Identify actionable takeaways that incorporate multiple perspectives + - Write in plain text format - do NOT use markdown formatting, bullet points, or special characters + +Format the output in json with the following json schema definition: + + +{ + "type": "object", + "properties": { + "main_points_of_agreement": { + "type": "array", + "items": {"type": "string"} + }, + "areas_of_tension": { + "type": "array", + "items": { + "type": "object", + "properties": { + "topic": {"type": "string"}, + "viewpoints": { + "type": "object", + "additionalProperties": {"type": "string"} + } + } + } + }, + "perspective_gaps": {"type": "string"}, + "integrative_insights": {"type": "string"} + } +} + + +Make sure that the output is a json object with an output json schema defined above. +Only return the json object, no explanation or additional text. +""" + +# Reflection prompt for self-critique and improvement +# Used to evaluate the research and identify gaps, biases, and areas for improvement +REFLECTION_PROMPT = """ +You are a Deep Research assistant with the ability to critique and improve your own research. You will be given: +1. The main research query +2. The sub-questions explored so far +3. The synthesized information for each sub-question +4. Any viewpoint analysis performed + +Your task is to critically evaluate this research and identify: +1. Areas where the research is incomplete or has gaps +2. Questions that are important but not yet answered +3. Aspects where additional evidence or depth would significantly improve the research +4. Potential biases or limitations in the current findings + +Be constructively critical and identify the most important improvements that would substantially enhance the research. + +Format the output in json with the following json schema definition: + + +{ + "type": "object", + "properties": { + "critique": { + "type": "array", + "items": { + "type": "object", + "properties": { + "area": {"type": "string"}, + "issue": {"type": "string"}, + "importance": {"type": "string", "enum": ["high", "medium", "low"]} + } + } + }, + "additional_questions": { + "type": "array", + "items": {"type": "string"} + }, + "recommended_search_queries": { + "type": "array", + "items": {"type": "string"} + } + } +} + + +Make sure that the output is a json object with an output json schema defined above. +Only return the json object, no explanation or additional text. +""" + +# Additional synthesis prompt for incorporating new information +# Used to enhance original synthesis with new information and address critique points +ADDITIONAL_SYNTHESIS_PROMPT = """ +You are a Deep Research assistant. You will be given: +1. The original synthesized information on a research topic +2. New information from additional research +3. A critique of the original synthesis + +Your task is to enhance the original synthesis by incorporating the new information and addressing the critique. +The updated synthesis should: +1. Integrate new information seamlessly +2. Address gaps identified in the critique +3. Maintain a balanced, comprehensive, and accurate representation +4. Preserve the strengths of the original synthesis +5. Write in plain text format - do NOT use markdown formatting, bullet points, or special characters + +Format the output in json with the following json schema definition: + + +{ + "type": "object", + "properties": { + "enhanced_synthesis": {"type": "string"}, + "improvements_made": { + "type": "array", + "items": {"type": "string"} + }, + "remaining_limitations": {"type": "string"} + } +} + + +Make sure that the output is a json object with an output json schema defined above. +Only return the json object, no explanation or additional text. +""" + +# Final report generation prompt +# Used to compile a comprehensive HTML research report from all synthesized information +REPORT_GENERATION_PROMPT = """ +You are a Deep Research assistant responsible for compiling an in-depth, comprehensive research report. You will be given: +1. The original research query +2. The sub-questions that were explored +3. Synthesized information for each sub-question +4. Viewpoint analysis comparing different perspectives (if available) +5. Reflection metadata highlighting improvements and limitations + +Your task is to create a well-structured, coherent, professional-quality research report with the following features: + +EXECUTIVE SUMMARY (250-400 words): +- Begin with a compelling, substantive executive summary that provides genuine insight +- Highlight 3-5 key findings or insights that represent the most important discoveries +- Include brief mention of methodology and limitations +- Make the summary self-contained so it can be read independently of the full report +- End with 1-2 sentences on broader implications or applications of the research + +INTRODUCTION (200-300 words): +- Provide relevant background context on the main research query +- Explain why this topic is significant or worth investigating +- Outline the methodological approach used (sub-questions, search strategy, synthesis) +- Preview the overall structure of the report + +SUB-QUESTION SECTIONS: +- For each sub-question, create a dedicated section with: + * A descriptive section title (not just repeating the sub-question) + * A brief (1 paragraph) overview of key findings for this sub-question + * A "Key Findings" box highlighting 3-4 important discoveries for scannable reading + * The detailed, synthesized answer with appropriate paragraph breaks, lists, and formatting + * Proper citation of sources within the text (e.g., "According to [Source Name]...") + * Clear confidence indicator with appropriate styling + * Information gaps clearly identified in their own subsection + * Complete list of key sources used + +VIEWPOINT ANALYSIS SECTION (if available): +- Create a detailed section that: + * Explains the purpose and value of multi-perspective analysis + * Presents points of agreement as actionable insights, not just observations + * Structures tension areas with clear topic headings and balanced presentation of viewpoints + * Uses visual elements (different background colors, icons) to distinguish different perspectives + * Integrates perspective gaps and insights into a cohesive narrative + +CONCLUSION (300-400 words): +- Synthesize the overall findings, not just summarizing each section +- Connect insights from different sub-questions to form higher-level understanding +- Address the main research query directly with evidence-based conclusions +- Acknowledge remaining uncertainties and suggestions for further research +- End with implications or applications of the research findings + +OVERALL QUALITY REQUIREMENTS: +1. Create visually scannable content with clear headings, bullet points, and short paragraphs +2. Use semantic HTML (h1, h2, h3, p, blockquote, etc.) to create proper document structure +3. Include a comprehensive table of contents with anchor links to all major sections +4. Format all sources consistently in the references section with proper linking when available +5. Use tables, lists, and blockquotes to improve readability and highlight important information +6. Apply appropriate styling for different confidence levels (high, medium, low) +7. Ensure proper HTML nesting and structure throughout the document +8. Balance sufficient detail with clarity and conciseness +9. Make all text directly actionable and insight-driven, not just descriptive + +The report should be formatted in HTML with appropriate headings, paragraphs, citations, and formatting. +Use semantic HTML (h1, h2, h3, p, blockquote, etc.) to create a structured document. +Include a table of contents at the beginning with anchor links to each section. +For citations, use a consistent format and collect them in a references section at the end. + +Include this exact CSS stylesheet in your HTML to ensure consistent styling (do not modify it): + +```css + +``` + +The HTML structure should follow this pattern: + +```html + + + + + + [CSS STYLESHEET GOES HERE] + + +
    +

    Research Report: [Main Query]

    + + +
    +

    Table of Contents

    + +
    + + +
    +

    Executive Summary

    + [CONCISE SUMMARY OF KEY FINDINGS] +
    + + +
    +

    Introduction

    +

    [INTRODUCTION TO THE RESEARCH QUERY]

    +

    [OVERVIEW OF THE APPROACH AND SUB-QUESTIONS]

    +
    + + + [FOR EACH SUB-QUESTION]: +
    +

    [INDEX]. [SUB-QUESTION TEXT]

    +

    Confidence Level: [LEVEL]

    + + +
    +

    Key Findings

    +
      +
    • [KEY FINDING 1]
    • +
    • [KEY FINDING 2]
    • + [...] +
    +
    + +
    + [DETAILED ANSWER] +
    + + +
    +

    Information Gaps

    +

    [GAPS TEXT]

    +
    + + +
    +

    Key Sources

    +
      +
    • [SOURCE 1]
    • +
    • [SOURCE 2]
    • + [...] +
    +
    +
    + + +
    +

    Viewpoint Analysis

    + +

    Points of Agreement

    +
    +
      +
    • [AGREEMENT 1]
    • +
    • [AGREEMENT 2]
    • + [...] +
    +
    + +

    Areas of Tension

    + [FOR EACH TENSION]: +
    +

    [TENSION TOPIC]

    +
    +
    [VIEWPOINT 1 TITLE]
    +
    [VIEWPOINT 1 CONTENT]
    +
    [VIEWPOINT 2 TITLE]
    +
    [VIEWPOINT 2 CONTENT]
    + [...] +
    +
    + +

    Perspective Gaps

    +

    [PERSPECTIVE GAPS CONTENT]

    + +

    Integrative Insights

    +

    [INTEGRATIVE INSIGHTS CONTENT]

    +
    + + +
    +

    Conclusion

    +

    [CONCLUSION TEXT]

    +
    + + +
    +

    References

    +
      +
    • [REFERENCE 1]
    • +
    • [REFERENCE 2]
    • + [...] +
    +
    +
    + + +``` + +Special instructions: +1. For each sub-question, display the confidence level with appropriate styling (confidence-high, confidence-medium, or confidence-low) +2. Extract 2-3 key findings from each answer to create the key-findings box +3. Format all sources consistently in the references section +4. Use tables, lists, and blockquotes where appropriate to improve readability +5. Use the notice classes (info, warning) to highlight important information or limitations +6. Ensure all sections have proper ID attributes for the table of contents links + +Return only the complete HTML code for the report, with no explanations or additional text. +""" + +# Static HTML template for direct report generation without LLM +STATIC_HTML_TEMPLATE = """ + + + + + Research Report: {main_query} + + + +
    +

    Research Report: {main_query}

    + + +
    +

    Table of Contents

    + +
    + + +
    +

    Executive Summary

    +

    {executive_summary}

    +
    + + +
    +

    Introduction

    + {introduction_html} +
    + + + {sub_questions_html} + + + {viewpoint_analysis_html} + + +
    +

    Conclusion

    + {conclusion_html} +
    + + +
    +

    References

    + {references_html} +
    +
    + + +""" + +# Template for sub-question section in the static HTML report +SUB_QUESTION_TEMPLATE = """ +
    +
    +

    {index}. {question}

    + + + {confidence_icon} + + Confidence: {confidence_upper} + +
    + +
    +

    {answer}

    +
    + + {info_gaps_html} + + {key_sources_html} +
    +""" + +# Template for viewpoint analysis section in the static HTML report +VIEWPOINT_ANALYSIS_TEMPLATE = """ +
    +

    Viewpoint Analysis

    + +
    +

    🤝 Points of Agreement

    +
    +
      + {agreements_html} +
    +
    +
    + +
    +

    ⚖️ Areas of Tension

    +
    + {tensions_html} +
    +
    + +
    +

    🔍 Perspective Gaps

    +
    +

    {perspective_gaps}

    +
    +
    + +
    +

    💡 Integrative Insights

    +
    +

    {integrative_insights}

    +
    +
    +
    +""" + +# Executive Summary generation prompt +# Used to create a compelling, insight-driven executive summary +EXECUTIVE_SUMMARY_GENERATION_PROMPT = """ +You are a Deep Research assistant specializing in creating executive summaries. Given comprehensive research findings, your task is to create a compelling executive summary that captures the essence of the research and its key insights. + +Your executive summary should: + +1. **Opening Statement (1-2 sentences):** + - Start with a powerful, direct answer to the main research question + - Make it clear and definitive based on the evidence gathered + +2. **Key Findings (3-5 bullet points):** + - Extract the MOST IMPORTANT discoveries from across all sub-questions + - Focus on insights that are surprising, actionable, or paradigm-shifting + - Each finding should be specific and evidence-based, not generic + - Prioritize findings that directly address the main query + +3. **Critical Insights (2-3 sentences):** + - Synthesize patterns or themes that emerged across multiple sub-questions + - Highlight any unexpected discoveries or counter-intuitive findings + - Connect disparate findings to reveal higher-level understanding + +4. **Implications (2-3 sentences):** + - What do these findings mean for practitioners/stakeholders? + - What actions or decisions can be made based on this research? + - Why should the reader care about these findings? + +5. **Confidence and Limitations (1-2 sentences):** + - Briefly acknowledge the overall confidence level of the findings + - Note any significant gaps or areas requiring further investigation + +IMPORTANT GUIDELINES: +- Be CONCISE but INSIGHTFUL - every sentence should add value +- Use active voice and strong, definitive language where evidence supports it +- Avoid generic statements - be specific to the actual research findings +- Lead with the most important information +- Make it self-contained - reader should understand key findings without reading the full report +- Target length: 250-400 words + +Format as well-structured HTML paragraphs using

    tags and

      /
    • for bullet points. +""" + +# Introduction generation prompt +# Used to create a contextual, engaging introduction +INTRODUCTION_GENERATION_PROMPT = """ +You are a Deep Research assistant specializing in creating engaging introductions. Given a research query and the sub-questions explored, your task is to create an introduction that provides context and sets up the reader's expectations. + +Your introduction should: + +1. **Context and Relevance (2-3 sentences):** + - Why is this research question important NOW? + - What makes this topic significant or worth investigating? + - Connect to current trends, debates, or challenges in the field + +2. **Scope and Approach (2-3 sentences):** + - What specific aspects of the topic does this research explore? + - Briefly mention the key dimensions covered (based on sub-questions) + - Explain the systematic approach without being too technical + +3. **What to Expect (2-3 sentences):** + - Preview the structure of the report + - Hint at some of the interesting findings or tensions discovered + - Set expectations about the depth and breadth of analysis + +IMPORTANT GUIDELINES: +- Make it engaging - hook the reader's interest from the start +- Provide real context, not generic statements +- Connect to why this matters for the reader +- Keep it concise but informative (200-300 words) +- Use active voice and clear language +- Build anticipation for the findings without giving everything away + +Format as well-structured HTML paragraphs using

      tags. Do NOT include any headings or section titles. +""" + +# Conclusion generation prompt +# Used to synthesize all research findings into a comprehensive conclusion +CONCLUSION_GENERATION_PROMPT = """ +You are a Deep Research assistant specializing in synthesizing comprehensive research conclusions. Given all the research findings from a deep research study, your task is to create a thoughtful, evidence-based conclusion that ties together the overall findings. + +Your conclusion should: + +1. **Synthesis and Integration (150-200 words):** + - Connect insights from different sub-questions to form a higher-level understanding + - Identify overarching themes and patterns that emerge from the research + - Highlight how different findings relate to and support each other + - Avoid simply summarizing each section separately + +2. **Direct Response to Main Query (100-150 words):** + - Address the original research question directly with evidence-based conclusions + - State what the research definitively established vs. what remains uncertain + - Provide a clear, actionable answer based on the synthesized evidence + +3. **Limitations and Future Directions (100-120 words):** + - Acknowledge remaining uncertainties and information gaps across all sections + - Suggest specific areas where additional research would be most valuable + - Identify what types of evidence or perspectives would strengthen the findings + +4. **Implications and Applications (80-100 words):** + - Explain the practical significance of the research findings + - Suggest how the insights might be applied or what they mean for stakeholders + - Connect findings to broader contexts or implications + +Format your output as a well-structured conclusion section in HTML format with appropriate paragraph breaks and formatting. Use

      tags for paragraphs and organize the content logically with clear transitions between the different aspects outlined above. + +IMPORTANT: Do NOT include any headings like "Conclusion",

      , or

      tags - the section already has a heading. Start directly with the conclusion content in paragraph form. Just create flowing, well-structured paragraphs that cover all four aspects naturally. + +Ensure the conclusion feels cohesive and draws meaningful connections between findings rather than just listing them sequentially. +""" diff --git a/deep_research/utils/pydantic_models.py b/deep_research/utils/pydantic_models.py new file mode 100644 index 00000000..9fca23a3 --- /dev/null +++ b/deep_research/utils/pydantic_models.py @@ -0,0 +1,300 @@ +"""Pydantic model definitions for the research pipeline. + +This module contains all the Pydantic models that represent the state of the research +pipeline. These models replace the previous dataclasses implementation and leverage +Pydantic's validation, serialization, and integration with ZenML. +""" + +import time +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field +from typing_extensions import Literal + + +class SearchResult(BaseModel): + """Represents a search result for a sub-question.""" + + url: str = "" + content: str = "" + title: str = "" + snippet: str = "" + metadata: Optional[Dict[str, Any]] = Field(default_factory=dict) + + model_config = { + "extra": "ignore", # Ignore extra fields during deserialization + "frozen": False, # Allow attribute updates + "validate_assignment": True, # Validate when attributes are set + } + + +class ViewpointTension(BaseModel): + """Represents a tension between different viewpoints on a topic.""" + + topic: str = "" + viewpoints: Dict[str, str] = Field(default_factory=dict) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +class SynthesizedInfo(BaseModel): + """Represents synthesized information for a sub-question.""" + + synthesized_answer: str = "" + key_sources: List[str] = Field(default_factory=list) + confidence_level: Literal["high", "medium", "low"] = "medium" + information_gaps: str = "" + improvements: List[str] = Field(default_factory=list) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +class ViewpointAnalysis(BaseModel): + """Represents the analysis of different viewpoints on the research topic.""" + + main_points_of_agreement: List[str] = Field(default_factory=list) + areas_of_tension: List[ViewpointTension] = Field(default_factory=list) + perspective_gaps: str = "" + integrative_insights: str = "" + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +class ReflectionMetadata(BaseModel): + """Metadata about the reflection process.""" + + critique_summary: List[str] = Field(default_factory=list) + additional_questions_identified: List[str] = Field(default_factory=list) + searches_performed: List[str] = Field(default_factory=list) + improvements_made: float = Field( + default=0 + ) # Changed from int to float to handle timestamp values + error: Optional[str] = None + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +class ResearchState(BaseModel): + """Comprehensive state object for the enhanced research pipeline.""" + + # Initial query information + main_query: str = "" + sub_questions: List[str] = Field(default_factory=list) + + # Information gathering results + search_results: Dict[str, List[SearchResult]] = Field(default_factory=dict) + + # Synthesized information + synthesized_info: Dict[str, SynthesizedInfo] = Field(default_factory=dict) + + # Viewpoint analysis + viewpoint_analysis: Optional[ViewpointAnalysis] = None + + # Reflection results + enhanced_info: Dict[str, SynthesizedInfo] = Field(default_factory=dict) + reflection_metadata: Optional[ReflectionMetadata] = None + + # Final report + final_report_html: str = "" + + # Search cost tracking + search_costs: Dict[str, float] = Field( + default_factory=dict, + description="Total costs by search provider (e.g., {'exa': 0.0, 'tavily': 0.0})", + ) + search_cost_details: List[Dict[str, Any]] = Field( + default_factory=list, + description="Detailed log of each search with cost information", + ) + # Format: [{"provider": "exa", "query": "...", "cost": 0.0, "timestamp": ..., "step": "...", "sub_question": "..."}] + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + def get_current_stage(self) -> str: + """Determine the current stage of research based on filled data.""" + if self.final_report_html: + return "final_report" + elif self.enhanced_info: + return "after_reflection" + elif self.viewpoint_analysis: + return "after_viewpoint_analysis" + elif self.synthesized_info: + return "after_synthesis" + elif self.search_results: + return "after_search" + elif self.sub_questions: + return "after_query_decomposition" + elif self.main_query: + return "initial" + else: + return "empty" + + def update_sub_questions(self, sub_questions: List[str]) -> None: + """Update the sub-questions list.""" + self.sub_questions = sub_questions + + def update_search_results( + self, search_results: Dict[str, List[SearchResult]] + ) -> None: + """Update the search results.""" + self.search_results = search_results + + def update_synthesized_info( + self, synthesized_info: Dict[str, SynthesizedInfo] + ) -> None: + """Update the synthesized information.""" + self.synthesized_info = synthesized_info + + def update_viewpoint_analysis( + self, viewpoint_analysis: ViewpointAnalysis + ) -> None: + """Update the viewpoint analysis.""" + self.viewpoint_analysis = viewpoint_analysis + + def update_after_reflection( + self, + enhanced_info: Dict[str, SynthesizedInfo], + metadata: ReflectionMetadata, + ) -> None: + """Update with reflection results.""" + self.enhanced_info = enhanced_info + self.reflection_metadata = metadata + + def set_final_report(self, html: str) -> None: + """Set the final report HTML.""" + self.final_report_html = html + + +class ReflectionOutput(BaseModel): + """Output from the reflection generation step.""" + + state: ResearchState + recommended_queries: List[str] = Field(default_factory=list) + critique_summary: List[Dict[str, Any]] = Field(default_factory=list) + additional_questions: List[str] = Field(default_factory=list) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +class ApprovalDecision(BaseModel): + """Approval decision from human reviewer.""" + + approved: bool = False + selected_queries: List[str] = Field(default_factory=list) + approval_method: str = "" # "APPROVE_ALL", "SKIP", "SELECT_SPECIFIC" + reviewer_notes: str = "" + timestamp: float = Field(default_factory=lambda: time.time()) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +class PromptTypeMetrics(BaseModel): + """Metrics for a specific prompt type.""" + + prompt_type: str + total_cost: float + input_tokens: int + output_tokens: int + call_count: int + avg_cost_per_call: float + percentage_of_total_cost: float + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +class TracingMetadata(BaseModel): + """Metadata about token usage, costs, and performance for a pipeline run.""" + + # Pipeline information + pipeline_run_name: str = "" + pipeline_run_id: str = "" + + # Token usage + total_input_tokens: int = 0 + total_output_tokens: int = 0 + total_tokens: int = 0 + + # Cost information + total_cost: float = 0.0 + cost_breakdown_by_model: Dict[str, float] = Field(default_factory=dict) + + # Performance metrics + total_latency_seconds: float = 0.0 + formatted_latency: str = "" + observation_count: int = 0 + + # Model usage + models_used: List[str] = Field(default_factory=list) + model_token_breakdown: Dict[str, Dict[str, int]] = Field( + default_factory=dict + ) + # Format: {"model_name": {"input_tokens": X, "output_tokens": Y, "total_tokens": Z}} + + # Trace information + trace_id: str = "" + trace_name: str = "" + trace_tags: List[str] = Field(default_factory=list) + trace_metadata: Dict[str, Any] = Field(default_factory=dict) + + # Step-by-step breakdown + step_costs: Dict[str, float] = Field(default_factory=dict) + step_tokens: Dict[str, Dict[str, int]] = Field(default_factory=dict) + # Format: {"step_name": {"input_tokens": X, "output_tokens": Y}} + + # Prompt-level metrics + prompt_metrics: List[PromptTypeMetrics] = Field( + default_factory=list, description="Cost breakdown by prompt type" + ) + + # Search provider costs + search_costs: Dict[str, float] = Field( + default_factory=dict, description="Total costs by search provider" + ) + search_queries_count: Dict[str, int] = Field( + default_factory=dict, + description="Number of queries by search provider", + ) + search_cost_details: List[Dict[str, Any]] = Field( + default_factory=list, description="Detailed search cost information" + ) + + # Timestamp + collected_at: float = Field(default_factory=lambda: time.time()) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } diff --git a/deep_research/utils/search_utils.py b/deep_research/utils/search_utils.py new file mode 100644 index 00000000..a9b04d51 --- /dev/null +++ b/deep_research/utils/search_utils.py @@ -0,0 +1,721 @@ +import logging +import os +from enum import Enum +from typing import Any, Dict, List, Optional, Union + +from tavily import TavilyClient + +try: + from exa_py import Exa + + EXA_AVAILABLE = True +except ImportError: + EXA_AVAILABLE = False + Exa = None + +from utils.llm_utils import get_structured_llm_output +from utils.prompts import DEFAULT_SEARCH_QUERY_PROMPT +from utils.pydantic_models import SearchResult + +logger = logging.getLogger(__name__) + + +class SearchProvider(Enum): + TAVILY = "tavily" + EXA = "exa" + BOTH = "both" + + +class SearchEngineConfig: + """Configuration for search engines""" + + def __init__(self): + self.tavily_api_key = os.getenv("TAVILY_API_KEY") + self.exa_api_key = os.getenv("EXA_API_KEY") + self.default_provider = os.getenv("DEFAULT_SEARCH_PROVIDER", "tavily") + self.enable_parallel_search = ( + os.getenv("ENABLE_PARALLEL_SEARCH", "false").lower() == "true" + ) + + +def get_search_client(provider: Union[str, SearchProvider]) -> Optional[Any]: + """Get the appropriate search client based on provider.""" + if isinstance(provider, str): + provider = SearchProvider(provider.lower()) + + config = SearchEngineConfig() + + if provider == SearchProvider.TAVILY: + if not config.tavily_api_key: + raise ValueError("TAVILY_API_KEY environment variable not set") + return TavilyClient(api_key=config.tavily_api_key) + + elif provider == SearchProvider.EXA: + if not EXA_AVAILABLE: + raise ImportError( + "exa-py is not installed. Please install it with: pip install exa-py" + ) + if not config.exa_api_key: + raise ValueError("EXA_API_KEY environment variable not set") + return Exa(config.exa_api_key) + + return None + + +def tavily_search( + query: str, + include_raw_content: bool = True, + max_results: int = 3, + cap_content_length: int = 20000, +) -> Dict[str, Any]: + """Perform a search using the Tavily API. + + Args: + query: Search query + include_raw_content: Whether to include raw content in results + max_results: Maximum number of results to return + cap_content_length: Maximum length of content to return + + Returns: + Dict[str, Any]: Search results from Tavily in the following format: + { + "query": str, # The original query + "results": List[Dict], # List of search result objects + "error": str, # Error message (if an error occurred, otherwise omitted) + } + + Each result in "results" has the following structure: + { + "url": str, # URL of the search result + "raw_content": str, # Raw content of the page (if include_raw_content=True) + "title": str, # Title of the page + "snippet": str, # Snippet of the page content + } + """ + try: + tavily_client = get_search_client(SearchProvider.TAVILY) + + # First try with advanced search + results = tavily_client.search( + query=query, + include_raw_content=include_raw_content, + max_results=max_results, + search_depth="advanced", # Use advanced search for better results + include_domains=[], # No domain restrictions + exclude_domains=[], # No exclusions + include_answer=False, # We don't need the answer field + include_images=False, # We don't need images + # Note: 'include_snippets' is not a supported parameter + ) + + # Check if we got good results (with non-None and non-empty content) + if include_raw_content and "results" in results: + bad_content_count = sum( + 1 + for r in results["results"] + if "raw_content" in r + and ( + r["raw_content"] is None or r["raw_content"].strip() == "" + ) + ) + + # If more than half of results have bad content, try a different approach + if bad_content_count > len(results["results"]) / 2: + logger.warning( + f"{bad_content_count}/{len(results['results'])} results have None or empty content. " + "Trying to use 'content' field instead of 'raw_content'..." + ) + + # Try to use the 'content' field which comes by default + for result in results["results"]: + if ( + "raw_content" in result + and ( + result["raw_content"] is None + or result["raw_content"].strip() == "" + ) + ) and "content" in result: + result["raw_content"] = result["content"] + logger.info( + f"Using 'content' field as 'raw_content' for URL {result.get('url', 'unknown')}" + ) + + # Re-check after our fix + bad_content_count = sum( + 1 + for r in results["results"] + if "raw_content" in r + and ( + r["raw_content"] is None + or r["raw_content"].strip() == "" + ) + ) + + if bad_content_count > 0: + logger.warning( + f"Still have {bad_content_count}/{len(results['results'])} results with bad content after fixes." + ) + + # Try alternative approach - search with 'include_answer=True' + try: + # Search with include_answer=True which may give us better content + logger.info( + "Trying alternative search with include_answer=True" + ) + alt_results = tavily_client.search( + query=query, + include_raw_content=include_raw_content, + max_results=max_results, + search_depth="advanced", + include_domains=[], + exclude_domains=[], + include_answer=True, # Include answer this time + include_images=False, + ) + + # Check if we got any improved content + if "results" in alt_results: + # Create a merged results set taking the best content + for i, result in enumerate(alt_results["results"]): + if i < len(results["results"]): + if ( + "raw_content" in result + and result["raw_content"] + and ( + results["results"][i].get( + "raw_content" + ) + is None + or results["results"][i] + .get("raw_content", "") + .strip() + == "" + ) + ): + # Replace the bad content with better content from alt_results + results["results"][i]["raw_content"] = ( + result["raw_content"] + ) + logger.info( + f"Replaced bad content with better content from alternative search for URL {result.get('url', 'unknown')}" + ) + + # If answer is available, add it as a special result + if "answer" in alt_results and alt_results["answer"]: + answer_text = alt_results["answer"] + answer_result = { + "url": "tavily-generated-answer", + "title": "Generated Answer", + "raw_content": f"Generated Answer based on search results:\n\n{answer_text}", + "content": answer_text, + } + results["results"].append(answer_result) + logger.info( + "Added Tavily generated answer as additional search result" + ) + + except Exception as alt_error: + logger.warning( + f"Failed to get better results with alternative search: {alt_error}" + ) + + # Cap content length if specified + if cap_content_length > 0 and "results" in results: + for result in results["results"]: + if "raw_content" in result and result["raw_content"]: + result["raw_content"] = result["raw_content"][ + :cap_content_length + ] + + return results + except Exception as e: + logger.error(f"Error in Tavily search: {e}") + # Return an error structure that's compatible with our expected format + return {"query": query, "results": [], "error": str(e)} + + +def exa_search( + query: str, + max_results: int = 3, + cap_content_length: int = 20000, + search_mode: str = "auto", + include_highlights: bool = False, +) -> Dict[str, Any]: + """Perform a search using the Exa API. + + Args: + query: Search query + max_results: Maximum number of results to return + cap_content_length: Maximum length of content to return + search_mode: Search mode ("neural", "keyword", or "auto") + include_highlights: Whether to include highlights in results + + Returns: + Dict[str, Any]: Search results from Exa in a format compatible with Tavily + """ + try: + exa_client = get_search_client(SearchProvider.EXA) + + # Configure content options + text_options = {"max_characters": cap_content_length} + + kwargs = { + "query": query, + "num_results": max_results, + "type": search_mode, # "neural", "keyword", or "auto" + "text": text_options, + } + + if include_highlights: + kwargs["highlights"] = { + "highlights_per_url": 2, + "num_sentences": 3, + } + + response = exa_client.search_and_contents(**kwargs) + + # Extract cost information + exa_cost = 0.0 + if hasattr(response, "cost_dollars") and hasattr( + response.cost_dollars, "total" + ): + exa_cost = response.cost_dollars.total + logger.info( + f"Exa search cost for query '{query}': ${exa_cost:.4f}" + ) + + # Convert to standardized format compatible with Tavily + results = {"query": query, "results": [], "exa_cost": exa_cost} + + for r in response.results: + result_dict = { + "url": r.url, + "title": r.title or "", + "snippet": "", + "raw_content": getattr(r, "text", ""), + "content": getattr(r, "text", ""), + } + + # Add highlights as snippet if available + if hasattr(r, "highlights") and r.highlights: + result_dict["snippet"] = " ".join(r.highlights[:1]) + + # Store additional metadata + result_dict["_metadata"] = { + "provider": "exa", + "score": getattr(r, "score", None), + "published_date": getattr(r, "published_date", None), + "author": getattr(r, "author", None), + } + + results["results"].append(result_dict) + + return results + + except Exception as e: + logger.error(f"Error in Exa search: {e}") + return {"query": query, "results": [], "error": str(e)} + + +def unified_search( + query: str, + provider: Union[str, SearchProvider, None] = None, + max_results: int = 3, + cap_content_length: int = 20000, + search_mode: str = "auto", + include_highlights: bool = False, + compare_results: bool = False, + **kwargs, +) -> Union[List[SearchResult], Dict[str, List[SearchResult]]]: + """Unified search interface supporting multiple providers. + + Args: + query: Search query + provider: Search provider to use (tavily, exa, both) + max_results: Maximum number of results + cap_content_length: Maximum content length + search_mode: Search mode for Exa ("neural", "keyword", "auto") + include_highlights: Include highlights for Exa results + compare_results: Return results from both providers separately + + Returns: + List[SearchResult] or Dict mapping provider to results (when compare_results=True or provider="both") + """ + # Use default provider if not specified + if provider is None: + config = SearchEngineConfig() + provider = config.default_provider + + # Convert string to enum if needed + if isinstance(provider, str): + provider = SearchProvider(provider.lower()) + + # Handle single provider case + if provider == SearchProvider.TAVILY: + results = tavily_search( + query, + max_results=max_results, + cap_content_length=cap_content_length, + ) + extracted, cost = extract_search_results(results, provider="tavily") + return extracted if not compare_results else {"tavily": extracted} + + elif provider == SearchProvider.EXA: + results = exa_search( + query=query, + max_results=max_results, + cap_content_length=cap_content_length, + search_mode=search_mode, + include_highlights=include_highlights, + ) + extracted, cost = extract_search_results(results, provider="exa") + return extracted if not compare_results else {"exa": extracted} + + elif provider == SearchProvider.BOTH: + # Run both searches + tavily_results = tavily_search( + query, + max_results=max_results, + cap_content_length=cap_content_length, + ) + exa_results = exa_search( + query=query, + max_results=max_results, + cap_content_length=cap_content_length, + search_mode=search_mode, + include_highlights=include_highlights, + ) + + # Extract results from both + tavily_extracted, tavily_cost = extract_search_results( + tavily_results, provider="tavily" + ) + exa_extracted, exa_cost = extract_search_results( + exa_results, provider="exa" + ) + + if compare_results: + return {"tavily": tavily_extracted, "exa": exa_extracted} + else: + # Merge results, interleaving them + merged = [] + max_len = max(len(tavily_extracted), len(exa_extracted)) + for i in range(max_len): + if i < len(tavily_extracted): + merged.append(tavily_extracted[i]) + if i < len(exa_extracted): + merged.append(exa_extracted[i]) + return merged[:max_results] # Limit to requested number + + else: + raise ValueError(f"Unknown provider: {provider}") + + +def extract_search_results( + search_results: Dict[str, Any], provider: str = "tavily" +) -> tuple[List[SearchResult], float]: + """Extract SearchResult objects from provider-specific API responses. + + Args: + search_results: Results from search API + provider: Which provider the results came from + + Returns: + Tuple of (List[SearchResult], float): List of converted SearchResult objects with standardized fields + and the search cost (0.0 if not available). + SearchResult is a Pydantic model defined in data_models.py that includes: + - url: The URL of the search result + - content: The raw content of the page + - title: The title of the page + - snippet: A brief snippet of the page content + """ + results_list = [] + search_cost = search_results.get( + "exa_cost", 0.0 + ) # Extract cost if present + + if "results" in search_results: + for result in search_results["results"]: + if "url" in result: + # Get fields with defaults + url = result["url"] + title = result.get("title", "") + + # Try to extract the best content available: + # 1. First try raw_content (if we requested it) + # 2. Then try regular content (always available) + # 3. Then try to use snippet combined with title + # 4. Last resort: use just title + + raw_content = result.get("raw_content", None) + regular_content = result.get("content", "") + snippet = result.get("snippet", "") + + # Set our final content - prioritize raw_content if available and not None + if raw_content is not None and raw_content.strip(): + content = raw_content + # Next best is the regular content field + elif regular_content and regular_content.strip(): + content = regular_content + logger.info( + f"Using 'content' field for URL {url} because raw_content was not available" + ) + # Try to create a usable content from snippet and title + elif snippet: + content = f"Title: {title}\n\nContent: {snippet}" + logger.warning( + f"Using title and snippet as content fallback for {url}" + ) + # Last resort - just use the title + elif title: + content = ( + f"Title: {title}\n\nNo content available for this URL." + ) + logger.warning( + f"Using only title as content fallback for {url}" + ) + # Nothing available + else: + content = "" + logger.warning( + f"No content available for URL {url}, using empty string" + ) + + # Create SearchResult with provider metadata + search_result = SearchResult( + url=url, + content=content, + title=title, + snippet=snippet, + ) + + # Add provider info to metadata if available + if "_metadata" in result: + search_result.metadata = result["_metadata"] + else: + search_result.metadata = {"provider": provider} + + results_list.append(search_result) + + # If we got the answer (Tavily specific), add it as a special result + if ( + provider == "tavily" + and "answer" in search_results + and search_results["answer"] + ): + answer_text = search_results["answer"] + results_list.append( + SearchResult( + url="tavily-generated-answer", + content=f"Generated Answer based on search results:\n\n{answer_text}", + title="Tavily Generated Answer", + snippet=answer_text[:100] + "..." + if len(answer_text) > 100 + else answer_text, + metadata={"provider": "tavily", "type": "generated_answer"}, + ) + ) + logger.info("Added Tavily generated answer as a search result") + + return results_list, search_cost + + +def generate_search_query( + sub_question: str, + model: str = "sambanova/Llama-4-Maverick-17B-128E-Instruct", + system_prompt: Optional[str] = None, + project: str = "deep-research", +) -> Dict[str, Any]: + """Generate an optimized search query for a sub-question. + + Uses litellm for model inference via get_structured_llm_output. + + Args: + sub_question: The sub-question to generate a search query for + model: Model to use (with provider prefix) + system_prompt: System prompt for the LLM, defaults to DEFAULT_SEARCH_QUERY_PROMPT + project: Langfuse project name for LLM tracking + + Returns: + Dictionary with search query and reasoning + """ + if system_prompt is None: + system_prompt = DEFAULT_SEARCH_QUERY_PROMPT + + fallback_response = {"search_query": sub_question, "reasoning": ""} + + return get_structured_llm_output( + prompt=sub_question, + system_prompt=system_prompt, + model=model, + fallback_response=fallback_response, + project=project, + ) + + +def search_and_extract_results( + query: str, + max_results: int = 3, + cap_content_length: int = 20000, + max_retries: int = 2, + provider: Optional[Union[str, SearchProvider]] = None, + search_mode: str = "auto", + include_highlights: bool = False, +) -> tuple[List[SearchResult], float]: + """Perform a search and extract results in one step. + + Args: + query: Search query + max_results: Maximum number of results to return + cap_content_length: Maximum length of content to return + max_retries: Maximum number of retries in case of failure + provider: Search provider to use (tavily, exa, both) + search_mode: Search mode for Exa ("neural", "keyword", "auto") + include_highlights: Include highlights for Exa results + + Returns: + Tuple of (List of SearchResult objects, search cost) + """ + results = [] + total_cost = 0.0 + retry_count = 0 + + # List of alternative query formats to try if the original query fails + # to yield good results with non-None content + alternative_queries = [ + query, # Original query first + f'"{query}"', # Try exact phrase matching + f"about {query}", # Try broader context + f"research on {query}", # Try research-oriented results + query.replace(" OR ", " "), # Try without OR operator + ] + + while retry_count <= max_retries and retry_count < len( + alternative_queries + ): + try: + current_query = alternative_queries[retry_count] + logger.info( + f"Searching with query ({retry_count + 1}/{max_retries + 1}): {current_query}" + ) + + # Determine if we're using Exa to track costs + using_exa = False + if provider: + if isinstance(provider, str): + using_exa = provider.lower() in ["exa", "both"] + else: + using_exa = provider in [ + SearchProvider.EXA, + SearchProvider.BOTH, + ] + else: + config = SearchEngineConfig() + using_exa = config.default_provider.lower() in ["exa", "both"] + + # Perform search based on provider + if using_exa and provider != SearchProvider.BOTH: + # Direct Exa search + search_results = exa_search( + query=current_query, + max_results=max_results, + cap_content_length=cap_content_length, + search_mode=search_mode, + include_highlights=include_highlights, + ) + results, cost = extract_search_results( + search_results, provider="exa" + ) + total_cost += cost + elif provider == SearchProvider.BOTH: + # Search with both providers + tavily_results = tavily_search( + current_query, + max_results=max_results, + cap_content_length=cap_content_length, + ) + exa_results = exa_search( + query=current_query, + max_results=max_results, + cap_content_length=cap_content_length, + search_mode=search_mode, + include_highlights=include_highlights, + ) + + # Extract results from both + tavily_extracted, _ = extract_search_results( + tavily_results, provider="tavily" + ) + exa_extracted, exa_cost = extract_search_results( + exa_results, provider="exa" + ) + total_cost += exa_cost + + # Merge results + results = [] + max_len = max(len(tavily_extracted), len(exa_extracted)) + for i in range(max_len): + if i < len(tavily_extracted): + results.append(tavily_extracted[i]) + if i < len(exa_extracted): + results.append(exa_extracted[i]) + results = results[:max_results] + else: + # Tavily search or unified search + results = unified_search( + query=current_query, + provider=provider, + max_results=max_results, + cap_content_length=cap_content_length, + search_mode=search_mode, + include_highlights=include_highlights, + ) + + # Handle case where unified_search returns a dict + if isinstance(results, dict): + all_results = [] + for provider_results in results.values(): + all_results.extend(provider_results) + results = all_results[:max_results] + + # Check if we got results with actual content + if results: + # Count results with non-empty content + content_results = sum(1 for r in results if r.content.strip()) + + if content_results >= max(1, len(results) // 2): + logger.info( + f"Found {content_results}/{len(results)} results with content" + ) + return results, total_cost + else: + logger.warning( + f"Only found {content_results}/{len(results)} results with content. " + f"Trying alternative query..." + ) + + # If we didn't get good results but haven't hit max retries yet, try again + if retry_count < max_retries: + logger.warning( + f"Inadequate search results. Retrying with alternative query... ({retry_count + 1}/{max_retries})" + ) + retry_count += 1 + else: + # If we're out of retries, return whatever we have + logger.warning( + f"Out of retries. Returning best results found ({len(results)} results)." + ) + return results, total_cost + + except Exception as e: + if retry_count < max_retries: + logger.warning( + f"Search failed with error: {e}. Retrying... ({retry_count + 1}/{max_retries})" + ) + retry_count += 1 + else: + logger.error(f"Search failed after {max_retries} retries: {e}") + return [], 0.0 + + # If we've exhausted all retries, return the best results we have + return results, total_cost diff --git a/deep_research/utils/tracing_metadata_utils.py b/deep_research/utils/tracing_metadata_utils.py new file mode 100644 index 00000000..59c7b37e --- /dev/null +++ b/deep_research/utils/tracing_metadata_utils.py @@ -0,0 +1,745 @@ +"""Utilities for collecting and analyzing tracing metadata from Langfuse.""" + +import time +from datetime import datetime, timedelta, timezone +from functools import wraps +from typing import Any, Dict, List, Optional, Tuple + +from langfuse import Langfuse +from langfuse.api.core import ApiError +from langfuse.client import ObservationsView, TraceWithDetails +from rich import print +from rich.console import Console +from rich.table import Table + +console = Console() + +langfuse = Langfuse() + +# Prompt type identification keywords +PROMPT_IDENTIFIERS = { + "query_decomposition": [ + "MAIN RESEARCH QUERY", + "DIFFERENT DIMENSIONS", + "sub-questions", + ], + "search_query": ["Deep Research assistant", "effective search query"], + "synthesis": [ + "information synthesis", + "comprehensive answer", + "confidence level", + ], + "viewpoint_analysis": [ + "multi-perspective analysis", + "viewpoint categories", + ], + "reflection": ["critique and improve", "information gaps"], + "additional_synthesis": ["enhance the original synthesis"], + "conclusion_generation": [ + "Synthesis and Integration", + "Direct Response to Main Query", + ], + "executive_summary": [ + "executive summaries", + "Key Findings", + "250-400 words", + ], + "introduction": ["engaging introductions", "Context and Relevance"], +} + +# Rate limiting configuration +# Adjust these based on your Langfuse tier: +# - Hobby: 30 req/min for Other APIs -> ~2s between requests +# - Core: 100 req/min -> ~0.6s between requests +# - Pro: 1000 req/min -> ~0.06s between requests +RATE_LIMIT_DELAY = 0.1 # 100ms between requests (safe for most tiers) +MAX_RETRIES = 3 +INITIAL_BACKOFF = 1.0 # Initial backoff in seconds + +# Batch processing configuration +BATCH_DELAY = 0.5 # Additional delay between batches of requests + + +def rate_limited(func): + """Decorator to add rate limiting between API calls.""" + + @wraps(func) + def wrapper(*args, **kwargs): + time.sleep(RATE_LIMIT_DELAY) + return func(*args, **kwargs) + + return wrapper + + +def retry_with_backoff(func): + """Decorator to retry functions with exponential backoff on rate limit errors.""" + + @wraps(func) + def wrapper(*args, **kwargs): + backoff = INITIAL_BACKOFF + last_exception = None + + for attempt in range(MAX_RETRIES): + try: + return func(*args, **kwargs) + except ApiError as e: + if e.status_code == 429: # Rate limit error + last_exception = e + if attempt < MAX_RETRIES - 1: + wait_time = backoff * (2**attempt) + console.print( + f"[yellow]Rate limit hit. Retrying in {wait_time:.1f}s...[/yellow]" + ) + time.sleep(wait_time) + continue + raise + except Exception: + # For non-rate limit errors, raise immediately + raise + + # If we've exhausted all retries + if last_exception: + raise last_exception + + return wrapper + + +@rate_limited +@retry_with_backoff +def fetch_traces_safe(limit: Optional[int] = None) -> List[TraceWithDetails]: + """Safely fetch traces with rate limiting and retry logic.""" + return langfuse.fetch_traces(limit=limit).data + + +@rate_limited +@retry_with_backoff +def fetch_observations_safe(trace_id: str) -> List[ObservationsView]: + """Safely fetch observations with rate limiting and retry logic.""" + return langfuse.fetch_observations(trace_id=trace_id).data + + +def get_total_trace_cost(trace_id: str) -> float: + """Calculate the total cost for a single trace by summing all observation costs. + + Args: + trace_id: The ID of the trace to calculate cost for + + Returns: + Total cost across all observations in the trace + """ + try: + observations = fetch_observations_safe(trace_id=trace_id) + total_cost = 0.0 + + for obs in observations: + # Check multiple possible cost fields + if ( + hasattr(obs, "calculated_total_cost") + and obs.calculated_total_cost + ): + total_cost += obs.calculated_total_cost + elif hasattr(obs, "total_price") and obs.total_price: + total_cost += obs.total_price + elif hasattr(obs, "total_cost") and obs.total_cost: + total_cost += obs.total_cost + # If cost details are available, calculate from input/output costs + elif hasattr(obs, "calculated_input_cost") and hasattr( + obs, "calculated_output_cost" + ): + if obs.calculated_input_cost and obs.calculated_output_cost: + total_cost += ( + obs.calculated_input_cost + obs.calculated_output_cost + ) + + return total_cost + except Exception as e: + print(f"[red]Error calculating trace cost: {e}[/red]") + return 0.0 + + +def get_total_tokens_used(trace_id: str) -> Tuple[int, int]: + """Calculate total input and output tokens used for a trace. + + Args: + trace_id: The ID of the trace to calculate tokens for + + Returns: + Tuple of (input_tokens, output_tokens) + """ + try: + observations = fetch_observations_safe(trace_id=trace_id) + total_input_tokens = 0 + total_output_tokens = 0 + + for obs in observations: + # Check for token fields in different possible locations + if hasattr(obs, "usage") and obs.usage: + if hasattr(obs.usage, "input") and obs.usage.input: + total_input_tokens += obs.usage.input + if hasattr(obs.usage, "output") and obs.usage.output: + total_output_tokens += obs.usage.output + # Also check for direct token fields + elif hasattr(obs, "promptTokens") and hasattr( + obs, "completionTokens" + ): + if obs.promptTokens: + total_input_tokens += obs.promptTokens + if obs.completionTokens: + total_output_tokens += obs.completionTokens + + return total_input_tokens, total_output_tokens + except Exception as e: + print(f"[red]Error calculating tokens: {e}[/red]") + return 0, 0 + + +def get_trace_stats(trace: TraceWithDetails) -> Dict[str, Any]: + """Get comprehensive statistics for a trace. + + Args: + trace: The trace object to analyze + + Returns: + Dictionary containing trace statistics including cost, latency, tokens, and metadata + """ + try: + # Get cost and token data + total_cost = get_total_trace_cost(trace.id) + input_tokens, output_tokens = get_total_tokens_used(trace.id) + + # Get observation count + observations = fetch_observations_safe(trace_id=trace.id) + observation_count = len(observations) + + # Extract model information from observations + models_used = set() + for obs in observations: + if hasattr(obs, "model") and obs.model: + models_used.add(obs.model) + + stats = { + "trace_id": trace.id, + "timestamp": trace.timestamp, + "total_cost": total_cost, + "latency_seconds": trace.latency + if hasattr(trace, "latency") + else 0, + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, + "observation_count": observation_count, + "models_used": list(models_used), + "metadata": trace.metadata if hasattr(trace, "metadata") else {}, + "tags": trace.tags if hasattr(trace, "tags") else [], + "user_id": trace.user_id if hasattr(trace, "user_id") else None, + "session_id": trace.session_id + if hasattr(trace, "session_id") + else None, + } + + # Add formatted latency + if stats["latency_seconds"]: + minutes = int(stats["latency_seconds"] // 60) + seconds = stats["latency_seconds"] % 60 + stats["latency_formatted"] = f"{minutes}m {seconds:.1f}s" + else: + stats["latency_formatted"] = "0m 0.0s" + + return stats + except Exception as e: + print(f"[red]Error getting trace stats: {e}[/red]") + return {} + + +def get_traces_by_name(name: str, limit: int = 1) -> List[TraceWithDetails]: + """Get traces by name using Langfuse API. + + Args: + name: The name of the trace to search for + limit: Maximum number of traces to return (default: 1) + + Returns: + List of traces matching the name + """ + try: + # Use the Langfuse API to get traces by name + traces_response = langfuse.get_traces(name=name, limit=limit) + return traces_response.data + except Exception as e: + print(f"[red]Error fetching traces by name: {e}[/red]") + return [] + + +def get_observations_for_trace(trace_id: str) -> List[ObservationsView]: + """Get all observations for a specific trace. + + Args: + trace_id: The ID of the trace + + Returns: + List of observations for the trace + """ + try: + observations_response = langfuse.get_observations(trace_id=trace_id) + return observations_response.data + except Exception as e: + print(f"[red]Error fetching observations: {e}[/red]") + return [] + + +def filter_traces_by_date_range( + start_date: datetime, end_date: datetime, limit: Optional[int] = None +) -> List[TraceWithDetails]: + """Filter traces within a specific date range. + + Args: + start_date: Start of the date range (inclusive) + end_date: End of the date range (inclusive) + limit: Maximum number of traces to return + + Returns: + List of traces within the date range + """ + try: + # Ensure dates are timezone-aware + if start_date.tzinfo is None: + start_date = start_date.replace(tzinfo=timezone.utc) + if end_date.tzinfo is None: + end_date = end_date.replace(tzinfo=timezone.utc) + + # Fetch all traces (or up to API maximum limit of 100) + all_traces = fetch_traces_safe(limit=limit or 100) + + # Filter by date range + filtered_traces = [ + trace + for trace in all_traces + if start_date <= trace.timestamp <= end_date + ] + + # Sort by timestamp (most recent first) + filtered_traces.sort(key=lambda x: x.timestamp, reverse=True) + + # Apply limit if specified + if limit: + filtered_traces = filtered_traces[:limit] + + return filtered_traces + except Exception as e: + print(f"[red]Error filtering traces by date range: {e}[/red]") + return [] + + +def get_traces_last_n_days( + days: int, limit: Optional[int] = None +) -> List[TraceWithDetails]: + """Get traces from the last N days. + + Args: + days: Number of days to look back + limit: Maximum number of traces to return + + Returns: + List of traces from the last N days + """ + end_date = datetime.now(timezone.utc) + start_date = end_date - timedelta(days=days) + + return filter_traces_by_date_range(start_date, end_date, limit) + + +def get_trace_stats_batch( + traces: List[TraceWithDetails], show_progress: bool = True +) -> List[Dict[str, Any]]: + """Get statistics for multiple traces efficiently with progress tracking. + + Args: + traces: List of traces to analyze + show_progress: Whether to show progress bar + + Returns: + List of dictionaries containing trace statistics + """ + stats_list = [] + + for i, trace in enumerate(traces): + if show_progress and i % 5 == 0: + console.print( + f"[dim]Processing trace {i + 1}/{len(traces)}...[/dim]" + ) + + stats = get_trace_stats(trace) + stats_list.append(stats) + + return stats_list + + +def get_aggregate_stats_for_traces( + traces: List[TraceWithDetails], +) -> Dict[str, Any]: + """Calculate aggregate statistics for a list of traces. + + Args: + traces: List of traces to analyze + + Returns: + Dictionary containing aggregate statistics + """ + if not traces: + return { + "trace_count": 0, + "total_cost": 0.0, + "total_input_tokens": 0, + "total_output_tokens": 0, + "total_tokens": 0, + "average_cost_per_trace": 0.0, + "average_latency_seconds": 0.0, + "total_observations": 0, + } + + total_cost = 0.0 + total_input_tokens = 0 + total_output_tokens = 0 + total_latency = 0.0 + total_observations = 0 + all_models = set() + + for trace in traces: + stats = get_trace_stats(trace) + total_cost += stats.get("total_cost", 0) + total_input_tokens += stats.get("input_tokens", 0) + total_output_tokens += stats.get("output_tokens", 0) + total_latency += stats.get("latency_seconds", 0) + total_observations += stats.get("observation_count", 0) + all_models.update(stats.get("models_used", [])) + + return { + "trace_count": len(traces), + "total_cost": total_cost, + "total_input_tokens": total_input_tokens, + "total_output_tokens": total_output_tokens, + "total_tokens": total_input_tokens + total_output_tokens, + "average_cost_per_trace": total_cost / len(traces) if traces else 0, + "average_latency_seconds": total_latency / len(traces) + if traces + else 0, + "total_observations": total_observations, + "models_used": list(all_models), + } + + +def display_trace_stats_table( + traces: List[TraceWithDetails], title: str = "Trace Statistics" +): + """Display trace statistics in a formatted table. + + Args: + traces: List of traces to display + title: Title for the table + """ + table = Table(title=title, show_header=True, header_style="bold magenta") + table.add_column("Trace ID", style="cyan", no_wrap=True) + table.add_column("Timestamp", style="yellow") + table.add_column("Cost ($)", justify="right", style="green") + table.add_column("Tokens (In/Out)", justify="right") + table.add_column("Latency", justify="right") + table.add_column("Observations", justify="right") + + for trace in traces[:10]: # Limit to 10 for display + stats = get_trace_stats(trace) + table.add_row( + stats["trace_id"][:12] + "...", + stats["timestamp"].strftime("%Y-%m-%d %H:%M"), + f"${stats['total_cost']:.4f}", + f"{stats['input_tokens']:,}/{stats['output_tokens']:,}", + stats["latency_formatted"], + str(stats["observation_count"]), + ) + + console.print(table) + + +def identify_prompt_type(observation: ObservationsView) -> str: + """Identify the prompt type based on keywords in the observation's input. + + Examines the system prompt in observation.input['messages'][0]['content'] + for unique keywords that identify each prompt type. + + Args: + observation: The observation to analyze + + Returns: + str: The prompt type name, or "unknown" if not identified + """ + try: + # Access the system prompt from the messages + if hasattr(observation, "input") and observation.input: + messages = observation.input.get("messages", []) + if messages and len(messages) > 0: + system_content = messages[0].get("content", "") + + # Check each prompt type's keywords + for prompt_type, keywords in PROMPT_IDENTIFIERS.items(): + # Check if any keyword is in the system prompt + for keyword in keywords: + if keyword in system_content: + return prompt_type + + return "unknown" + except Exception as e: + console.print( + f"[yellow]Warning: Could not identify prompt type: {e}[/yellow]" + ) + return "unknown" + + +def get_costs_by_prompt_type(trace_id: str) -> Dict[str, Dict[str, float]]: + """Get cost breakdown by prompt type for a given trace. + + Uses observation.usage.input/output for token counts and + observation.calculated_total_cost for costs. + + Args: + trace_id: The ID of the trace to analyze + + Returns: + Dict mapping prompt_type to { + 'cost': float, + 'input_tokens': int, + 'output_tokens': int, + 'count': int # number of calls + } + """ + try: + observations = fetch_observations_safe(trace_id=trace_id) + prompt_metrics = {} + + for obs in observations: + # Identify prompt type + prompt_type = identify_prompt_type(obs) + + # Initialize metrics for this prompt type if needed + if prompt_type not in prompt_metrics: + prompt_metrics[prompt_type] = { + "cost": 0.0, + "input_tokens": 0, + "output_tokens": 0, + "count": 0, + } + + # Add cost + cost = 0.0 + if ( + hasattr(obs, "calculated_total_cost") + and obs.calculated_total_cost + ): + cost = obs.calculated_total_cost + prompt_metrics[prompt_type]["cost"] += cost + + # Add tokens + if hasattr(obs, "usage") and obs.usage: + if hasattr(obs.usage, "input") and obs.usage.input: + prompt_metrics[prompt_type]["input_tokens"] += ( + obs.usage.input + ) + if hasattr(obs.usage, "output") and obs.usage.output: + prompt_metrics[prompt_type]["output_tokens"] += ( + obs.usage.output + ) + + # Increment count + prompt_metrics[prompt_type]["count"] += 1 + + return prompt_metrics + except Exception as e: + print(f"[red]Error getting costs by prompt type: {e}[/red]") + return {} + + +def get_prompt_type_statistics(trace_id: str) -> Dict[str, Dict[str, Any]]: + """Get detailed statistics for each prompt type. + + Args: + trace_id: The ID of the trace to analyze + + Returns: + Dict mapping prompt_type to { + 'cost': float, + 'input_tokens': int, + 'output_tokens': int, + 'count': int, + 'avg_cost_per_call': float, + 'avg_input_tokens': float, + 'avg_output_tokens': float, + 'percentage_of_total_cost': float + } + """ + try: + # Get basic metrics + prompt_metrics = get_costs_by_prompt_type(trace_id) + + # Calculate total cost for percentage calculation + total_cost = sum( + metrics["cost"] for metrics in prompt_metrics.values() + ) + + # Enhance with statistics + enhanced_metrics = {} + for prompt_type, metrics in prompt_metrics.items(): + count = metrics["count"] + enhanced_metrics[prompt_type] = { + "cost": metrics["cost"], + "input_tokens": metrics["input_tokens"], + "output_tokens": metrics["output_tokens"], + "count": count, + "avg_cost_per_call": metrics["cost"] / count + if count > 0 + else 0, + "avg_input_tokens": metrics["input_tokens"] / count + if count > 0 + else 0, + "avg_output_tokens": metrics["output_tokens"] / count + if count > 0 + else 0, + "percentage_of_total_cost": ( + metrics["cost"] / total_cost * 100 + ) + if total_cost > 0 + else 0, + } + + return enhanced_metrics + except Exception as e: + print(f"[red]Error getting prompt type statistics: {e}[/red]") + return {} + + +if __name__ == "__main__": + print( + "[bold cyan]ZenML Deep Research - Tracing Metadata Utilities Demo[/bold cyan]\n" + ) + + try: + # Fetch recent traces + print("[yellow]Fetching recent traces...[/yellow]") + traces = fetch_traces_safe(limit=5) + + if not traces: + print("[red]No traces found![/red]") + exit(1) + except ApiError as e: + if e.status_code == 429: + print("[red]Rate limit exceeded. Please try again later.[/red]") + print( + "[yellow]Tip: Consider upgrading your Langfuse tier for higher rate limits.[/yellow]" + ) + else: + print(f"[red]API Error: {e}[/red]") + exit(1) + except Exception as e: + print(f"[red]Error fetching traces: {e}[/red]") + exit(1) + + # Demo 1: Get stats for a single trace + print("\n[bold]1. Single Trace Statistics:[/bold]") + first_trace = traces[0] + stats = get_trace_stats(first_trace) + + console.print(f"Trace ID: [cyan]{stats['trace_id']}[/cyan]") + console.print(f"Timestamp: [yellow]{stats['timestamp']}[/yellow]") + console.print(f"Total Cost: [green]${stats['total_cost']:.4f}[/green]") + console.print( + f"Tokens - Input: [blue]{stats['input_tokens']:,}[/blue], Output: [blue]{stats['output_tokens']:,}[/blue]" + ) + console.print(f"Latency: [magenta]{stats['latency_formatted']}[/magenta]") + console.print(f"Observations: [white]{stats['observation_count']}[/white]") + console.print( + f"Models Used: [cyan]{', '.join(stats['models_used'])}[/cyan]" + ) + + # Demo 2: Get traces from last 7 days + print("\n[bold]2. Traces from Last 7 Days:[/bold]") + recent_traces = get_traces_last_n_days(7, limit=10) + print( + f"Found [green]{len(recent_traces)}[/green] traces in the last 7 days" + ) + + if recent_traces: + display_trace_stats_table(recent_traces, "Last 7 Days Traces") + + # Demo 3: Filter traces by date range + print("\n[bold]3. Filter Traces by Date Range:[/bold]") + end_date = datetime.now(timezone.utc) + start_date = end_date - timedelta(days=3) + + filtered_traces = filter_traces_by_date_range(start_date, end_date) + print( + f"Found [green]{len(filtered_traces)}[/green] traces between {start_date.strftime('%Y-%m-%d')} and {end_date.strftime('%Y-%m-%d')}" + ) + + # Demo 4: Aggregate statistics + print("\n[bold]4. Aggregate Statistics for All Recent Traces:[/bold]") + agg_stats = get_aggregate_stats_for_traces(traces) + + table = Table( + title="Aggregate Statistics", + show_header=True, + header_style="bold magenta", + ) + table.add_column("Metric", style="cyan") + table.add_column("Value", justify="right", style="yellow") + + table.add_row("Total Traces", str(agg_stats["trace_count"])) + table.add_row("Total Cost", f"${agg_stats['total_cost']:.4f}") + table.add_row( + "Average Cost per Trace", f"${agg_stats['average_cost_per_trace']:.4f}" + ) + table.add_row("Total Input Tokens", f"{agg_stats['total_input_tokens']:,}") + table.add_row( + "Total Output Tokens", f"{agg_stats['total_output_tokens']:,}" + ) + table.add_row("Total Tokens", f"{agg_stats['total_tokens']:,}") + table.add_row( + "Average Latency", f"{agg_stats['average_latency_seconds']:.1f}s" + ) + table.add_row("Total Observations", str(agg_stats["total_observations"])) + + console.print(table) + + # Demo 5: Cost breakdown by observation + print("\n[bold]5. Cost Breakdown for First Trace:[/bold]") + observations = fetch_observations_safe(trace_id=first_trace.id) + + if observations: + table = Table( + title="Observation Cost Breakdown", + show_header=True, + header_style="bold magenta", + ) + table.add_column("Observation", style="cyan", no_wrap=True) + table.add_column("Model", style="yellow") + table.add_column("Tokens (In/Out)", justify="right") + table.add_column("Cost", justify="right", style="green") + + for i, obs in enumerate(observations[:5]): # Show first 5 + cost = 0.0 + if hasattr(obs, "calculated_total_cost"): + cost = obs.calculated_total_cost or 0.0 + + in_tokens = 0 + out_tokens = 0 + if hasattr(obs, "usage") and obs.usage: + in_tokens = obs.usage.input or 0 + out_tokens = obs.usage.output or 0 + elif hasattr(obs, "promptTokens"): + in_tokens = obs.promptTokens or 0 + out_tokens = obs.completionTokens or 0 + + table.add_row( + f"Obs {i + 1}", + obs.model if hasattr(obs, "model") else "Unknown", + f"{in_tokens:,}/{out_tokens:,}", + f"${cost:.4f}", + ) + + console.print(table) From 142f2aa19d3806ab01ab6a115a9bb1ccde0edf2f Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Tue, 27 May 2025 22:09:47 +0200 Subject: [PATCH 03/11] Refactor PromptBundle into individual prompts --- deep_research/materializers/__init__.py | 4 +- .../materializers/prompt_materializer.py | 390 ++++++++++++++ .../materializers/prompts_materializer.py | 509 ------------------ .../pipelines/parallel_research_pipeline.py | 29 +- deep_research/run.py | 26 +- deep_research/steps/cross_viewpoint_step.py | 13 +- .../steps/execute_approved_searches_step.py | 13 +- .../steps/generate_reflection_step.py | 12 +- .../steps/initialize_prompts_step.py | 125 ++++- .../steps/process_sub_question_step.py | 21 +- .../steps/pydantic_final_report_step.py | 78 ++- .../steps/query_decomposition_step.py | 13 +- deep_research/tests/test_prompt_loader.py | 108 ---- deep_research/tests/test_prompt_models.py | 191 ++----- deep_research/utils/prompt_loader.py | 136 ----- deep_research/utils/prompt_models.py | 98 +--- deep_research/utils/pydantic_models.py | 62 +++ 17 files changed, 717 insertions(+), 1111 deletions(-) create mode 100644 deep_research/materializers/prompt_materializer.py delete mode 100644 deep_research/materializers/prompts_materializer.py delete mode 100644 deep_research/tests/test_prompt_loader.py delete mode 100644 deep_research/utils/prompt_loader.py diff --git a/deep_research/materializers/__init__.py b/deep_research/materializers/__init__.py index 1479f72b..260eedb3 100644 --- a/deep_research/materializers/__init__.py +++ b/deep_research/materializers/__init__.py @@ -7,14 +7,14 @@ """ from .approval_decision_materializer import ApprovalDecisionMaterializer -from .prompts_materializer import PromptsBundleMaterializer +from .prompt_materializer import PromptMaterializer from .pydantic_materializer import ResearchStateMaterializer from .reflection_output_materializer import ReflectionOutputMaterializer from .tracing_metadata_materializer import TracingMetadataMaterializer __all__ = [ "ApprovalDecisionMaterializer", - "PromptsBundleMaterializer", + "PromptMaterializer", "ReflectionOutputMaterializer", "ResearchStateMaterializer", "TracingMetadataMaterializer", diff --git a/deep_research/materializers/prompt_materializer.py b/deep_research/materializers/prompt_materializer.py new file mode 100644 index 00000000..835c53a6 --- /dev/null +++ b/deep_research/materializers/prompt_materializer.py @@ -0,0 +1,390 @@ +"""Materializer for individual Prompt with custom HTML visualization. + +This module provides a materializer that creates beautiful HTML visualizations +for individual prompts in the ZenML dashboard. +""" + +import os +from typing import Dict + +from utils.pydantic_models import Prompt +from zenml.enums import ArtifactType, VisualizationType +from zenml.io import fileio +from zenml.materializers import PydanticMaterializer + + +class PromptMaterializer(PydanticMaterializer): + """Materializer for Prompt with custom visualization.""" + + ASSOCIATED_TYPES = (Prompt,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + + def save_visualizations( + self, data: Prompt + ) -> Dict[str, VisualizationType]: + """Create and save visualizations for the Prompt. + + Args: + data: The Prompt to visualize + + Returns: + Dictionary mapping file paths to visualization types + """ + # Generate an HTML visualization + visualization_path = os.path.join(self.uri, "prompt.html") + + # Create HTML content + html_content = self._generate_visualization_html(data) + + # Write the HTML content to a file + with fileio.open(visualization_path, "w") as f: + f.write(html_content) + + # Return the visualization path and type + return {visualization_path: VisualizationType.HTML} + + def _generate_visualization_html(self, prompt: Prompt) -> str: + """Generate HTML visualization for a single prompt. + + Args: + prompt: The Prompt to visualize + + Returns: + HTML string + """ + # Determine tag colors + tag_html = "" + if prompt.tags: + tag_colors = { + "search": "search", + "synthesis": "synthesis", + "analysis": "analysis", + "reflection": "reflection", + "report": "report", + "query": "query", + "decomposition": "decomposition", + "viewpoint": "viewpoint", + "conclusion": "conclusion", + "summary": "summary", + "introduction": "introduction", + } + + tag_html = '
      ' + for tag in prompt.tags: + tag_class = tag_colors.get(tag, "default") + tag_html += f'{tag}' + tag_html += "
      " + + # Create HTML content + html = f""" + + + + {prompt.name} - Prompt + + + +
      +
      +

      + 🎯 {prompt.name} + v{prompt.version} +

      + {f'

      {prompt.description}

      ' if prompt.description else ""} + {tag_html} +
      + +
      +
      + {len(prompt.content.split())} + Words +
      +
      + {len(prompt.content)} + Characters +
      +
      + {len(prompt.content.splitlines())} + Lines +
      +
      + +
      +

      📝 Prompt Content

      +
      + + {self._escape_html(prompt.content)} +
      +
      +
      + + + + + """ + + return html + + def _escape_html(self, text: str) -> str: + """Escape HTML special characters. + + Args: + text: Text to escape + + Returns: + Escaped text + """ + return ( + text.replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace('"', """) + .replace("'", "'") + ) diff --git a/deep_research/materializers/prompts_materializer.py b/deep_research/materializers/prompts_materializer.py deleted file mode 100644 index 96e35dde..00000000 --- a/deep_research/materializers/prompts_materializer.py +++ /dev/null @@ -1,509 +0,0 @@ -"""Materializer for PromptsBundle with custom HTML visualization. - -This module provides a materializer that creates beautiful HTML visualizations -for prompt bundles in the ZenML dashboard. -""" - -import os -from typing import Dict - -from utils.prompt_models import PromptsBundle -from zenml.enums import ArtifactType, VisualizationType -from zenml.io import fileio -from zenml.materializers import PydanticMaterializer - - -class PromptsBundleMaterializer(PydanticMaterializer): - """Materializer for PromptsBundle with custom visualization.""" - - ASSOCIATED_TYPES = (PromptsBundle,) - ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA - - def save_visualizations( - self, data: PromptsBundle - ) -> Dict[str, VisualizationType]: - """Create and save visualizations for the PromptsBundle. - - Args: - data: The PromptsBundle to visualize - - Returns: - Dictionary mapping file paths to visualization types - """ - # Generate an HTML visualization - visualization_path = os.path.join(self.uri, "prompts_bundle.html") - - # Create HTML content - html_content = self._generate_visualization_html(data) - - # Write the HTML content to a file - with fileio.open(visualization_path, "w") as f: - f.write(html_content) - - # Return the visualization path and type - return {visualization_path: VisualizationType.HTML} - - def _generate_visualization_html(self, bundle: PromptsBundle) -> str: - """Generate HTML visualization for the prompts bundle. - - Args: - bundle: The PromptsBundle to visualize - - Returns: - HTML string - """ - # Create HTML content - html = f""" - - - - Prompts Bundle - {bundle.created_at} - - - -
      -
      -

      🎯 Prompts Bundle

      - -
      -
      - {len(bundle.list_all_prompts())} - Total Prompts -
      -
      - {len([p for p in bundle.list_all_prompts().values() if p.tags])} - Tagged Prompts -
      -
      - {len(bundle.custom_prompts)} - Custom Prompts -
      -
      -
      - - - -
      - """ - - # Add each prompt - prompts = [ - ( - "search_query_prompt", - bundle.search_query_prompt, - ["search", "query"], - ), - ( - "query_decomposition_prompt", - bundle.query_decomposition_prompt, - ["analysis", "decomposition"], - ), - ( - "synthesis_prompt", - bundle.synthesis_prompt, - ["synthesis", "integration"], - ), - ( - "viewpoint_analysis_prompt", - bundle.viewpoint_analysis_prompt, - ["analysis", "viewpoint"], - ), - ( - "reflection_prompt", - bundle.reflection_prompt, - ["reflection", "critique"], - ), - ( - "additional_synthesis_prompt", - bundle.additional_synthesis_prompt, - ["synthesis", "enhancement"], - ), - ( - "conclusion_generation_prompt", - bundle.conclusion_generation_prompt, - ["report", "conclusion"], - ), - ] - - for prompt_type, prompt, default_tags in prompts: - # Use provided tags or default tags - tags = prompt.tags if prompt.tags else default_tags - - html += f""" -
      -
      -

      {prompt.name}

      - v{prompt.version} -
      - {f'

      {prompt.description}

      ' if prompt.description else ""} -
      - {"".join([f'{tag}' for tag in tags])} -
      -
      - - {self._escape_html(prompt.content)} -
      - -
      - """ - - # Add custom prompts if any - for name, prompt in bundle.custom_prompts.items(): - tags = prompt.tags if prompt.tags else ["custom"] - html += f""" -
      -
      -

      {prompt.name}

      - v{prompt.version} -
      - {f'

      {prompt.description}

      ' if prompt.description else ""} -
      - {"".join([f'{tag}' for tag in tags])} -
      -
      - - {self._escape_html(prompt.content)} -
      - -
      - """ - - html += """ -
      - -
      - - - - - """ - - return html - - def _escape_html(self, text: str) -> str: - """Escape HTML special characters. - - Args: - text: Text to escape - - Returns: - Escaped text - """ - return ( - text.replace("&", "&") - .replace("<", "<") - .replace(">", ">") - .replace('"', """) - .replace("'", "'") - ) diff --git a/deep_research/pipelines/parallel_research_pipeline.py b/deep_research/pipelines/parallel_research_pipeline.py index bd7afe14..669b4824 100644 --- a/deep_research/pipelines/parallel_research_pipeline.py +++ b/deep_research/pipelines/parallel_research_pipeline.py @@ -43,8 +43,18 @@ def parallelized_deep_research_pipeline( Returns: Formatted research report as HTML """ - # Initialize prompts bundle for tracking - prompts_bundle = initialize_prompts_step(pipeline_version="1.0.0") + # Initialize individual prompts for tracking + ( + search_query_prompt, + query_decomposition_prompt, + synthesis_prompt, + viewpoint_analysis_prompt, + reflection_prompt, + additional_synthesis_prompt, + conclusion_generation_prompt, + executive_summary_prompt, + introduction_prompt, + ) = initialize_prompts_step(pipeline_version="1.0.0") # Initialize the research state with the main query state = ResearchState(main_query=query) @@ -52,7 +62,7 @@ def parallelized_deep_research_pipeline( # Step 1: Decompose the query into sub-questions, limiting to max_sub_questions decomposed_state = initial_query_decomposition_step( state=state, - prompts_bundle=prompts_bundle, + query_decomposition_prompt=query_decomposition_prompt, max_sub_questions=max_sub_questions, langfuse_project_name=langfuse_project_name, ) @@ -64,7 +74,8 @@ def parallelized_deep_research_pipeline( # Process the i-th sub-question (if it exists) sub_state = process_sub_question_step( state=decomposed_state, - prompts_bundle=prompts_bundle, + search_query_prompt=search_query_prompt, + synthesis_prompt=synthesis_prompt, question_index=i, search_provider=search_provider, search_mode=search_mode, @@ -87,7 +98,7 @@ def parallelized_deep_research_pipeline( # Continue with subsequent steps analyzed_state = cross_viewpoint_analysis_step( state=merged_state, - prompts_bundle=prompts_bundle, + viewpoint_analysis_prompt=viewpoint_analysis_prompt, langfuse_project_name=langfuse_project_name, ) @@ -95,7 +106,7 @@ def parallelized_deep_research_pipeline( # Step 1: Generate reflection and recommendations (no searches yet) reflection_output = generate_reflection_step( state=analyzed_state, - prompts_bundle=prompts_bundle, + reflection_prompt=reflection_prompt, langfuse_project_name=langfuse_project_name, ) @@ -111,7 +122,7 @@ def parallelized_deep_research_pipeline( reflected_state = execute_approved_searches_step( reflection_output=reflection_output, approval_decision=approval_decision, - prompts_bundle=prompts_bundle, + additional_synthesis_prompt=additional_synthesis_prompt, search_provider=search_provider, search_mode=search_mode, num_results_per_search=num_results_per_search, @@ -122,7 +133,9 @@ def parallelized_deep_research_pipeline( # This returns a tuple (state, html_report) final_state, final_report = pydantic_final_report_step( state=reflected_state, - prompts_bundle=prompts_bundle, + conclusion_generation_prompt=conclusion_generation_prompt, + executive_summary_prompt=executive_summary_prompt, + introduction_prompt=introduction_prompt, langfuse_project_name=langfuse_project_name, ) diff --git a/deep_research/run.py b/deep_research/run.py index aa1745c3..76c6499c 100644 --- a/deep_research/run.py +++ b/deep_research/run.py @@ -148,18 +148,18 @@ help="Number of search results to return per query (default: 3)", ) def main( - mode: str = None, - config: str = "configs/enhanced_research.yaml", - no_cache: bool = False, - log_file: str = None, - debug: bool = False, - query: str = None, - max_sub_questions: int = 10, - require_approval: bool = False, - approval_timeout: int = 3600, - search_provider: str = None, - search_mode: str = "auto", - num_results: int = 3, + mode, + config, + no_cache, + log_file, + debug, + query, + max_sub_questions, + require_approval, + approval_timeout, + search_provider, + search_mode, + num_results, ): """Run the deep research pipeline. @@ -184,7 +184,7 @@ def main( # Apply mode presets if specified if mode: mode_config = RESEARCH_MODES[mode.lower()] - logger.info(f"\n{'=' * 80}") + logger.info("\n" + "=" * 80) logger.info(f"Using research mode: {mode.upper()}") logger.info(f"Description: {mode_config['description']}") diff --git a/deep_research/steps/cross_viewpoint_step.py b/deep_research/steps/cross_viewpoint_step.py index ad9c5ddb..506df12c 100644 --- a/deep_research/steps/cross_viewpoint_step.py +++ b/deep_research/steps/cross_viewpoint_step.py @@ -8,8 +8,8 @@ safe_json_loads, ) from utils.llm_utils import run_llm_completion -from utils.prompt_models import PromptsBundle from utils.pydantic_models import ( + Prompt, ResearchState, ViewpointAnalysis, ViewpointTension, @@ -22,7 +22,7 @@ @step(output_materializers=ResearchStateMaterializer) def cross_viewpoint_analysis_step( state: ResearchState, - prompts_bundle: PromptsBundle, + viewpoint_analysis_prompt: Prompt, llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", viewpoint_categories: List[str] = [ "scientific", @@ -38,7 +38,7 @@ def cross_viewpoint_analysis_step( Args: state: The current research state - prompts_bundle: Bundle containing all prompts for the pipeline + viewpoint_analysis_prompt: Prompt for viewpoint analysis llm_model: The model to use for viewpoint analysis viewpoint_categories: Categories of viewpoints to analyze @@ -69,15 +69,10 @@ def cross_viewpoint_analysis_step( # Perform viewpoint analysis try: logger.info(f"Calling {llm_model} for viewpoint analysis") - # Get the prompt from the bundle - system_prompt = prompts_bundle.get_prompt_content( - "viewpoint_analysis_prompt" - ) - # Use the run_llm_completion function from llm_utils content = run_llm_completion( prompt=json.dumps(analysis_input), - system_prompt=system_prompt, + system_prompt=str(viewpoint_analysis_prompt), model=llm_model, # Model name will be prefixed in the function max_tokens=3000, # Further increased for more comprehensive viewpoint analysis project=langfuse_project_name, diff --git a/deep_research/steps/execute_approved_searches_step.py b/deep_research/steps/execute_approved_searches_step.py index 90a15718..4c83be66 100644 --- a/deep_research/steps/execute_approved_searches_step.py +++ b/deep_research/steps/execute_approved_searches_step.py @@ -9,9 +9,9 @@ get_structured_llm_output, is_text_relevant, ) -from utils.prompt_models import PromptsBundle from utils.pydantic_models import ( ApprovalDecision, + Prompt, ReflectionMetadata, ReflectionOutput, ResearchState, @@ -43,7 +43,7 @@ def create_enhanced_info_copy(synthesized_info): def execute_approved_searches_step( reflection_output: ReflectionOutput, approval_decision: ApprovalDecision, - prompts_bundle: PromptsBundle, + additional_synthesis_prompt: Prompt, num_results_per_search: int = 3, cap_search_length: int = 20000, llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", @@ -221,17 +221,10 @@ def execute_approved_searches_step( ], } - # Get the prompt from the bundle - additional_synthesis_prompt = ( - prompts_bundle.get_prompt_content( - "additional_synthesis_prompt" - ) - ) - # Use the utility function for enhancement enhanced_synthesis = get_structured_llm_output( prompt=json.dumps(enhancement_input), - system_prompt=additional_synthesis_prompt, + system_prompt=str(additional_synthesis_prompt), model=llm_model, fallback_response={ "enhanced_synthesis": enhanced_info[ diff --git a/deep_research/steps/generate_reflection_step.py b/deep_research/steps/generate_reflection_step.py index e8547af6..9fbbe516 100644 --- a/deep_research/steps/generate_reflection_step.py +++ b/deep_research/steps/generate_reflection_step.py @@ -7,8 +7,7 @@ ReflectionOutputMaterializer, ) from utils.llm_utils import get_structured_llm_output -from utils.prompt_models import PromptsBundle -from utils.pydantic_models import ReflectionOutput, ResearchState +from utils.pydantic_models import Prompt, ReflectionOutput, ResearchState from zenml import log_metadata, step logger = logging.getLogger(__name__) @@ -17,7 +16,7 @@ @step(output_materializers=ReflectionOutputMaterializer) def generate_reflection_step( state: ResearchState, - prompts_bundle: PromptsBundle, + reflection_prompt: Prompt, llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", langfuse_project_name: str = "deep-research", ) -> Annotated[ReflectionOutput, "reflection_output"]: @@ -29,7 +28,7 @@ def generate_reflection_step( Args: state: The current research state - prompts_bundle: Bundle containing all prompts for the pipeline + reflection_prompt: Prompt for generating reflection llm_model: The model to use for reflection Returns: @@ -77,9 +76,6 @@ def generate_reflection_step( # Get reflection critique logger.info(f"Generating self-critique via {llm_model}") - # Get the prompt from the bundle - reflection_prompt = prompts_bundle.get_prompt_content("reflection_prompt") - # Define fallback for reflection result fallback_reflection = { "critique": [], @@ -90,7 +86,7 @@ def generate_reflection_step( # Use utility function to get structured output reflection_result = get_structured_llm_output( prompt=json.dumps(reflection_input), - system_prompt=reflection_prompt, + system_prompt=str(reflection_prompt), model=llm_model, fallback_response=fallback_reflection, project=langfuse_project_name, diff --git a/deep_research/steps/initialize_prompts_step.py b/deep_research/steps/initialize_prompts_step.py index 377c6df4..fbb5eed8 100644 --- a/deep_research/steps/initialize_prompts_step.py +++ b/deep_research/steps/initialize_prompts_step.py @@ -1,45 +1,132 @@ -"""Step to initialize and track prompts as artifacts. +"""Step to initialize and track prompts as individual artifacts. -This step creates a PromptsBundle artifact at the beginning of the pipeline, +This step creates individual Prompt artifacts at the beginning of the pipeline, making all prompts trackable and versioned in ZenML. """ import logging -from typing import Annotated +from typing import Annotated, Tuple -from materializers.prompts_materializer import PromptsBundleMaterializer -from utils.prompt_loader import load_prompts_bundle -from utils.prompt_models import PromptsBundle +from materializers.prompt_materializer import PromptMaterializer +from utils import prompts +from utils.pydantic_models import Prompt from zenml import step logger = logging.getLogger(__name__) -@step(output_materializers=PromptsBundleMaterializer) +@step(output_materializers=PromptMaterializer) def initialize_prompts_step( pipeline_version: str = "1.1.0", -) -> Annotated[PromptsBundle, "prompts_bundle"]: - """Initialize the prompts bundle for the pipeline. +) -> Tuple[ + Annotated[Prompt, "search_query_prompt"], + Annotated[Prompt, "query_decomposition_prompt"], + Annotated[Prompt, "synthesis_prompt"], + Annotated[Prompt, "viewpoint_analysis_prompt"], + Annotated[Prompt, "reflection_prompt"], + Annotated[Prompt, "additional_synthesis_prompt"], + Annotated[Prompt, "conclusion_generation_prompt"], + Annotated[Prompt, "executive_summary_prompt"], + Annotated[Prompt, "introduction_prompt"], +]: + """Initialize individual prompts for the pipeline. This step loads all prompts from the prompts.py module and creates - a PromptsBundle artifact that can be tracked and visualized in ZenML. + individual Prompt artifacts that can be tracked and visualized in ZenML. Args: pipeline_version: Version of the pipeline using these prompts Returns: - PromptsBundle containing all prompts used in the pipeline + Tuple of individual Prompt artifacts used in the pipeline """ logger.info( - f"Initializing prompts bundle for pipeline version {pipeline_version}" + f"Initializing prompts for pipeline version {pipeline_version}" ) - # Load all prompts into a bundle - prompts_bundle = load_prompts_bundle(pipeline_version=pipeline_version) + # Create individual prompt instances + search_query_prompt = Prompt( + content=prompts.DEFAULT_SEARCH_QUERY_PROMPT, + name="search_query_prompt", + description="Generates effective search queries from sub-questions", + version="1.0.0", + tags=["search", "query", "information-gathering"], + ) + + query_decomposition_prompt = Prompt( + content=prompts.QUERY_DECOMPOSITION_PROMPT, + name="query_decomposition_prompt", + description="Breaks down complex research queries into specific sub-questions", + version="1.0.0", + tags=["analysis", "decomposition", "planning"], + ) + + synthesis_prompt = Prompt( + content=prompts.SYNTHESIS_PROMPT, + name="synthesis_prompt", + description="Synthesizes search results into comprehensive answers for sub-questions", + version="1.1.0", + tags=["synthesis", "integration", "analysis"], + ) + + viewpoint_analysis_prompt = Prompt( + content=prompts.VIEWPOINT_ANALYSIS_PROMPT, + name="viewpoint_analysis_prompt", + description="Analyzes synthesized answers across different perspectives and viewpoints", + version="1.1.0", + tags=["analysis", "viewpoint", "perspective"], + ) + + reflection_prompt = Prompt( + content=prompts.REFLECTION_PROMPT, + name="reflection_prompt", + description="Evaluates research and identifies gaps, biases, and areas for improvement", + version="1.0.0", + tags=["reflection", "critique", "improvement"], + ) - # Log some statistics - all_prompts = prompts_bundle.list_all_prompts() - logger.info(f"Loaded {len(all_prompts)} prompts into bundle") - logger.info(f"Prompts: {', '.join(all_prompts.keys())}") + additional_synthesis_prompt = Prompt( + content=prompts.ADDITIONAL_SYNTHESIS_PROMPT, + name="additional_synthesis_prompt", + description="Enhances original synthesis with new information and addresses critique points", + version="1.1.0", + tags=["synthesis", "enhancement", "integration"], + ) + + conclusion_generation_prompt = Prompt( + content=prompts.CONCLUSION_GENERATION_PROMPT, + name="conclusion_generation_prompt", + description="Synthesizes all research findings into a comprehensive conclusion", + version="1.0.0", + tags=["report", "conclusion", "synthesis"], + ) + + executive_summary_prompt = Prompt( + content=prompts.EXECUTIVE_SUMMARY_GENERATION_PROMPT, + name="executive_summary_prompt", + description="Creates a compelling, insight-driven executive summary", + version="1.1.0", + tags=["report", "summary", "insights"], + ) - return prompts_bundle + introduction_prompt = Prompt( + content=prompts.INTRODUCTION_GENERATION_PROMPT, + name="introduction_prompt", + description="Creates a contextual, engaging introduction", + version="1.1.0", + tags=["report", "introduction", "context"], + ) + + logger.info(f"Loaded 9 individual prompts") + + return ( + search_query_prompt, + query_decomposition_prompt, + synthesis_prompt, + viewpoint_analysis_prompt, + reflection_prompt, + additional_synthesis_prompt, + conclusion_generation_prompt, + executive_summary_prompt, + introduction_prompt, + ) diff --git a/deep_research/steps/process_sub_question_step.py b/deep_research/steps/process_sub_question_step.py index f6b077c7..cfc309c2 100644 --- a/deep_research/steps/process_sub_question_step.py +++ b/deep_research/steps/process_sub_question_step.py @@ -12,8 +12,7 @@ from materializers.pydantic_materializer import ResearchStateMaterializer from utils.llm_utils import synthesize_information -from utils.prompt_models import PromptsBundle -from utils.pydantic_models import ResearchState, SynthesizedInfo +from utils.pydantic_models import Prompt, ResearchState, SynthesizedInfo from utils.search_utils import ( generate_search_query, search_and_extract_results, @@ -26,7 +25,8 @@ @step(output_materializers=ResearchStateMaterializer) def process_sub_question_step( state: ResearchState, - prompts_bundle: PromptsBundle, + search_query_prompt: Prompt, + synthesis_prompt: Prompt, question_index: int, llm_model_search: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", llm_model_synthesis: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", @@ -43,7 +43,8 @@ def process_sub_question_step( Args: state: The original research state with all sub-questions - prompts_bundle: Bundle containing all prompts for the pipeline + search_query_prompt: Prompt for generating search queries + synthesis_prompt: Prompt for synthesizing search results question_index: The index of the sub-question to process llm_model_search: Model to use for search query generation llm_model_synthesis: Model to use for synthesis @@ -100,14 +101,11 @@ def process_sub_question_step( # === INFORMATION GATHERING === search_phase_start = time.time() - # Generate search query with prompt from bundle - search_query_prompt = prompts_bundle.get_prompt_content( - "search_query_prompt" - ) + # Generate search query with prompt search_query_data = generate_search_query( sub_question=sub_question, model=llm_model_search, - system_prompt=search_query_prompt, + system_prompt=str(search_query_prompt), project=langfuse_project_name, ) search_query = search_query_data.get( @@ -175,12 +173,11 @@ def process_sub_question_step( "sources": sources, } - # Synthesize information with prompt from bundle - synthesis_prompt = prompts_bundle.get_prompt_content("synthesis_prompt") + # Synthesize information with prompt synthesis_result = synthesize_information( synthesis_input=synthesis_input, model=llm_model_synthesis, - system_prompt=synthesis_prompt, + system_prompt=str(synthesis_prompt), project=langfuse_project_name, ) diff --git a/deep_research/steps/pydantic_final_report_step.py b/deep_research/steps/pydantic_final_report_step.py index dc05b1bb..1c23f49c 100644 --- a/deep_research/steps/pydantic_final_report_step.py +++ b/deep_research/steps/pydantic_final_report_step.py @@ -17,13 +17,12 @@ remove_reasoning_from_output, ) from utils.llm_utils import run_llm_completion -from utils.prompt_models import PromptsBundle from utils.prompts import ( STATIC_HTML_TEMPLATE, SUB_QUESTION_TEMPLATE, VIEWPOINT_ANALYSIS_TEMPLATE, ) -from utils.pydantic_models import ResearchState +from utils.pydantic_models import Prompt, ResearchState from zenml import log_metadata, step from zenml.types import HTMLString @@ -139,7 +138,7 @@ def code_block_replace(match): def generate_executive_summary( state: ResearchState, - prompts_bundle: PromptsBundle, + executive_summary_prompt: Prompt, llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", langfuse_project_name: str = "deep-research", ) -> str: @@ -147,7 +146,7 @@ def generate_executive_summary( Args: state: The current research state - prompts_bundle: Bundle containing all prompts for the pipeline + executive_summary_prompt: Prompt for generating executive summary llm_model: The model to use for generation langfuse_project_name: Name of the Langfuse project for tracking @@ -179,24 +178,19 @@ def generate_executive_summary( for tension in state.viewpoint_analysis.areas_of_tension[:2]: context += f"- {tension.topic}\n" - # Get the executive summary prompt + # Use the executive summary prompt try: - executive_summary_prompt = prompts_bundle.get_prompt_content( - "executive_summary_prompt" - ) + executive_summary_prompt_str = str(executive_summary_prompt) logger.info("Successfully retrieved executive_summary_prompt") except Exception as e: logger.error(f"Failed to get executive_summary_prompt: {e}") - logger.info( - f"Available prompts: {list(prompts_bundle.list_all_prompts().keys())}" - ) return generate_fallback_executive_summary(state) try: # Call LLM to generate executive summary result = run_llm_completion( prompt=context, - system_prompt=executive_summary_prompt, + system_prompt=executive_summary_prompt_str, model=llm_model, temperature=0.7, max_tokens=800, @@ -221,7 +215,7 @@ def generate_executive_summary( def generate_introduction( state: ResearchState, - prompts_bundle: PromptsBundle, + introduction_prompt: Prompt, llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", langfuse_project_name: str = "deep-research", ) -> str: @@ -229,7 +223,7 @@ def generate_introduction( Args: state: The current research state - prompts_bundle: Bundle containing all prompts for the pipeline + introduction_prompt: Prompt for generating introduction llm_model: The model to use for generation langfuse_project_name: Name of the Langfuse project for tracking @@ -246,22 +240,17 @@ def generate_introduction( # Get the introduction prompt try: - introduction_prompt = prompts_bundle.get_prompt_content( - "introduction_prompt" - ) + introduction_prompt_str = str(introduction_prompt) logger.info("Successfully retrieved introduction_prompt") except Exception as e: logger.error(f"Failed to get introduction_prompt: {e}") - logger.info( - f"Available prompts: {list(prompts_bundle.list_all_prompts().keys())}" - ) return generate_fallback_introduction(state) try: # Call LLM to generate introduction result = run_llm_completion( prompt=context, - system_prompt=introduction_prompt, + system_prompt=introduction_prompt_str, model=llm_model, temperature=0.7, max_tokens=600, @@ -316,7 +305,7 @@ def generate_fallback_introduction(state: ResearchState) -> str: def generate_conclusion( state: ResearchState, - prompts_bundle: PromptsBundle, + conclusion_generation_prompt: Prompt, llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", langfuse_project_name: str = "deep-research", ) -> str: @@ -324,7 +313,7 @@ def generate_conclusion( Args: state: The ResearchState containing all research findings - prompts_bundle: Bundle containing all prompts for the pipeline + conclusion_generation_prompt: Prompt for generating conclusion llm_model: The model to use for conclusion generation Returns: @@ -382,15 +371,13 @@ def generate_conclusion( } try: - # Get the prompt from the bundle - conclusion_prompt = prompts_bundle.get_prompt_content( - "conclusion_generation_prompt" - ) + # Use the conclusion generation prompt + conclusion_prompt_str = str(conclusion_generation_prompt) # Generate conclusion using LLM conclusion_html = run_llm_completion( prompt=json.dumps(conclusion_input, indent=2), - system_prompt=conclusion_prompt, + system_prompt=conclusion_prompt_str, model=llm_model, clean_output=True, max_tokens=1500, # Sufficient for comprehensive conclusion @@ -439,7 +426,9 @@ def generate_conclusion( def generate_report_from_template( state: ResearchState, - prompts_bundle: PromptsBundle, + conclusion_generation_prompt: Prompt, + executive_summary_prompt: Prompt, + introduction_prompt: Prompt, llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", langfuse_project_name: str = "deep-research", ) -> str: @@ -450,7 +439,9 @@ def generate_report_from_template( Args: state: The current research state - prompts_bundle: Bundle containing all prompts for the pipeline + conclusion_generation_prompt: Prompt for generating conclusion + executive_summary_prompt: Prompt for generating executive summary + introduction_prompt: Prompt for generating introduction llm_model: The model to use for conclusion generation Returns: @@ -604,7 +595,7 @@ def generate_report_from_template( # Generate dynamic executive summary using LLM logger.info("Generating dynamic executive summary...") executive_summary = generate_executive_summary( - state, prompts_bundle, llm_model, langfuse_project_name + state, executive_summary_prompt, llm_model, langfuse_project_name ) logger.info( f"Executive summary generated: {len(executive_summary)} characters" @@ -613,13 +604,13 @@ def generate_report_from_template( # Generate dynamic introduction using LLM logger.info("Generating dynamic introduction...") introduction_html = generate_introduction( - state, prompts_bundle, llm_model, langfuse_project_name + state, introduction_prompt, llm_model, langfuse_project_name ) logger.info(f"Introduction generated: {len(introduction_html)} characters") # Generate comprehensive conclusion using LLM conclusion_html = generate_conclusion( - state, prompts_bundle, llm_model, langfuse_project_name + state, conclusion_generation_prompt, llm_model, langfuse_project_name ) # Generate complete HTML report @@ -993,7 +984,9 @@ def _generate_fallback_report(state: ResearchState) -> str: ) def pydantic_final_report_step( state: ResearchState, - prompts_bundle: PromptsBundle, + conclusion_generation_prompt: Prompt, + executive_summary_prompt: Prompt, + introduction_prompt: Prompt, use_static_template: bool = True, llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", langfuse_project_name: str = "deep-research", @@ -1009,7 +1002,9 @@ def pydantic_final_report_step( Args: state: The current research state (Pydantic model) - prompts_bundle: Bundle containing all prompts for the pipeline + conclusion_generation_prompt: Prompt for generating conclusions + executive_summary_prompt: Prompt for generating executive summary + introduction_prompt: Prompt for generating introduction use_static_template: Whether to use a static template instead of LLM generation llm_model: The model to use for report generation with provider prefix @@ -1023,7 +1018,12 @@ def pydantic_final_report_step( # Use the static HTML template approach logger.info("Using static HTML template for report generation") html_content = generate_report_from_template( - state, prompts_bundle, llm_model, langfuse_project_name + state, + conclusion_generation_prompt, + executive_summary_prompt, + introduction_prompt, + llm_model, + langfuse_project_name, ) # Update the state with the final report HTML @@ -1112,10 +1112,8 @@ def pydantic_final_report_step( try: logger.info(f"Calling {llm_model} to generate final report") - # Get the prompt from the bundle - report_prompt = prompts_bundle.get_prompt_content( - "report_generation_prompt" - ) + # Use a default report generation prompt + report_prompt = "Generate a comprehensive HTML research report based on the provided research data. Include proper HTML structure with sections for executive summary, introduction, findings, and conclusion." # Use the utility function to run LLM completion html_content = run_llm_completion( diff --git a/deep_research/steps/query_decomposition_step.py b/deep_research/steps/query_decomposition_step.py index 78e50b9f..48369cf2 100644 --- a/deep_research/steps/query_decomposition_step.py +++ b/deep_research/steps/query_decomposition_step.py @@ -4,8 +4,7 @@ from materializers.pydantic_materializer import ResearchStateMaterializer from utils.llm_utils import get_structured_llm_output -from utils.prompt_models import PromptsBundle -from utils.pydantic_models import ResearchState +from utils.pydantic_models import Prompt, ResearchState from zenml import log_metadata, step logger = logging.getLogger(__name__) @@ -14,7 +13,7 @@ @step(output_materializers=ResearchStateMaterializer) def initial_query_decomposition_step( state: ResearchState, - prompts_bundle: PromptsBundle, + query_decomposition_prompt: Prompt, llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", max_sub_questions: int = 8, langfuse_project_name: str = "deep-research", @@ -23,7 +22,7 @@ def initial_query_decomposition_step( Args: state: The current research state - prompts_bundle: Bundle containing all prompts for the pipeline + query_decomposition_prompt: Prompt for query decomposition llm_model: The reasoning model to use with provider prefix max_sub_questions: Maximum number of sub-questions to generate @@ -33,10 +32,8 @@ def initial_query_decomposition_step( start_time = time.time() logger.info(f"Decomposing research query: {state.main_query}") - # Get the prompt from the bundle - system_prompt = prompts_bundle.get_prompt_content( - "query_decomposition_prompt" - ) + # Get the prompt content + system_prompt = str(query_decomposition_prompt) try: # Call OpenAI API to decompose the query diff --git a/deep_research/tests/test_prompt_loader.py b/deep_research/tests/test_prompt_loader.py deleted file mode 100644 index fdc7b7da..00000000 --- a/deep_research/tests/test_prompt_loader.py +++ /dev/null @@ -1,108 +0,0 @@ -"""Unit tests for prompt loader utilities.""" - -import pytest -from utils.prompt_loader import get_prompt_for_step, load_prompts_bundle -from utils.prompt_models import PromptsBundle - - -class TestPromptLoader: - """Test cases for prompt loader functions.""" - - def test_load_prompts_bundle(self): - """Test loading prompts bundle from prompts.py.""" - bundle = load_prompts_bundle(pipeline_version="2.0.0") - - # Check it returns a PromptsBundle - assert isinstance(bundle, PromptsBundle) - - # Check pipeline version - assert bundle.pipeline_version == "2.0.0" - - # Check all core prompts are loaded - assert bundle.search_query_prompt is not None - assert bundle.query_decomposition_prompt is not None - assert bundle.synthesis_prompt is not None - assert bundle.viewpoint_analysis_prompt is not None - assert bundle.reflection_prompt is not None - assert bundle.additional_synthesis_prompt is not None - assert bundle.conclusion_generation_prompt is not None - - # Check prompts have correct metadata - assert bundle.search_query_prompt.name == "search_query_prompt" - assert ( - bundle.search_query_prompt.description - == "Generates effective search queries from sub-questions" - ) - assert bundle.search_query_prompt.version == "1.0.0" - assert "search" in bundle.search_query_prompt.tags - - # Check that actual prompt content is loaded - assert "search query" in bundle.search_query_prompt.content.lower() - assert "json schema" in bundle.search_query_prompt.content.lower() - - def test_load_prompts_bundle_default_version(self): - """Test loading prompts bundle with default version.""" - bundle = load_prompts_bundle() - assert bundle.pipeline_version == "1.0.0" - - def test_get_prompt_for_step(self): - """Test getting prompt content for specific steps.""" - bundle = load_prompts_bundle() - - # Test valid step names - test_cases = [ - ("query_decomposition", "query_decomposition_prompt"), - ("search_query_generation", "search_query_prompt"), - ("synthesis", "synthesis_prompt"), - ("viewpoint_analysis", "viewpoint_analysis_prompt"), - ("reflection", "reflection_prompt"), - ("additional_synthesis", "additional_synthesis_prompt"), - ("conclusion_generation", "conclusion_generation_prompt"), - ] - - for step_name, expected_prompt_attr in test_cases: - content = get_prompt_for_step(bundle, step_name) - expected_content = getattr(bundle, expected_prompt_attr).content - assert content == expected_content - - def test_get_prompt_for_step_invalid(self): - """Test getting prompt for invalid step name.""" - bundle = load_prompts_bundle() - - with pytest.raises( - ValueError, match="No prompt mapping found for step: invalid_step" - ): - get_prompt_for_step(bundle, "invalid_step") - - def test_all_prompts_have_content(self): - """Test that all loaded prompts have non-empty content.""" - bundle = load_prompts_bundle() - - all_prompts = bundle.list_all_prompts() - for name, prompt in all_prompts.items(): - assert prompt.content, f"Prompt {name} has empty content" - assert ( - len(prompt.content) > 50 - ), f"Prompt {name} content seems too short" - - def test_all_prompts_have_descriptions(self): - """Test that all loaded prompts have descriptions.""" - bundle = load_prompts_bundle() - - all_prompts = bundle.list_all_prompts() - for name, prompt in all_prompts.items(): - assert prompt.description, f"Prompt {name} has no description" - assert ( - len(prompt.description) > 10 - ), f"Prompt {name} description seems too short" - - def test_all_prompts_have_tags(self): - """Test that all loaded prompts have at least one tag.""" - bundle = load_prompts_bundle() - - all_prompts = bundle.list_all_prompts() - for name, prompt in all_prompts.items(): - assert prompt.tags, f"Prompt {name} has no tags" - assert ( - len(prompt.tags) >= 1 - ), f"Prompt {name} should have at least one tag" diff --git a/deep_research/tests/test_prompt_models.py b/deep_research/tests/test_prompt_models.py index f8b0519f..fcc437d4 100644 --- a/deep_research/tests/test_prompt_models.py +++ b/deep_research/tests/test_prompt_models.py @@ -1,7 +1,7 @@ """Unit tests for prompt models and utilities.""" -import pytest -from utils.prompt_models import PromptsBundle, PromptTemplate +from utils.prompt_models import PromptTemplate +from utils.pydantic_models import Prompt class TestPromptTemplate: @@ -36,148 +36,75 @@ def test_prompt_template_minimal(self): assert prompt.tags == [] -class TestPromptsBundle: - """Test cases for PromptsBundle model.""" - - @pytest.fixture - def sample_prompts(self): - """Create sample prompts for testing.""" - return { - "search_query_prompt": PromptTemplate( - name="search_query_prompt", - content="Search query content", - description="Generates search queries", - tags=["search"], - ), - "query_decomposition_prompt": PromptTemplate( - name="query_decomposition_prompt", - content="Query decomposition content", - description="Decomposes queries", - tags=["analysis"], - ), - "synthesis_prompt": PromptTemplate( - name="synthesis_prompt", - content="Synthesis content", - description="Synthesizes information", - tags=["synthesis"], - ), - "viewpoint_analysis_prompt": PromptTemplate( - name="viewpoint_analysis_prompt", - content="Viewpoint analysis content", - description="Analyzes viewpoints", - tags=["analysis"], - ), - "reflection_prompt": PromptTemplate( - name="reflection_prompt", - content="Reflection content", - description="Reflects on research", - tags=["reflection"], - ), - "additional_synthesis_prompt": PromptTemplate( - name="additional_synthesis_prompt", - content="Additional synthesis content", - description="Additional synthesis", - tags=["synthesis"], - ), - "conclusion_generation_prompt": PromptTemplate( - name="conclusion_generation_prompt", - content="Conclusion generation content", - description="Generates conclusions", - tags=["report"], - ), - } - - def test_prompts_bundle_creation(self, sample_prompts): - """Test creating a prompts bundle.""" - bundle = PromptsBundle(**sample_prompts) - - assert bundle.search_query_prompt.name == "search_query_prompt" - assert ( - bundle.query_decomposition_prompt.name - == "query_decomposition_prompt" - ) - assert bundle.pipeline_version == "1.0.0" - assert isinstance(bundle.created_at, str) - assert bundle.custom_prompts == {} - - def test_prompts_bundle_with_custom_prompts(self, sample_prompts): - """Test creating a prompts bundle with custom prompts.""" - custom_prompt = PromptTemplate( - name="custom_prompt", - content="Custom prompt content", - description="A custom prompt", - ) +class TestPrompt: + """Test cases for the new Prompt model.""" - bundle = PromptsBundle( - **sample_prompts, custom_prompts={"custom_prompt": custom_prompt} + def test_prompt_creation(self): + """Test creating a prompt with all fields.""" + prompt = Prompt( + name="test_prompt", + content="This is a test prompt", + description="A test prompt for unit testing", + version="1.0.0", + tags=["test", "unit"], ) - assert "custom_prompt" in bundle.custom_prompts - assert bundle.custom_prompts["custom_prompt"].name == "custom_prompt" - - def test_get_prompt_by_name(self, sample_prompts): - """Test retrieving prompts by name.""" - bundle = PromptsBundle(**sample_prompts) - - # Test getting a core prompt - prompt = bundle.get_prompt_by_name("search_query_prompt") - assert prompt is not None - assert prompt.name == "search_query_prompt" + assert prompt.name == "test_prompt" + assert prompt.content == "This is a test prompt" + assert prompt.description == "A test prompt for unit testing" + assert prompt.version == "1.0.0" + assert prompt.tags == ["test", "unit"] - # Test getting a non-existent prompt - prompt = bundle.get_prompt_by_name("non_existent") - assert prompt is None + def test_prompt_minimal(self): + """Test creating a prompt with minimal fields.""" + prompt = Prompt(name="minimal_prompt", content="Minimal content") - def test_get_prompt_by_name_custom(self, sample_prompts): - """Test retrieving custom prompts by name.""" - custom_prompt = PromptTemplate( - name="custom_prompt", content="Custom content" - ) + assert prompt.name == "minimal_prompt" + assert prompt.content == "Minimal content" + assert prompt.description == "" + assert prompt.version == "1.0.0" + assert prompt.tags == [] - bundle = PromptsBundle( - **sample_prompts, custom_prompts={"custom_prompt": custom_prompt} + def test_prompt_str_conversion(self): + """Test converting prompt to string returns content.""" + prompt = Prompt( + name="test_prompt", + content="This is the prompt content", + description="Test prompt", ) - prompt = bundle.get_prompt_by_name("custom_prompt") - assert prompt is not None - assert prompt.name == "custom_prompt" + assert str(prompt) == "This is the prompt content" - def test_list_all_prompts(self, sample_prompts): - """Test listing all prompts.""" - bundle = PromptsBundle(**sample_prompts) + def test_prompt_repr(self): + """Test prompt representation.""" + prompt = Prompt(name="test_prompt", content="Content", version="2.0.0") - all_prompts = bundle.list_all_prompts() - assert len(all_prompts) == 7 # 7 core prompts - assert "search_query_prompt" in all_prompts - assert "conclusion_generation_prompt" in all_prompts - - def test_list_all_prompts_with_custom(self, sample_prompts): - """Test listing all prompts including custom ones.""" - custom_prompt = PromptTemplate( - name="custom_prompt", content="Custom content" - ) + assert repr(prompt) == "Prompt(name='test_prompt', version='2.0.0')" - bundle = PromptsBundle( - **sample_prompts, custom_prompts={"custom_prompt": custom_prompt} + def test_prompt_create_factory(self): + """Test creating prompt using factory method.""" + prompt = Prompt.create( + content="Factory created prompt", + name="factory_prompt", + description="Created via factory", + version="1.1.0", + tags=["factory", "test"], ) - all_prompts = bundle.list_all_prompts() - assert len(all_prompts) == 8 # 7 core + 1 custom - assert "custom_prompt" in all_prompts + assert prompt.name == "factory_prompt" + assert prompt.content == "Factory created prompt" + assert prompt.description == "Created via factory" + assert prompt.version == "1.1.0" + assert prompt.tags == ["factory", "test"] - def test_get_prompt_content(self, sample_prompts): - """Test getting prompt content by type.""" - bundle = PromptsBundle(**sample_prompts) - - content = bundle.get_prompt_content("search_query_prompt") - assert content == "Search query content" - - content = bundle.get_prompt_content("synthesis_prompt") - assert content == "Synthesis content" - - def test_get_prompt_content_invalid(self, sample_prompts): - """Test getting prompt content with invalid type.""" - bundle = PromptsBundle(**sample_prompts) + def test_prompt_create_factory_minimal(self): + """Test creating prompt using factory method with minimal args.""" + prompt = Prompt.create( + content="Minimal factory prompt", name="minimal_factory" + ) - with pytest.raises(AttributeError): - bundle.get_prompt_content("invalid_prompt_type") + assert prompt.name == "minimal_factory" + assert prompt.content == "Minimal factory prompt" + assert prompt.description == "" + assert prompt.version == "1.0.0" + assert prompt.tags == [] diff --git a/deep_research/utils/prompt_loader.py b/deep_research/utils/prompt_loader.py deleted file mode 100644 index c2b3d2f7..00000000 --- a/deep_research/utils/prompt_loader.py +++ /dev/null @@ -1,136 +0,0 @@ -"""Utility functions for loading prompts into the PromptsBundle model. - -This module provides functions to create PromptsBundle instances from -the existing prompt definitions in prompts.py. -""" - -from utils import prompts -from utils.prompt_models import PromptsBundle, PromptTemplate - - -def load_prompts_bundle(pipeline_version: str = "1.2.0") -> PromptsBundle: - """Load all prompts from prompts.py into a PromptsBundle. - - Args: - pipeline_version: Version of the pipeline using these prompts - - Returns: - PromptsBundle containing all prompts - """ - # Create PromptTemplate instances for each prompt - search_query_prompt = PromptTemplate( - name="search_query_prompt", - content=prompts.DEFAULT_SEARCH_QUERY_PROMPT, - description="Generates effective search queries from sub-questions", - version="1.0.0", - tags=["search", "query", "information-gathering"], - ) - - query_decomposition_prompt = PromptTemplate( - name="query_decomposition_prompt", - content=prompts.QUERY_DECOMPOSITION_PROMPT, - description="Breaks down complex research queries into specific sub-questions", - version="1.0.0", - tags=["analysis", "decomposition", "planning"], - ) - - synthesis_prompt = PromptTemplate( - name="synthesis_prompt", - content=prompts.SYNTHESIS_PROMPT, - description="Synthesizes search results into comprehensive answers for sub-questions", - version="1.1.0", - tags=["synthesis", "integration", "analysis"], - ) - - viewpoint_analysis_prompt = PromptTemplate( - name="viewpoint_analysis_prompt", - content=prompts.VIEWPOINT_ANALYSIS_PROMPT, - description="Analyzes synthesized answers across different perspectives and viewpoints", - version="1.1.0", - tags=["analysis", "viewpoint", "perspective"], - ) - - reflection_prompt = PromptTemplate( - name="reflection_prompt", - content=prompts.REFLECTION_PROMPT, - description="Evaluates research and identifies gaps, biases, and areas for improvement", - version="1.0.0", - tags=["reflection", "critique", "improvement"], - ) - - additional_synthesis_prompt = PromptTemplate( - name="additional_synthesis_prompt", - content=prompts.ADDITIONAL_SYNTHESIS_PROMPT, - description="Enhances original synthesis with new information and addresses critique points", - version="1.1.0", - tags=["synthesis", "enhancement", "integration"], - ) - - conclusion_generation_prompt = PromptTemplate( - name="conclusion_generation_prompt", - content=prompts.CONCLUSION_GENERATION_PROMPT, - description="Synthesizes all research findings into a comprehensive conclusion", - version="1.0.0", - tags=["report", "conclusion", "synthesis"], - ) - - executive_summary_prompt = PromptTemplate( - name="executive_summary_prompt", - content=prompts.EXECUTIVE_SUMMARY_GENERATION_PROMPT, - description="Creates a compelling, insight-driven executive summary", - version="1.1.0", - tags=["report", "summary", "insights"], - ) - - introduction_prompt = PromptTemplate( - name="introduction_prompt", - content=prompts.INTRODUCTION_GENERATION_PROMPT, - description="Creates a contextual, engaging introduction", - version="1.1.0", - tags=["report", "introduction", "context"], - ) - - # Create and return the bundle - return PromptsBundle( - search_query_prompt=search_query_prompt, - query_decomposition_prompt=query_decomposition_prompt, - synthesis_prompt=synthesis_prompt, - viewpoint_analysis_prompt=viewpoint_analysis_prompt, - reflection_prompt=reflection_prompt, - additional_synthesis_prompt=additional_synthesis_prompt, - conclusion_generation_prompt=conclusion_generation_prompt, - executive_summary_prompt=executive_summary_prompt, - introduction_prompt=introduction_prompt, - pipeline_version=pipeline_version, - ) - - -def get_prompt_for_step(bundle: PromptsBundle, step_name: str) -> str: - """Get the appropriate prompt content for a specific step. - - Args: - bundle: The PromptsBundle containing all prompts - step_name: Name of the step requesting the prompt - - Returns: - The prompt content string - - Raises: - ValueError: If no prompt mapping exists for the step - """ - # Map step names to prompt attributes - step_to_prompt_mapping = { - "query_decomposition": "query_decomposition_prompt", - "search_query_generation": "search_query_prompt", - "synthesis": "synthesis_prompt", - "viewpoint_analysis": "viewpoint_analysis_prompt", - "reflection": "reflection_prompt", - "additional_synthesis": "additional_synthesis_prompt", - "conclusion_generation": "conclusion_generation_prompt", - } - - prompt_attr = step_to_prompt_mapping.get(step_name) - if not prompt_attr: - raise ValueError(f"No prompt mapping found for step: {step_name}") - - return bundle.get_prompt_content(prompt_attr) diff --git a/deep_research/utils/prompt_models.py b/deep_research/utils/prompt_models.py index c96a7bd0..00e08157 100644 --- a/deep_research/utils/prompt_models.py +++ b/deep_research/utils/prompt_models.py @@ -1,12 +1,9 @@ """Pydantic models for prompt tracking and management. -This module contains models for bundling prompts as trackable artifacts +This module contains models for tracking prompts as artifacts in the ZenML pipeline, enabling better observability and version control. """ -from datetime import datetime -from typing import Dict, Optional - from pydantic import BaseModel, Field @@ -28,96 +25,3 @@ class PromptTemplate(BaseModel): "frozen": False, "validate_assignment": True, } - - -class PromptsBundle(BaseModel): - """Bundle of all prompts used in the research pipeline. - - This model serves as a single artifact that contains all prompts, - making them trackable, versionable, and visualizable in the ZenML dashboard. - """ - - # Core prompts used in the pipeline - search_query_prompt: PromptTemplate - query_decomposition_prompt: PromptTemplate - synthesis_prompt: PromptTemplate - viewpoint_analysis_prompt: PromptTemplate - reflection_prompt: PromptTemplate - additional_synthesis_prompt: PromptTemplate - conclusion_generation_prompt: PromptTemplate - executive_summary_prompt: PromptTemplate - introduction_prompt: PromptTemplate - - # Metadata - pipeline_version: str = Field( - "1.0.0", description="Version of the pipeline using these prompts" - ) - created_at: str = Field( - default_factory=lambda: datetime.now().isoformat(), - description="Timestamp when this bundle was created", - ) - - # Additional prompts can be stored here - custom_prompts: Dict[str, PromptTemplate] = Field( - default_factory=dict, - description="Additional custom prompts not part of the core set", - ) - - model_config = { - "extra": "ignore", - "frozen": False, - "validate_assignment": True, - } - - def get_prompt_by_name(self, name: str) -> Optional[PromptTemplate]: - """Retrieve a prompt by its name. - - Args: - name: Name of the prompt to retrieve - - Returns: - PromptTemplate if found, None otherwise - """ - # Check core prompts - for field_name, field_value in self.__dict__.items(): - if ( - isinstance(field_value, PromptTemplate) - and field_value.name == name - ): - return field_value - - # Check custom prompts - return self.custom_prompts.get(name) - - def list_all_prompts(self) -> Dict[str, PromptTemplate]: - """Get all prompts as a dictionary. - - Returns: - Dictionary mapping prompt names to PromptTemplate objects - """ - all_prompts = {} - - # Add core prompts - for field_name, field_value in self.__dict__.items(): - if isinstance(field_value, PromptTemplate): - all_prompts[field_value.name] = field_value - - # Add custom prompts - all_prompts.update(self.custom_prompts) - - return all_prompts - - def get_prompt_content(self, prompt_type: str) -> str: - """Get the content of a specific prompt by its type. - - Args: - prompt_type: Type of prompt (e.g., 'search_query_prompt', 'synthesis_prompt') - - Returns: - The prompt content string - - Raises: - AttributeError: If prompt type doesn't exist - """ - prompt = getattr(self, prompt_type) - return prompt.content diff --git a/deep_research/utils/pydantic_models.py b/deep_research/utils/pydantic_models.py index 9fca23a3..02ecc3c0 100644 --- a/deep_research/utils/pydantic_models.py +++ b/deep_research/utils/pydantic_models.py @@ -12,6 +12,68 @@ from typing_extensions import Literal +class Prompt(BaseModel): + """A single prompt with metadata for tracking and visualization. + + This class is designed to be simple and intuitive to use. You can access + the prompt content directly via the content attribute or by converting + to string. + """ + + content: str = Field(..., description="The actual prompt text") + name: str = Field(..., description="Unique identifier for the prompt") + description: str = Field( + "", description="Human-readable description of what this prompt does" + ) + version: str = Field("1.0.0", description="Version of the prompt") + tags: List[str] = Field( + default_factory=list, description="Tags for categorizing the prompt" + ) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + def __str__(self) -> str: + """Return the prompt content as a string.""" + return self.content + + def __repr__(self) -> str: + """Return a readable representation of the prompt.""" + return f"Prompt(name='{self.name}', version='{self.version}')" + + @classmethod + def create( + cls, + content: str, + name: str, + description: str = "", + version: str = "1.0.0", + tags: Optional[List[str]] = None, + ) -> "Prompt": + """Factory method to create a Prompt instance. + + Args: + content: The prompt text + name: Unique identifier for the prompt + description: Optional description of the prompt's purpose + version: Version string (defaults to "1.0.0") + tags: Optional list of tags for categorization + + Returns: + A new Prompt instance + """ + return cls( + content=content, + name=name, + description=description, + version=version, + tags=tags or [], + ) + + class SearchResult(BaseModel): """Represents a search result for a sub-question.""" From 5d3bd7762254f4130ddd9aae4bd86f15a2d298c6 Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Tue, 27 May 2025 22:16:23 +0200 Subject: [PATCH 04/11] Fix the tracingmetadata visualization --- .../materializers/tracing_metadata_materializer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/deep_research/materializers/tracing_metadata_materializer.py b/deep_research/materializers/tracing_metadata_materializer.py index 7cf7b51b..62bfab2c 100644 --- a/deep_research/materializers/tracing_metadata_materializer.py +++ b/deep_research/materializers/tracing_metadata_materializer.py @@ -502,7 +502,9 @@ def _generate_visualization_html(self, metadata: TracingMetadata) -> str:

    Cost Breakdown Chart

    - +
    + +
    + + + """ + + return html diff --git a/deep_research/materializers/pydantic_materializer.py b/deep_research/materializers/pydantic_materializer.py deleted file mode 100644 index ee01281b..00000000 --- a/deep_research/materializers/pydantic_materializer.py +++ /dev/null @@ -1,764 +0,0 @@ -"""Pydantic materializer for research state objects. - -This module contains an extended version of ZenML's PydanticMaterializer -that adds visualization capabilities for the ResearchState model. -""" - -import os -from typing import Dict - -from utils.pydantic_models import ResearchState -from zenml.enums import ArtifactType, VisualizationType -from zenml.io import fileio -from zenml.materializers import PydanticMaterializer - - -class ResearchStateMaterializer(PydanticMaterializer): - """Materializer for the ResearchState class with visualizations.""" - - ASSOCIATED_TYPES = (ResearchState,) - ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA - - def save_visualizations( - self, data: ResearchState - ) -> Dict[str, VisualizationType]: - """Create and save visualizations for the ResearchState. - - Args: - data: The ResearchState to visualize - - Returns: - Dictionary mapping file paths to visualization types - """ - # Generate an HTML visualization - visualization_path = os.path.join(self.uri, "research_state.html") - - # Create HTML content based on current stage - html_content = self._generate_visualization_html(data) - - # Write the HTML content to a file - with fileio.open(visualization_path, "w") as f: - f.write(html_content) - - # Return the visualization path and type - return {visualization_path: VisualizationType.HTML} - - def _generate_visualization_html(self, state: ResearchState) -> str: - """Generate HTML visualization for the research state. - - Args: - state: The ResearchState to visualize - - Returns: - HTML string - """ - # Base structure for the HTML - html = f""" - - - - Research State: {state.main_query} - - - - -
    -

    Research State

    - - -
    -
    Initial Query
    -
    Query Decomposition
    -
    Information Gathering
    -
    Information Synthesis
    -
    Viewpoint Analysis
    -
    Reflection & Enhancement
    -
    Final Report
    -
    - - -
    -
    -
    - - -
      - """ - - # Determine which tab should be active based on current stage - current_stage = state.get_current_stage() - - # Map stages to tabs - stage_to_tab = { - "empty": "overview", - "initial": "overview", - "after_query_decomposition": "sub-questions", - "after_search": "search-results", - "after_synthesis": "synthesis", - "after_viewpoint_analysis": "viewpoints", - "after_reflection": "reflection", - "final_report": "final-report", - } - - # Get the default active tab based on stage - default_active_tab = stage_to_tab.get(current_stage, "overview") - - # Create tab headers dynamically based on available data - tabs_created = [] - - # Overview tab is always shown - is_active = default_active_tab == "overview" - html += f'
    • Overview
    • ' - tabs_created.append("overview") - - if state.sub_questions: - is_active = default_active_tab == "sub-questions" - html += f'
    • Sub-Questions
    • ' - tabs_created.append("sub-questions") - - if state.search_results: - is_active = default_active_tab == "search-results" - html += f'
    • Search Results
    • ' - tabs_created.append("search-results") - - if state.synthesized_info: - is_active = default_active_tab == "synthesis" - html += f'
    • Synthesis
    • ' - tabs_created.append("synthesis") - - if state.viewpoint_analysis: - is_active = default_active_tab == "viewpoints" - html += f'
    • Viewpoints
    • ' - tabs_created.append("viewpoints") - - if state.enhanced_info or state.reflection_metadata: - is_active = default_active_tab == "reflection" - html += f'
    • Reflection
    • ' - tabs_created.append("reflection") - - if state.final_report_html: - is_active = default_active_tab == "final-report" - html += f'
    • Final Report
    • ' - tabs_created.append("final-report") - - # Ensure the active tab actually exists in the created tabs - # If not, fallback to the first available tab - if default_active_tab not in tabs_created and tabs_created: - default_active_tab = tabs_created[0] - - html += """ -
    - - - """ - - # Overview tab content (always shown) - is_active = default_active_tab == "overview" - html += f""" -
    -
    -

    Main Query

    -
    - """ - - if state.main_query: - html += f"

    {state.main_query}

    " - else: - html += "

    No main query specified

    " - - html += """ -
    -
    -
    - """ - - # Sub-questions tab content - if state.sub_questions: - is_active = default_active_tab == "sub-questions" - html += f""" -
    -
    -

    Sub-Questions ({len(state.sub_questions)})

    -
    - """ - - for i, question in enumerate(state.sub_questions): - html += f""" -
    - {i + 1}. {question} -
    - """ - - html += """ -
    -
    -
    - """ - - # Search results tab content - if state.search_results: - is_active = default_active_tab == "search-results" - html += f""" -
    -
    -

    Search Results

    - """ - - for question, results in state.search_results.items(): - html += f""" -

    {question}

    -

    Found {len(results)} results

    -
      - """ - - for result in results: - # Extract domain from URL or use special handling for generated content - if result.url == "tavily-generated-answer": - domain = "Tavily" - else: - domain = "" - try: - from urllib.parse import urlparse - - parsed_url = urlparse(result.url) - domain = parsed_url.netloc - # Strip www. prefix to save space - if domain.startswith("www."): - domain = domain[4:] - except: - domain = ( - result.url.split("/")[2] - if len(result.url.split("/")) > 2 - else "" - ) - # Strip www. prefix to save space - if domain.startswith("www."): - domain = domain[4:] - - html += f""" -
    • - {result.title} ({domain}) -
    • - """ - - html += """ -
    - """ - - html += """ -
    -
    - """ - - # Synthesized information tab content - if state.synthesized_info: - is_active = default_active_tab == "synthesis" - html += f""" -
    -
    -

    Synthesized Information

    - """ - - for question, info in state.synthesized_info.items(): - html += f""" -

    {question} {info.confidence_level}

    -
    -

    {info.synthesized_answer}

    - """ - - if info.key_sources: - html += """ -
    -

    Key Sources:

    -
      - """ - - for source in info.key_sources[:3]: - html += f""" -
    • {source[:50]}...
    • - """ - - if len(info.key_sources) > 3: - html += f"
    • ...and {len(info.key_sources) - 3} more sources
    • " - - html += """ -
    -
    - """ - - if info.information_gaps: - html += f""" - - """ - - html += """ -
    - """ - - html += """ -
    -
    - """ - - # Viewpoint analysis tab content - if state.viewpoint_analysis: - is_active = default_active_tab == "viewpoints" - html += f""" -
    -
    -

    Viewpoint Analysis

    -
    - """ - - # Points of agreement - if state.viewpoint_analysis.main_points_of_agreement: - html += """ -

    Points of Agreement

    -
      - """ - - for point in state.viewpoint_analysis.main_points_of_agreement: - html += f""" -
    • {point}
    • - """ - - html += """ -
    - """ - - # Areas of tension - if state.viewpoint_analysis.areas_of_tension: - html += """ -

    Areas of Tension

    - """ - - for tension in state.viewpoint_analysis.areas_of_tension: - html += f""" -
    -

    {tension.topic}

    -
      - """ - - for viewpoint, description in tension.viewpoints.items(): - html += f""" -
    • {viewpoint}: {description}
    • - """ - - html += """ -
    -
    - """ - - # Perspective gaps and integrative insights - if state.viewpoint_analysis.perspective_gaps: - html += f""" -

    Perspective Gaps

    -

    {state.viewpoint_analysis.perspective_gaps}

    - """ - - if state.viewpoint_analysis.integrative_insights: - html += f""" -

    Integrative Insights

    -

    {state.viewpoint_analysis.integrative_insights}

    - """ - - html += """ -
    -
    -
    - """ - - # Reflection & Enhancement tab content - if state.enhanced_info or state.reflection_metadata: - is_active = default_active_tab == "reflection" - html += f""" -
    -
    -

    Reflection & Enhancement

    - """ - - # Reflection metadata - if state.reflection_metadata: - html += """ -
    - """ - - if state.reflection_metadata.critique_summary: - html += """ -

    Critique Summary

    -
      - """ - - for critique in state.reflection_metadata.critique_summary: - html += f""" -
    • {critique}
    • - """ - - html += """ -
    - """ - - if state.reflection_metadata.additional_questions_identified: - html += """ -

    Additional Questions Identified

    -
      - """ - - for question in state.reflection_metadata.additional_questions_identified: - html += f""" -
    • {question}
    • - """ - - html += """ -
    - """ - - html += f""" - -
    - """ - - # Enhanced information - if state.enhanced_info: - html += """ -

    Enhanced Information

    - """ - - for question, info in state.enhanced_info.items(): - # Show only for questions with improvements - if info.improvements: - html += f""" -
    -

    {question} {info.confidence_level}

    - -
    -

    Improvements Made:

    -
      - """ - - for improvement in info.improvements: - html += f""" -
    • {improvement}
    • - """ - - html += """ -
    -
    -
    - """ - - html += """ -
    -
    - """ - - # Final report tab - if state.final_report_html: - is_active = default_active_tab == "final-report" - html += f""" -
    -
    -

    Final Report

    -

    Final HTML report is available but not displayed here. View the HTML artifact to see the complete report.

    -
    -
    - """ - - # Close HTML tags - html += """ -
    - - - """ - - return html - - def _get_stage_class(self, state: ResearchState, stage: str) -> str: - """Get CSS class for a stage based on current progress. - - Args: - state: ResearchState object - stage: Stage name - - Returns: - CSS class string - """ - current_stage = state.get_current_stage() - - # These are the stages in order - stages = [ - "empty", - "initial", - "after_query_decomposition", - "after_search", - "after_synthesis", - "after_viewpoint_analysis", - "after_reflection", - "final_report", - ] - - current_index = ( - stages.index(current_stage) if current_stage in stages else 0 - ) - stage_index = stages.index(stage) if stage in stages else 0 - - if stage_index == current_index: - return "active" - elif stage_index < current_index: - return "completed" - else: - return "" - - def _calculate_progress(self, state: ResearchState) -> int: - """Calculate overall progress percentage. - - Args: - state: ResearchState object - - Returns: - Progress percentage (0-100) - """ - # Map stages to progress percentages - stage_percentages = { - "empty": 0, - "initial": 5, - "after_query_decomposition": 20, - "after_search": 40, - "after_synthesis": 60, - "after_viewpoint_analysis": 75, - "after_reflection": 90, - "final_report": 100, - } - - current_stage = state.get_current_stage() - return stage_percentages.get(current_stage, 0) diff --git a/deep_research/materializers/query_context_materializer.py b/deep_research/materializers/query_context_materializer.py new file mode 100644 index 00000000..ebb7088b --- /dev/null +++ b/deep_research/materializers/query_context_materializer.py @@ -0,0 +1,272 @@ +"""Materializer for QueryContext with interactive mind map visualization.""" + +import os +from typing import Dict + +from utils.pydantic_models import QueryContext +from zenml.enums import ArtifactType, VisualizationType +from zenml.io import fileio +from zenml.materializers import PydanticMaterializer + + +class QueryContextMaterializer(PydanticMaterializer): + """Materializer for QueryContext with mind map visualization.""" + + ASSOCIATED_TYPES = (QueryContext,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + + def save_visualizations( + self, data: QueryContext + ) -> Dict[str, VisualizationType]: + """Create and save mind map visualization for the QueryContext. + + Args: + data: The QueryContext to visualize + + Returns: + Dictionary mapping file paths to visualization types + """ + visualization_path = os.path.join(self.uri, "query_context.html") + html_content = self._generate_visualization_html(data) + + with fileio.open(visualization_path, "w") as f: + f.write(html_content) + + return {visualization_path: VisualizationType.HTML} + + def _generate_visualization_html(self, context: QueryContext) -> str: + """Generate HTML mind map visualization for the query context. + + Args: + context: The QueryContext to visualize + + Returns: + HTML string + """ + # Create sub-questions HTML + sub_questions_html = "" + if context.sub_questions: + for i, sub_q in enumerate(context.sub_questions, 1): + sub_questions_html += f""" +
    +
    {i}
    +
    {sub_q}
    +
    + """ + else: + sub_questions_html = '
    No sub-questions decomposed yet
    ' + + # Format timestamp + from datetime import datetime + + timestamp = datetime.fromtimestamp( + context.decomposition_timestamp + ).strftime("%Y-%m-%d %H:%M:%S UTC") + + html = f""" + + + + Query Context - {context.main_query[:50]}... + + + +
    +
    +

    Query Decomposition Mind Map

    +
    Created: {timestamp}
    +
    + +
    +
    + {context.main_query} +
    + +
    + {sub_questions_html} +
    +
    + +
    +
    +
    {len(context.sub_questions)}
    +
    Sub-Questions
    +
    +
    +
    {len(context.main_query.split())}
    +
    Words in Query
    +
    +
    +
    {sum(len(q.split()) for q in context.sub_questions)}
    +
    Total Sub-Question Words
    +
    +
    +
    + + + """ + + return html diff --git a/deep_research/materializers/reflection_output_materializer.py b/deep_research/materializers/reflection_output_materializer.py deleted file mode 100644 index 1e8b37ae..00000000 --- a/deep_research/materializers/reflection_output_materializer.py +++ /dev/null @@ -1,279 +0,0 @@ -"""Materializer for ReflectionOutput with custom visualization.""" - -import os -from typing import Dict - -from utils.pydantic_models import ReflectionOutput -from zenml.enums import ArtifactType, VisualizationType -from zenml.io import fileio -from zenml.materializers import PydanticMaterializer - - -class ReflectionOutputMaterializer(PydanticMaterializer): - """Materializer for the ReflectionOutput class with visualizations.""" - - ASSOCIATED_TYPES = (ReflectionOutput,) - ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA - - def save_visualizations( - self, data: ReflectionOutput - ) -> Dict[str, VisualizationType]: - """Create and save visualizations for the ReflectionOutput. - - Args: - data: The ReflectionOutput to visualize - - Returns: - Dictionary mapping file paths to visualization types - """ - # Generate an HTML visualization - visualization_path = os.path.join(self.uri, "reflection_output.html") - - # Create HTML content - html_content = self._generate_visualization_html(data) - - # Write the HTML content to a file - with fileio.open(visualization_path, "w") as f: - f.write(html_content) - - # Return the visualization path and type - return {visualization_path: VisualizationType.HTML} - - def _generate_visualization_html(self, output: ReflectionOutput) -> str: - """Generate HTML visualization for the reflection output. - - Args: - output: The ReflectionOutput to visualize - - Returns: - HTML string - """ - html = f""" - - - - Reflection Output - - - -
    -

    🔍 Reflection & Analysis Output

    - - - -
    -

    - 📝Critique Summary - {} -

    - """.format(len(output.critique_summary)) - - if output.critique_summary: - for critique in output.critique_summary: - html += """ -
    - """ - - # Handle different critique formats - if isinstance(critique, dict): - for key, value in critique.items(): - html += f""" -
    {key}:
    -
    {value}
    - """ - else: - html += f""" -
    {critique}
    - """ - - html += """ -
    - """ - else: - html += """ -

    No critique summary available

    - """ - - html += """ -
    - -
    -

    - Additional Questions Identified - {} -

    - """.format(len(output.additional_questions)) - - if output.additional_questions: - for question in output.additional_questions: - html += f""" -
    - {question} -
    - """ - else: - html += """ -

    No additional questions identified

    - """ - - html += """ -
    - -
    -

    📊Research State Summary

    -
    -

    Main Query: {}

    -

    Current Stage: {}

    -

    Sub-questions: {}

    -

    Search Results: {} queries with results

    -

    Synthesized Info: {} topics synthesized

    -
    -
    - """.format( - output.state.main_query, - output.state.get_current_stage().replace("_", " ").title(), - len(output.state.sub_questions), - len(output.state.search_results), - len(output.state.synthesized_info), - ) - - # Add metadata section - html += """ - -
    - - - """ - - return html diff --git a/deep_research/materializers/search_data_materializer.py b/deep_research/materializers/search_data_materializer.py new file mode 100644 index 00000000..7c0ef883 --- /dev/null +++ b/deep_research/materializers/search_data_materializer.py @@ -0,0 +1,394 @@ +"""Materializer for SearchData with cost breakdown charts and search results visualization.""" + +import json +import os +from typing import Dict + +from utils.pydantic_models import SearchData +from zenml.enums import ArtifactType, VisualizationType +from zenml.io import fileio +from zenml.materializers import PydanticMaterializer + + +class SearchDataMaterializer(PydanticMaterializer): + """Materializer for SearchData with interactive visualizations.""" + + ASSOCIATED_TYPES = (SearchData,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + + def save_visualizations( + self, data: SearchData + ) -> Dict[str, VisualizationType]: + """Create and save visualizations for the SearchData. + + Args: + data: The SearchData to visualize + + Returns: + Dictionary mapping file paths to visualization types + """ + visualization_path = os.path.join(self.uri, "search_data.html") + html_content = self._generate_visualization_html(data) + + with fileio.open(visualization_path, "w") as f: + f.write(html_content) + + return {visualization_path: VisualizationType.HTML} + + def _generate_visualization_html(self, data: SearchData) -> str: + """Generate HTML visualization for the search data. + + Args: + data: The SearchData to visualize + + Returns: + HTML string + """ + # Prepare data for charts + cost_data = [ + {"provider": k, "cost": v} for k, v in data.search_costs.items() + ] + + # Create search results HTML + results_html = "" + for sub_q, results in data.search_results.items(): + results_html += f""" +
    +

    {sub_q}

    +
    {len(results)} results found
    +
    + """ + + for i, result in enumerate(results[:5]): # Show first 5 results + results_html += f""" +
    +
    {result.title or "Untitled"}
    +
    {result.snippet or result.content[:200]}...
    + View Source +
    + """ + + if len(results) > 5: + results_html += f'
    ... and {len(results) - 5} more results
    ' + + results_html += """ +
    +
    + """ + + if not results_html: + results_html = ( + '
    No search results yet
    ' + ) + + # Calculate total cost + total_cost = sum(data.search_costs.values()) + + html = f""" + + + + Search Data Visualization + + + + +
    +
    +

    Search Data Analysis

    + +
    +
    +
    {data.total_searches}
    +
    Total Searches
    +
    +
    +
    { + len(data.search_results) + }
    +
    Sub-Questions
    +
    +
    +
    { + sum(len(results) for results in data.search_results.values()) + }
    +
    Total Results
    +
    +
    +
    ${total_cost:.4f}
    +
    Total Cost
    +
    +
    +
    + +
    +

    Cost Analysis

    + +
    + +
    + +
    +

    Cost Breakdown by Provider

    + + + + + + + + + + { + "".join( + f''' + + + + + + ''' + for provider, cost in data.search_costs.items() + ) + } + +
    ProviderCostPercentage
    {provider}${cost:.4f}{(cost / total_cost * 100 if total_cost > 0 else 0):.1f}%
    +
    +
    + +
    +

    Search Results

    + {results_html} +
    +
    + + + + + """ + + return html diff --git a/deep_research/materializers/synthesis_data_materializer.py b/deep_research/materializers/synthesis_data_materializer.py new file mode 100644 index 00000000..fb5d68f2 --- /dev/null +++ b/deep_research/materializers/synthesis_data_materializer.py @@ -0,0 +1,431 @@ +"""Materializer for SynthesisData with confidence metrics and synthesis quality visualization.""" + +import os +from typing import Dict + +from utils.pydantic_models import SynthesisData +from zenml.enums import ArtifactType, VisualizationType +from zenml.io import fileio +from zenml.materializers import PydanticMaterializer + + +class SynthesisDataMaterializer(PydanticMaterializer): + """Materializer for SynthesisData with quality metrics visualization.""" + + ASSOCIATED_TYPES = (SynthesisData,) + ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA + + def save_visualizations( + self, data: SynthesisData + ) -> Dict[str, VisualizationType]: + """Create and save visualizations for the SynthesisData. + + Args: + data: The SynthesisData to visualize + + Returns: + Dictionary mapping file paths to visualization types + """ + visualization_path = os.path.join(self.uri, "synthesis_data.html") + html_content = self._generate_visualization_html(data) + + with fileio.open(visualization_path, "w") as f: + f.write(html_content) + + return {visualization_path: VisualizationType.HTML} + + def _generate_visualization_html(self, data: SynthesisData) -> str: + """Generate HTML visualization for the synthesis data. + + Args: + data: The SynthesisData to visualize + + Returns: + HTML string + """ + # Count confidence levels + confidence_counts = {"high": 0, "medium": 0, "low": 0} + for info in data.synthesized_info.values(): + confidence_counts[info.confidence_level] += 1 + + # Create synthesis cards HTML + synthesis_html = "" + for sub_q, info in data.synthesized_info.items(): + confidence_color = { + "high": "#2dce89", + "medium": "#ffd600", + "low": "#f5365c", + }.get(info.confidence_level, "#666") + + sources_html = "" + if info.key_sources: + sources_html = ( + "
    Key Sources:
      " + ) + for source in info.key_sources[:3]: # Show first 3 sources + sources_html += f"
    • {source}
    • " + if len(info.key_sources) > 3: + sources_html += ( + f"
    • ... and {len(info.key_sources) - 3} more
    • " + ) + sources_html += "
    " + + gaps_html = "" + if info.information_gaps: + gaps_html = f""" +
    + Information Gaps: +

    {info.information_gaps}

    +
    + """ + + improvements_html = "" + if info.improvements: + improvements_html = "
    Suggested Improvements:
      " + for imp in info.improvements: + improvements_html += f"
    • {imp}
    • " + improvements_html += "
    " + + # Check if this has enhanced version + enhanced_badge = "" + enhanced_section = "" + if sub_q in data.enhanced_info: + enhanced_badge = 'Enhanced' + enhanced_info = data.enhanced_info[sub_q] + enhanced_section = f""" +
    +

    Enhanced Answer

    +

    {enhanced_info.synthesized_answer}

    +
    + Confidence: {enhanced_info.confidence_level.upper()} +
    +
    + """ + + synthesis_html += f""" +
    +
    +

    {sub_q}

    + {enhanced_badge} +
    + +
    +

    Original Synthesis

    +

    {info.synthesized_answer}

    + +
    + Confidence: {info.confidence_level.upper()} +
    + + {sources_html} + {gaps_html} + {improvements_html} +
    + + {enhanced_section} +
    + """ + + if not synthesis_html: + synthesis_html = '
    No synthesis data available yet
    ' + + # Calculate statistics + total_syntheses = len(data.synthesized_info) + total_enhanced = len(data.enhanced_info) + avg_sources = sum( + len(info.key_sources) for info in data.synthesized_info.values() + ) / max(total_syntheses, 1) + + html = f""" + + + + Synthesis Data Visualization + + + + +
    +
    +

    Synthesis Quality Analysis

    +
    + +
    +
    +
    {total_syntheses}
    +
    Total Syntheses
    +
    +
    +
    {total_enhanced}
    +
    Enhanced Syntheses
    +
    +
    +
    {avg_sources:.1f}
    +
    Avg Sources per Synthesis
    +
    +
    +
    {confidence_counts["high"]}
    +
    High Confidence
    +
    +
    + +
    +

    Confidence Distribution

    +
    + +
    +
    + +
    +

    Synthesized Information

    + {synthesis_html} +
    +
    + + + + + """ + + return html diff --git a/deep_research/pipelines/parallel_research_pipeline.py b/deep_research/pipelines/parallel_research_pipeline.py index 669b4824..fabdc204 100644 --- a/deep_research/pipelines/parallel_research_pipeline.py +++ b/deep_research/pipelines/parallel_research_pipeline.py @@ -8,7 +8,6 @@ from steps.process_sub_question_step import process_sub_question_step from steps.pydantic_final_report_step import pydantic_final_report_step from steps.query_decomposition_step import initial_query_decomposition_step -from utils.pydantic_models import ResearchState from zenml import pipeline @@ -56,24 +55,22 @@ def parallelized_deep_research_pipeline( introduction_prompt, ) = initialize_prompts_step(pipeline_version="1.0.0") - # Initialize the research state with the main query - state = ResearchState(main_query=query) - # Step 1: Decompose the query into sub-questions, limiting to max_sub_questions - decomposed_state = initial_query_decomposition_step( - state=state, + query_context = initial_query_decomposition_step( + main_query=query, query_decomposition_prompt=query_decomposition_prompt, max_sub_questions=max_sub_questions, langfuse_project_name=langfuse_project_name, ) # Fan out: Process each sub-question in parallel - # Collect artifacts to establish dependencies for the merge step - after = [] + # Collect step names to establish dependencies for the merge step + parallel_step_names = [] for i in range(max_sub_questions): # Process the i-th sub-question (if it exists) - sub_state = process_sub_question_step( - state=decomposed_state, + step_name = f"process_question_{i + 1}" + search_data, synthesis_data = process_sub_question_step( + query_context=query_context, search_query_prompt=search_query_prompt, synthesis_prompt=synthesis_prompt, question_index=i, @@ -81,66 +78,87 @@ def parallelized_deep_research_pipeline( search_mode=search_mode, num_results_per_search=num_results_per_search, langfuse_project_name=langfuse_project_name, - id=f"process_question_{i + 1}", + id=step_name, + after="initial_query_decomposition_step", ) - after.append(sub_state) + parallel_step_names.append(step_name) # Fan in: Merge results from all parallel processing # The 'after' parameter ensures this step runs after all processing steps - # It doesn't directly use the processed_states input - merged_state = merge_sub_question_results_step( - original_state=decomposed_state, - step_prefix="process_question_", - output_name="output", - after=after, # This creates the dependency + merged_search_data, merged_synthesis_data = ( + merge_sub_question_results_step( + step_prefix="process_question_", + after=parallel_step_names, # Wait for all parallel steps to complete + ) ) # Continue with subsequent steps - analyzed_state = cross_viewpoint_analysis_step( - state=merged_state, + analysis_data = cross_viewpoint_analysis_step( + query_context=query_context, + synthesis_data=merged_synthesis_data, viewpoint_analysis_prompt=viewpoint_analysis_prompt, langfuse_project_name=langfuse_project_name, + after="merge_sub_question_results_step", ) # New 3-step reflection flow with optional human approval # Step 1: Generate reflection and recommendations (no searches yet) - reflection_output = generate_reflection_step( - state=analyzed_state, + analysis_with_reflection, recommended_queries = generate_reflection_step( + query_context=query_context, + synthesis_data=merged_synthesis_data, + analysis_data=analysis_data, reflection_prompt=reflection_prompt, langfuse_project_name=langfuse_project_name, + after="cross_viewpoint_analysis_step", ) # Step 2: Get approval for recommended searches approval_decision = get_research_approval_step( - reflection_output=reflection_output, + query_context=query_context, + synthesis_data=merged_synthesis_data, + analysis_data=analysis_with_reflection, + recommended_queries=recommended_queries, require_approval=require_approval, timeout=approval_timeout, max_queries=max_additional_searches, + after="generate_reflection_step", ) # Step 3: Execute approved searches (if any) - reflected_state = execute_approved_searches_step( - reflection_output=reflection_output, - approval_decision=approval_decision, - additional_synthesis_prompt=additional_synthesis_prompt, - search_provider=search_provider, - search_mode=search_mode, - num_results_per_search=num_results_per_search, - langfuse_project_name=langfuse_project_name, + enhanced_search_data, enhanced_synthesis_data, enhanced_analysis_data = ( + execute_approved_searches_step( + query_context=query_context, + search_data=merged_search_data, + synthesis_data=merged_synthesis_data, + analysis_data=analysis_with_reflection, + recommended_queries=recommended_queries, + approval_decision=approval_decision, + additional_synthesis_prompt=additional_synthesis_prompt, + search_provider=search_provider, + search_mode=search_mode, + num_results_per_search=num_results_per_search, + langfuse_project_name=langfuse_project_name, + after="get_research_approval_step", + ) ) # Use our new Pydantic-based final report step - # This returns a tuple (state, html_report) - final_state, final_report = pydantic_final_report_step( - state=reflected_state, + pydantic_final_report_step( + query_context=query_context, + search_data=enhanced_search_data, + synthesis_data=enhanced_synthesis_data, + analysis_data=enhanced_analysis_data, conclusion_generation_prompt=conclusion_generation_prompt, executive_summary_prompt=executive_summary_prompt, introduction_prompt=introduction_prompt, langfuse_project_name=langfuse_project_name, + after="execute_approved_searches_step", ) # Collect tracing metadata for the entire pipeline run - _, tracing_metadata = collect_tracing_metadata_step( - state=final_state, + collect_tracing_metadata_step( + query_context=query_context, + search_data=enhanced_search_data, langfuse_project_name=langfuse_project_name, + after="pydantic_final_report_step", ) diff --git a/deep_research/steps/approval_step.py b/deep_research/steps/approval_step.py index c74277c9..93566049 100644 --- a/deep_research/steps/approval_step.py +++ b/deep_research/steps/approval_step.py @@ -1,26 +1,62 @@ import logging import time -from typing import Annotated +from typing import Annotated, List from materializers.approval_decision_materializer import ( ApprovalDecisionMaterializer, ) from utils.approval_utils import ( format_approval_request, - summarize_research_progress, ) -from utils.pydantic_models import ApprovalDecision, ReflectionOutput -from zenml import add_tags, log_metadata, step +from utils.pydantic_models import ( + AnalysisData, + ApprovalDecision, + QueryContext, + SynthesisData, +) +from zenml import log_metadata, step from zenml.client import Client logger = logging.getLogger(__name__) +def summarize_research_progress_from_artifacts( + synthesis_data: SynthesisData, analysis_data: AnalysisData +) -> dict: + """Summarize research progress from the new artifact structure.""" + completed_count = len(synthesis_data.synthesized_info) + + # Calculate confidence levels from synthesis data + confidence_levels = [] + for info in synthesis_data.synthesized_info.values(): + confidence_levels.append(info.confidence_level) + + # Calculate average confidence (high=1.0, medium=0.5, low=0.0) + confidence_map = {"high": 1.0, "medium": 0.5, "low": 0.0} + avg_confidence = sum( + confidence_map.get(c.lower(), 0.5) for c in confidence_levels + ) / max(len(confidence_levels), 1) + + low_confidence_count = sum( + 1 for c in confidence_levels if c.lower() == "low" + ) + + return { + "completed_count": completed_count, + "avg_confidence": round(avg_confidence, 2), + "low_confidence_count": low_confidence_count, + } + + @step( - enable_cache=False, output_materializers=ApprovalDecisionMaterializer + enable_cache=False, + output_materializers={"approval_decision": ApprovalDecisionMaterializer}, ) # Never cache approval decisions def get_research_approval_step( - reflection_output: ReflectionOutput, + query_context: QueryContext, + synthesis_data: SynthesisData, + analysis_data: AnalysisData, + recommended_queries: List[str], require_approval: bool = True, alerter_type: str = "slack", timeout: int = 3600, @@ -33,7 +69,10 @@ def get_research_approval_step( automatically approves all queries. Args: - reflection_output: Output from the reflection generation step + query_context: Context containing the main query and sub-questions + synthesis_data: Synthesized information from research + analysis_data: Analysis including viewpoints and critique + recommended_queries: List of recommended additional queries require_approval: Whether to require human approval alerter_type: Type of alerter to use (slack, email, etc.) timeout: Timeout in seconds for approval response @@ -45,7 +84,7 @@ def get_research_approval_step( start_time = time.time() # Limit queries to max_queries - limited_queries = reflection_output.recommended_queries[:max_queries] + limited_queries = recommended_queries[:max_queries] # If approval not required, auto-approve all if not require_approval: @@ -61,9 +100,7 @@ def get_research_approval_step( "execution_time_seconds": execution_time, "approval_required": False, "approval_method": "AUTO_APPROVED", - "num_queries_recommended": len( - reflection_output.recommended_queries - ), + "num_queries_recommended": len(recommended_queries), "num_queries_approved": len(limited_queries), "max_queries_allowed": max_queries, "approval_status": "approved", @@ -108,11 +145,29 @@ def get_research_approval_step( ) # Prepare approval request - progress_summary = summarize_research_progress(reflection_output.state) + progress_summary = summarize_research_progress_from_artifacts( + synthesis_data, analysis_data + ) + + # Extract critique points from analysis data + critique_points = [] + if analysis_data.critique_summary: + # Convert critique summary to list of dicts for compatibility + for i, critique in enumerate( + analysis_data.critique_summary.split("\n") + ): + if critique.strip(): + critique_points.append( + { + "issue": critique.strip(), + "importance": "high" if i < 3 else "medium", + } + ) + message = format_approval_request( - main_query=reflection_output.state.main_query, + main_query=query_context.main_query, progress_summary=progress_summary, - critique_points=reflection_output.critique_summary, + critique_points=critique_points, proposed_queries=limited_queries, timeout=timeout, ) @@ -140,9 +195,7 @@ def get_research_approval_step( "approval_required": require_approval, "approval_method": "NO_ALERTER_AUTO_APPROVED", "alerter_type": "none", - "num_queries_recommended": len( - reflection_output.recommended_queries - ), + "num_queries_recommended": len(recommended_queries), "num_queries_approved": len(limited_queries), "max_queries_allowed": max_queries, "approval_status": "auto_approved", @@ -202,7 +255,7 @@ def get_research_approval_step( "approval_method": "DISCORD_APPROVED", "alerter_type": alerter_type, "num_queries_recommended": len( - reflection_output.recommended_queries + recommended_queries ), "num_queries_approved": len(limited_queries), "max_queries_allowed": max_queries, @@ -230,7 +283,7 @@ def get_research_approval_step( "approval_method": "DISCORD_REJECTED", "alerter_type": alerter_type, "num_queries_recommended": len( - reflection_output.recommended_queries + recommended_queries ), "num_queries_approved": 0, "max_queries_allowed": max_queries, @@ -260,9 +313,7 @@ def get_research_approval_step( "approval_required": require_approval, "approval_method": "ALERTER_ERROR", "alerter_type": alerter_type, - "num_queries_recommended": len( - reflection_output.recommended_queries - ), + "num_queries_recommended": len(recommended_queries), "num_queries_approved": 0, "max_queries_allowed": max_queries, "approval_status": "error", @@ -289,9 +340,7 @@ def get_research_approval_step( "execution_time_seconds": execution_time, "approval_required": require_approval, "approval_method": "ERROR", - "num_queries_recommended": len( - reflection_output.recommended_queries - ), + "num_queries_recommended": len(recommended_queries), "num_queries_approved": 0, "max_queries_allowed": max_queries, "approval_status": "error", @@ -301,7 +350,7 @@ def get_research_approval_step( ) # Add tag to the approval decision artifact - add_tags(tags=["hitl"], artifact="approval_decision") + # add_tags(tags=["hitl"], artifact_name="approval_decision", infer_artifact=True) return ApprovalDecision( approved=False, diff --git a/deep_research/steps/collect_tracing_metadata_step.py b/deep_research/steps/collect_tracing_metadata_step.py index d7e2e8a6..bfaa70e8 100644 --- a/deep_research/steps/collect_tracing_metadata_step.py +++ b/deep_research/steps/collect_tracing_metadata_step.py @@ -1,15 +1,15 @@ """Step to collect tracing metadata from Langfuse for the pipeline run.""" import logging -from typing import Annotated, Tuple +from typing import Annotated, Dict -from materializers.pydantic_materializer import ResearchStateMaterializer from materializers.tracing_metadata_materializer import ( TracingMetadataMaterializer, ) from utils.pydantic_models import ( PromptTypeMetrics, - ResearchState, + QueryContext, + SearchData, TracingMetadata, ) from utils.tracing_metadata_utils import ( @@ -18,7 +18,7 @@ get_trace_stats, get_traces_by_name, ) -from zenml import add_tags, get_step_context, step +from zenml import get_step_context, step logger = logging.getLogger(__name__) @@ -26,28 +26,26 @@ @step( enable_cache=False, output_materializers={ - "state": ResearchStateMaterializer, "tracing_metadata": TracingMetadataMaterializer, }, ) def collect_tracing_metadata_step( - state: ResearchState, + query_context: QueryContext, + search_data: SearchData, langfuse_project_name: str, -) -> Tuple[ - Annotated[ResearchState, "state"], - Annotated[TracingMetadata, "tracing_metadata"], -]: +) -> Annotated[TracingMetadata, "tracing_metadata"]: """Collect tracing metadata from Langfuse for the current pipeline run. This step gathers comprehensive metrics about token usage, costs, and performance for the entire pipeline run, providing insights into resource consumption. Args: - state: The final research state + query_context: The query context (for reference) + search_data: The search data containing cost information langfuse_project_name: Langfuse project name for accessing traces Returns: - Tuple of (ResearchState, TracingMetadata) - the state is passed through unchanged + TracingMetadata with comprehensive cost and performance metrics """ ctx = get_step_context() pipeline_run_name = ctx.pipeline_run.name @@ -74,7 +72,9 @@ def collect_tracing_metadata_step( logger.warning( f"No trace found for pipeline run: {pipeline_run_name}" ) - return state, metadata + # Still add search costs before returning + _add_search_costs_to_metadata(metadata, search_data) + return metadata trace = traces[0] @@ -179,26 +179,8 @@ def collect_tracing_metadata_step( except Exception as e: logger.warning(f"Failed to collect prompt-level metrics: {str(e)}") - # Add search costs from the state - if hasattr(state, "search_costs") and state.search_costs: - metadata.search_costs = state.search_costs.copy() - logger.info(f"Added search costs: {metadata.search_costs}") - - if hasattr(state, "search_cost_details") and state.search_cost_details: - metadata.search_cost_details = state.search_cost_details.copy() - - # Count queries by provider - search_queries_count = {} - for detail in state.search_cost_details: - provider = detail.get("provider", "unknown") - search_queries_count[provider] = ( - search_queries_count.get(provider, 0) + 1 - ) - metadata.search_queries_count = search_queries_count - - logger.info( - f"Added {len(metadata.search_cost_details)} search cost detail entries" - ) + # Add search costs from the SearchData artifact + _add_search_costs_to_metadata(metadata, search_data) total_search_cost = sum(metadata.search_costs.values()) logger.info( @@ -216,25 +198,55 @@ def collect_tracing_metadata_step( f"Failed to collect tracing metadata for pipeline run {pipeline_run_name}: {str(e)}" ) # Return metadata with whatever we could collect - # Still try to get search costs even if Langfuse failed - if hasattr(state, "search_costs") and state.search_costs: - metadata.search_costs = state.search_costs.copy() - if hasattr(state, "search_cost_details") and state.search_cost_details: - metadata.search_cost_details = state.search_cost_details.copy() - # Count queries by provider - search_queries_count = {} - for detail in state.search_cost_details: - provider = detail.get("provider", "unknown") - search_queries_count[provider] = ( - search_queries_count.get(provider, 0) + 1 - ) - metadata.search_queries_count = search_queries_count + _add_search_costs_to_metadata(metadata, search_data) - # Add tags to the artifacts - add_tags(tags=["state"], artifact="state") - add_tags( - tags=["exa", "tavily", "llm", "cost"], artifact="tracing_metadata" - ) + # Add tags to the artifact + # add_tags( + # tags=["exa", "tavily", "llm", "cost", "tracing"], + # artifact_name="tracing_metadata", + # infer_artifact=True, + # ) - return state, metadata + return metadata + + +def _add_search_costs_to_metadata( + metadata: TracingMetadata, search_data: SearchData +) -> None: + """Add search costs from SearchData to TracingMetadata. + + Args: + metadata: The TracingMetadata object to update + search_data: The SearchData containing cost information + """ + if search_data.search_costs: + metadata.search_costs = search_data.search_costs.copy() + logger.info(f"Added search costs: {metadata.search_costs}") + + if search_data.search_cost_details: + # Convert SearchCostDetail objects to dicts for backward compatibility + metadata.search_cost_details = [ + { + "provider": detail.provider, + "query": detail.query, + "cost": detail.cost, + "timestamp": detail.timestamp, + "step": detail.step, + "sub_question": detail.sub_question, + } + for detail in search_data.search_cost_details + ] + + # Count queries by provider + search_queries_count: Dict[str, int] = {} + for detail in search_data.search_cost_details: + provider = detail.provider + search_queries_count[provider] = ( + search_queries_count.get(provider, 0) + 1 + ) + metadata.search_queries_count = search_queries_count + + logger.info( + f"Added {len(metadata.search_cost_details)} search cost detail entries" + ) diff --git a/deep_research/steps/cross_viewpoint_step.py b/deep_research/steps/cross_viewpoint_step.py index 0c24b6e3..9212608e 100644 --- a/deep_research/steps/cross_viewpoint_step.py +++ b/deep_research/steps/cross_viewpoint_step.py @@ -3,25 +3,28 @@ import time from typing import Annotated, List -from materializers.pydantic_materializer import ResearchStateMaterializer +from materializers.analysis_data_materializer import AnalysisDataMaterializer from utils.helper_functions import ( safe_json_loads, ) from utils.llm_utils import run_llm_completion from utils.pydantic_models import ( + AnalysisData, Prompt, - ResearchState, + QueryContext, + SynthesisData, ViewpointAnalysis, ViewpointTension, ) -from zenml import add_tags, log_metadata, step +from zenml import log_metadata, step logger = logging.getLogger(__name__) -@step(output_materializers=ResearchStateMaterializer) +@step(output_materializers={"analysis_data": AnalysisDataMaterializer}) def cross_viewpoint_analysis_step( - state: ResearchState, + query_context: QueryContext, + synthesis_data: SynthesisData, viewpoint_analysis_prompt: Prompt, llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", viewpoint_categories: List[str] = [ @@ -33,27 +36,32 @@ def cross_viewpoint_analysis_step( "historical", ], langfuse_project_name: str = "deep-research", -) -> Annotated[ResearchState, "analyzed_state"]: +) -> Annotated[AnalysisData, "analysis_data"]: """Analyze synthesized information across different viewpoints. Args: - state: The current research state + query_context: The query context with main query and sub-questions + synthesis_data: The synthesized information to analyze viewpoint_analysis_prompt: Prompt for viewpoint analysis llm_model: The model to use for viewpoint analysis viewpoint_categories: Categories of viewpoints to analyze + langfuse_project_name: Project name for tracing Returns: - Updated research state with viewpoint analysis + AnalysisData containing viewpoint analysis """ start_time = time.time() logger.info( - f"Performing cross-viewpoint analysis on {len(state.synthesized_info)} sub-questions" + f"Performing cross-viewpoint analysis on {len(synthesis_data.synthesized_info)} sub-questions" ) + # Initialize analysis data + analysis_data = AnalysisData() + # Prepare input for viewpoint analysis analysis_input = { - "main_query": state.main_query, - "sub_questions": state.sub_questions, + "main_query": query_context.main_query, + "sub_questions": query_context.sub_questions, "synthesized_information": { question: { "synthesized_answer": info.synthesized_answer, @@ -61,7 +69,7 @@ def cross_viewpoint_analysis_step( "confidence_level": info.confidence_level, "information_gaps": info.information_gaps, } - for question, info in state.synthesized_info.items() + for question, info in synthesis_data.synthesized_info.items() }, "viewpoint_categories": viewpoint_categories, } @@ -113,8 +121,8 @@ def cross_viewpoint_analysis_step( logger.info("Completed viewpoint analysis") - # Update the state with the viewpoint analysis - state.update_viewpoint_analysis(viewpoint_analysis) + # Update the analysis data with the viewpoint analysis + analysis_data.viewpoint_analysis = viewpoint_analysis # Calculate execution time execution_time = time.time() - start_time @@ -133,7 +141,9 @@ def cross_viewpoint_analysis_step( "viewpoint_analysis": { "execution_time_seconds": execution_time, "llm_model": llm_model, - "num_sub_questions_analyzed": len(state.synthesized_info), + "num_sub_questions_analyzed": len( + synthesis_data.synthesized_info + ), "viewpoint_categories_requested": viewpoint_categories, "num_agreement_points": len( viewpoint_analysis.main_points_of_agreement @@ -172,7 +182,7 @@ def cross_viewpoint_analysis_step( # Log artifact metadata log_metadata( metadata={ - "state_with_viewpoint_analysis": { + "analysis_data_characteristics": { "has_viewpoint_analysis": True, "total_viewpoints_analyzed": sum( tension_categories.values() @@ -184,13 +194,14 @@ def cross_viewpoint_analysis_step( else None, } }, + artifact_name="analysis_data", infer_artifact=True, ) # Add tags to the artifact - add_tags(tags=["state", "viewpoint"], artifact="analyzed_state") + # add_tags(tags=["analysis", "viewpoint"], artifact_name="analysis_data", infer_artifact=True) - return state + return analysis_data except Exception as e: logger.error(f"Error performing viewpoint analysis: {e}") @@ -204,8 +215,8 @@ def cross_viewpoint_analysis_step( integrative_insights="No insights available due to analysis failure.", ) - # Update the state with the fallback analysis - state.update_viewpoint_analysis(fallback_analysis) + # Update the analysis data with the fallback analysis + analysis_data.viewpoint_analysis = fallback_analysis # Log error metadata execution_time = time.time() - start_time @@ -214,7 +225,9 @@ def cross_viewpoint_analysis_step( "viewpoint_analysis": { "execution_time_seconds": execution_time, "llm_model": llm_model, - "num_sub_questions_analyzed": len(state.synthesized_info), + "num_sub_questions_analyzed": len( + synthesis_data.synthesized_info + ), "viewpoint_categories_requested": viewpoint_categories, "analysis_success": False, "error_message": str(e), @@ -224,6 +237,10 @@ def cross_viewpoint_analysis_step( ) # Add tags to the artifact - add_tags(tags=["state", "viewpoint"], artifact="analyzed_state") + # add_tags( + # tags=["analysis", "viewpoint", "fallback"], + # artifact_name="analysis_data", + # infer_artifact=True, + # ) - return state + return analysis_data diff --git a/deep_research/steps/execute_approved_searches_step.py b/deep_research/steps/execute_approved_searches_step.py index db1b46e9..f1eb8625 100644 --- a/deep_research/steps/execute_approved_searches_step.py +++ b/deep_research/steps/execute_approved_searches_step.py @@ -1,47 +1,45 @@ import json import logging import time -from typing import Annotated +from typing import Annotated, List, Tuple -from materializers.pydantic_materializer import ResearchStateMaterializer +from materializers.analysis_data_materializer import AnalysisDataMaterializer +from materializers.search_data_materializer import SearchDataMaterializer +from materializers.synthesis_data_materializer import SynthesisDataMaterializer from utils.llm_utils import ( find_most_relevant_string, get_structured_llm_output, is_text_relevant, ) from utils.pydantic_models import ( + AnalysisData, ApprovalDecision, Prompt, - ReflectionMetadata, - ReflectionOutput, - ResearchState, + QueryContext, + SearchCostDetail, + SearchData, + SynthesisData, SynthesizedInfo, ) from utils.search_utils import search_and_extract_results -from zenml import add_tags, log_metadata, step +from zenml import log_metadata, step logger = logging.getLogger(__name__) -def create_enhanced_info_copy(synthesized_info): - """Create a deep copy of synthesized info for enhancement.""" - return { - k: SynthesizedInfo( - synthesized_answer=v.synthesized_answer, - key_sources=v.key_sources.copy(), - confidence_level=v.confidence_level, - information_gaps=v.information_gaps, - improvements=v.improvements.copy() - if hasattr(v, "improvements") - else [], - ) - for k, v in synthesized_info.items() +@step( + output_materializers={ + "enhanced_search_data": SearchDataMaterializer, + "enhanced_synthesis_data": SynthesisDataMaterializer, + "updated_analysis_data": AnalysisDataMaterializer, } - - -@step(output_materializers=ResearchStateMaterializer) +) def execute_approved_searches_step( - reflection_output: ReflectionOutput, + query_context: QueryContext, + search_data: SearchData, + synthesis_data: SynthesisData, + analysis_data: AnalysisData, + recommended_queries: List[str], approval_decision: ApprovalDecision, additional_synthesis_prompt: Prompt, num_results_per_search: int = 3, @@ -50,32 +48,51 @@ def execute_approved_searches_step( search_provider: str = "tavily", search_mode: str = "auto", langfuse_project_name: str = "deep-research", -) -> Annotated[ResearchState, "updated_state"]: - """Execute approved searches and enhance the research state. +) -> Tuple[ + Annotated[SearchData, "enhanced_search_data"], + Annotated[SynthesisData, "enhanced_synthesis_data"], + Annotated[AnalysisData, "updated_analysis_data"], +]: + """Execute approved searches and enhance the research artifacts. This step receives the approval decision and only executes searches that were approved by the human reviewer (or auto-approved). Args: - reflection_output: Output from the reflection generation step + query_context: The query context with main query and sub-questions + search_data: The existing search data + synthesis_data: The existing synthesis data + analysis_data: The analysis data with viewpoint and reflection metadata + recommended_queries: The recommended queries from reflection approval_decision: Human approval decision + additional_synthesis_prompt: Prompt for synthesis enhancement num_results_per_search: Number of results to fetch per search cap_search_length: Maximum length of content to process from search results llm_model: The model to use for synthesis enhancement - prompts_bundle: Bundle containing all prompts for the pipeline search_provider: Search provider to use search_mode: Search mode for the provider + langfuse_project_name: Project name for tracing Returns: - Updated research state with enhanced information and reflection metadata + Tuple of enhanced SearchData, SynthesisData, and updated AnalysisData """ start_time = time.time() logger.info( f"Processing approval decision: {approval_decision.approval_method}" ) - state = reflection_output.state - enhanced_info = create_enhanced_info_copy(state.synthesized_info) + # Create copies of the data to enhance + enhanced_search_data = SearchData( + search_results=search_data.search_results.copy(), + search_costs=search_data.search_costs.copy(), + search_cost_details=search_data.search_cost_details.copy(), + total_searches=search_data.total_searches, + ) + + enhanced_synthesis_data = SynthesisData( + synthesized_info=synthesis_data.synthesized_info.copy(), + enhanced_info={}, # Will be populated with enhanced versions + ) # Track improvements count improvements_count = 0 @@ -87,39 +104,10 @@ def execute_approved_searches_step( ): logger.info("No additional searches approved") - # Add any additional questions as new synthesized entries (from reflection) - for new_question in reflection_output.additional_questions: - if ( - new_question not in state.sub_questions - and new_question not in enhanced_info - ): - enhanced_info[new_question] = SynthesizedInfo( - synthesized_answer=f"This question was identified during reflection but has not yet been researched: {new_question}", - key_sources=[], - confidence_level="low", - information_gaps="This question requires additional research.", - ) - - # Create metadata indicating no additional research - reflection_metadata = ReflectionMetadata( - critique_summary=[ - c.get("issue", "") for c in reflection_output.critique_summary - ], - additional_questions_identified=reflection_output.additional_questions, - searches_performed=[], - improvements_made=improvements_count, - ) - - # Add approval decision info to metadata - if hasattr(reflection_metadata, "__dict__"): - reflection_metadata.__dict__["user_decision"] = ( - approval_decision.approval_method - ) - reflection_metadata.__dict__["reviewer_notes"] = ( - approval_decision.reviewer_notes - ) - - state.update_after_reflection(enhanced_info, reflection_metadata) + # Update reflection metadata with no searches + if analysis_data.reflection_metadata: + analysis_data.reflection_metadata.searches_performed = [] + analysis_data.reflection_metadata.improvements_made = 0.0 # Log metadata for no approved searches execution_time = time.time() - start_time @@ -133,9 +121,7 @@ def execute_approved_searches_step( else "no_queries", "num_queries_approved": 0, "num_searches_executed": 0, - "num_additional_questions": len( - reflection_output.additional_questions - ), + "num_recommended": len(recommended_queries), "improvements_made": improvements_count, "search_provider": search_provider, "llm_model": llm_model, @@ -143,10 +129,20 @@ def execute_approved_searches_step( } ) - # Add tags to the artifact - add_tags(tags=["state", "enhanced"], artifact="updated_state") - - return state + # Add tags to the artifacts + # add_tags( + # tags=["search", "not-enhanced"], artifact_name="enhanced_search_data", infer_artifact=True + # ) + # add_tags( + # tags=["synthesis", "not-enhanced"], + # artifact_name="enhanced_synthesis_data", + # infer_artifact=True, + # ) + # add_tags( + # tags=["analysis", "no-searches"], artifact_name="updated_analysis_data", infer_artifact=True + # ) + + return enhanced_search_data, enhanced_synthesis_data, analysis_data # Execute approved searches logger.info( @@ -175,53 +171,72 @@ def execute_approved_searches_step( and search_cost > 0 ): # Update total costs - state.search_costs["exa"] = ( - state.search_costs.get("exa", 0.0) + search_cost + enhanced_search_data.search_costs["exa"] = ( + enhanced_search_data.search_costs.get("exa", 0.0) + + search_cost ) # Add detailed cost entry - state.search_cost_details.append( - { - "provider": "exa", - "query": query, - "cost": search_cost, - "timestamp": time.time(), - "step": "execute_approved_searches", - "purpose": "reflection_enhancement", - } + enhanced_search_data.search_cost_details.append( + SearchCostDetail( + provider="exa", + query=query, + cost=search_cost, + timestamp=time.time(), + step="execute_approved_searches", + sub_question=None, # These are reflection queries + ) ) logger.info( f"Exa search cost for approved query: ${search_cost:.4f}" ) + # Update total searches + enhanced_search_data.total_searches += 1 + # Extract raw contents raw_contents = [result.content for result in search_results] # Find the most relevant sub-question for this query most_relevant_question = find_most_relevant_string( query, - state.sub_questions, + query_context.sub_questions, llm_model, project=langfuse_project_name, ) if ( most_relevant_question - and most_relevant_question in enhanced_info + and most_relevant_question in synthesis_data.synthesized_info ): + # Store the search results under the relevant question + if ( + most_relevant_question + in enhanced_search_data.search_results + ): + enhanced_search_data.search_results[ + most_relevant_question + ].extend(search_results) + else: + enhanced_search_data.search_results[ + most_relevant_question + ] = search_results + # Enhance the synthesis with new information + original_synthesis = synthesis_data.synthesized_info[ + most_relevant_question + ] + enhancement_input = { - "original_synthesis": enhanced_info[ - most_relevant_question - ].synthesized_answer, + "original_synthesis": original_synthesis.synthesized_answer, "new_information": raw_contents, "critique": [ item - for item in reflection_output.critique_summary - if is_text_relevant( - item.get("issue", ""), most_relevant_question - ) - ], + for item in analysis_data.reflection_metadata.critique_summary + if is_text_relevant(item, most_relevant_question) + ] + if analysis_data.reflection_metadata + else [], } # Use the utility function for enhancement @@ -230,9 +245,7 @@ def execute_approved_searches_step( system_prompt=str(additional_synthesis_prompt), model=llm_model, fallback_response={ - "enhanced_synthesis": enhanced_info[ - most_relevant_question - ].synthesized_answer, + "enhanced_synthesis": original_synthesis.synthesized_answer, "improvements_made": ["Failed to enhance synthesis"], "remaining_limitations": "Enhancement process failed.", }, @@ -243,20 +256,31 @@ def execute_approved_searches_step( enhanced_synthesis and "enhanced_synthesis" in enhanced_synthesis ): - # Update the synthesized answer - enhanced_info[ + # Create enhanced synthesis info + enhanced_info = SynthesizedInfo( + synthesized_answer=enhanced_synthesis[ + "enhanced_synthesis" + ], + key_sources=original_synthesis.key_sources + + [r.url for r in search_results[:2]], + confidence_level="high" + if original_synthesis.confidence_level == "medium" + else original_synthesis.confidence_level, + information_gaps=enhanced_synthesis.get( + "remaining_limitations", "" + ), + improvements=original_synthesis.improvements + + enhanced_synthesis.get("improvements_made", []), + ) + + # Store in enhanced_info + enhanced_synthesis_data.enhanced_info[ most_relevant_question - ].synthesized_answer = enhanced_synthesis[ - "enhanced_synthesis" - ] + ] = enhanced_info - # Add improvements improvements = enhanced_synthesis.get( "improvements_made", [] ) - enhanced_info[most_relevant_question].improvements.extend( - improvements - ) improvements_count += len(improvements) # Track enhancement for metadata @@ -274,44 +298,19 @@ def execute_approved_searches_step( } ) - # Add any additional questions as new synthesized entries - for new_question in reflection_output.additional_questions: - if ( - new_question not in state.sub_questions - and new_question not in enhanced_info - ): - enhanced_info[new_question] = SynthesizedInfo( - synthesized_answer=f"This question was identified during reflection but has not yet been researched: {new_question}", - key_sources=[], - confidence_level="low", - information_gaps="This question requires additional research.", - ) - - # Create final metadata with approval info - reflection_metadata = ReflectionMetadata( - critique_summary=[ - c.get("issue", "") for c in reflection_output.critique_summary - ], - additional_questions_identified=reflection_output.additional_questions, - searches_performed=approval_decision.selected_queries, - improvements_made=improvements_count, - ) - - # Add approval decision info to metadata - if hasattr(reflection_metadata, "__dict__"): - reflection_metadata.__dict__["user_decision"] = ( - approval_decision.approval_method + # Update reflection metadata with search info + if analysis_data.reflection_metadata: + analysis_data.reflection_metadata.searches_performed = ( + approval_decision.selected_queries ) - reflection_metadata.__dict__["reviewer_notes"] = ( - approval_decision.reviewer_notes + analysis_data.reflection_metadata.improvements_made = float( + improvements_count ) logger.info( f"Completed approved searches with {improvements_count} improvements" ) - state.update_after_reflection(enhanced_info, reflection_metadata) - # Calculate metrics for metadata execution_time = time.time() - start_time total_results = sum( @@ -332,9 +331,7 @@ def execute_approved_searches_step( "execution_time_seconds": execution_time, "approval_method": approval_decision.approval_method, "approval_status": "approved", - "num_queries_recommended": len( - reflection_output.recommended_queries - ), + "num_queries_recommended": len(recommended_queries), "num_queries_approved": len( approval_decision.selected_queries ), @@ -344,14 +341,13 @@ def execute_approved_searches_step( "total_search_results": total_results, "questions_enhanced": questions_enhanced, "improvements_made": improvements_count, - "num_additional_questions": len( - reflection_output.additional_questions - ), "search_provider": search_provider, "search_mode": search_mode, "llm_model": llm_model, "success": True, - "total_search_cost": state.search_costs.get("exa", 0.0), + "total_search_cost": enhanced_search_data.search_costs.get( + "exa", 0.0 + ), } } ) @@ -359,44 +355,69 @@ def execute_approved_searches_step( # Log artifact metadata log_metadata( metadata={ - "enhanced_state_after_approval": { - "total_questions": len(enhanced_info), - "questions_with_improvements": sum( - 1 - for info in enhanced_info.values() - if info.improvements + "search_data_characteristics": { + "new_searches": len(approval_decision.selected_queries), + "total_searches": enhanced_search_data.total_searches, + "additional_cost": enhanced_search_data.search_costs.get( + "exa", 0.0 + ) + - search_data.search_costs.get("exa", 0.0), + } + }, + artifact_name="enhanced_search_data", + infer_artifact=True, + ) + + log_metadata( + metadata={ + "synthesis_data_characteristics": { + "questions_enhanced": questions_enhanced, + "total_enhancements": len( + enhanced_synthesis_data.enhanced_info ), - "total_improvements": sum( - len(info.improvements) - for info in enhanced_info.values() + "improvements_made": improvements_count, + } + }, + artifact_name="enhanced_synthesis_data", + infer_artifact=True, + ) + + log_metadata( + metadata={ + "analysis_data_characteristics": { + "searches_performed": len( + approval_decision.selected_queries ), "approval_method": approval_decision.approval_method, } }, + artifact_name="updated_analysis_data", infer_artifact=True, ) - # Add tags to the artifact - add_tags(tags=["state", "enhanced"], artifact="updated_state") + # Add tags to the artifacts + # add_tags(tags=["search", "enhanced"], artifact_name="enhanced_search_data", infer_artifact=True) + # add_tags( + # tags=["synthesis", "enhanced"], artifact_name="enhanced_synthesis_data", infer_artifact=True + # ) + # add_tags( + # tags=["analysis", "with-searches"], + # artifact_name="updated_analysis_data", + # infer_artifact=True, + # ) - return state + return enhanced_search_data, enhanced_synthesis_data, analysis_data except Exception as e: logger.error(f"Error during approved search execution: {e}") - # Create error metadata - error_metadata = ReflectionMetadata( - error=f"Approved search execution failed: {str(e)}", - critique_summary=[ - c.get("issue", "") for c in reflection_output.critique_summary - ], - additional_questions_identified=reflection_output.additional_questions, - searches_performed=[], - improvements_made=0, - ) - - # Update the state with the original synthesized info as enhanced info - state.update_after_reflection(state.synthesized_info, error_metadata) + # Update reflection metadata with error + if analysis_data.reflection_metadata: + analysis_data.reflection_metadata.error = ( + f"Approved search execution failed: {str(e)}" + ) + analysis_data.reflection_metadata.searches_performed = [] + analysis_data.reflection_metadata.improvements_made = 0.0 # Log error metadata execution_time = time.time() - start_time @@ -419,7 +440,11 @@ def execute_approved_searches_step( } ) - # Add tags to the artifact - add_tags(tags=["state", "enhanced"], artifact="updated_state") + # Add tags to the artifacts + # add_tags(tags=["search", "error"], artifact_name="enhanced_search_data", infer_artifact=True) + # add_tags( + # tags=["synthesis", "error"], artifact_name="enhanced_synthesis_data", infer_artifact=True + # ) + # add_tags(tags=["analysis", "error"], artifact_name="updated_analysis_data", infer_artifact=True) - return state + return enhanced_search_data, enhanced_synthesis_data, analysis_data diff --git a/deep_research/steps/generate_reflection_step.py b/deep_research/steps/generate_reflection_step.py index 9081cc84..61cf4637 100644 --- a/deep_research/steps/generate_reflection_step.py +++ b/deep_research/steps/generate_reflection_step.py @@ -1,25 +1,38 @@ import json import logging import time -from typing import Annotated +from typing import Annotated, List, Tuple -from materializers.reflection_output_materializer import ( - ReflectionOutputMaterializer, -) +from materializers.analysis_data_materializer import AnalysisDataMaterializer from utils.llm_utils import get_structured_llm_output -from utils.pydantic_models import Prompt, ReflectionOutput, ResearchState -from zenml import add_tags, log_metadata, step +from utils.pydantic_models import ( + AnalysisData, + Prompt, + QueryContext, + ReflectionMetadata, + SynthesisData, +) +from zenml import log_metadata, step logger = logging.getLogger(__name__) -@step(output_materializers=ReflectionOutputMaterializer) +@step( + output_materializers={ + "analysis_data": AnalysisDataMaterializer, + } +) def generate_reflection_step( - state: ResearchState, + query_context: QueryContext, + synthesis_data: SynthesisData, + analysis_data: AnalysisData, reflection_prompt: Prompt, llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", langfuse_project_name: str = "deep-research", -) -> Annotated[ReflectionOutput, "reflection_output"]: +) -> Tuple[ + Annotated[AnalysisData, "analysis_data"], + Annotated[List[str], "recommended_queries"], +]: """ Generate reflection and recommendations WITHOUT executing searches. @@ -27,12 +40,15 @@ def generate_reflection_step( for additional research that could improve the quality of the results. Args: - state: The current research state + query_context: The query context with main query and sub-questions + synthesis_data: The synthesized information + analysis_data: The analysis data with viewpoint analysis reflection_prompt: Prompt for generating reflection llm_model: The model to use for reflection + langfuse_project_name: Project name for tracing Returns: - ReflectionOutput containing the state, recommendations, and critique + Tuple of updated AnalysisData and recommended queries """ start_time = time.time() logger.info("Generating reflection on research") @@ -45,28 +61,28 @@ def generate_reflection_step( "confidence_level": info.confidence_level, "information_gaps": info.information_gaps, } - for question, info in state.synthesized_info.items() + for question, info in synthesis_data.synthesized_info.items() } viewpoint_analysis_dict = None - if state.viewpoint_analysis: + if analysis_data.viewpoint_analysis: # Convert the viewpoint analysis to a dict for the LLM tension_list = [] - for tension in state.viewpoint_analysis.areas_of_tension: + for tension in analysis_data.viewpoint_analysis.areas_of_tension: tension_list.append( {"topic": tension.topic, "viewpoints": tension.viewpoints} ) viewpoint_analysis_dict = { - "main_points_of_agreement": state.viewpoint_analysis.main_points_of_agreement, + "main_points_of_agreement": analysis_data.viewpoint_analysis.main_points_of_agreement, "areas_of_tension": tension_list, - "perspective_gaps": state.viewpoint_analysis.perspective_gaps, - "integrative_insights": state.viewpoint_analysis.integrative_insights, + "perspective_gaps": analysis_data.viewpoint_analysis.perspective_gaps, + "integrative_insights": analysis_data.viewpoint_analysis.integrative_insights, } reflection_input = { - "main_query": state.main_query, - "sub_questions": state.sub_questions, + "main_query": query_context.main_query, + "sub_questions": query_context.sub_questions, "synthesized_information": synthesized_info_dict, } @@ -92,14 +108,21 @@ def generate_reflection_step( project=langfuse_project_name, ) - # Prepare return value - reflection_output = ReflectionOutput( - state=state, - recommended_queries=reflection_result.get( - "recommended_search_queries", [] - ), - critique_summary=reflection_result.get("critique", []), - additional_questions=reflection_result.get("additional_questions", []), + # Extract results + recommended_queries = reflection_result.get( + "recommended_search_queries", [] + ) + critique_summary = reflection_result.get("critique", []) + additional_questions = reflection_result.get("additional_questions", []) + + # Update analysis data with reflection metadata + analysis_data.reflection_metadata = ReflectionMetadata( + critique_summary=[ + str(c) for c in critique_summary + ], # Convert to strings + additional_questions_identified=additional_questions, + searches_performed=[], # Will be populated by execute_approved_searches_step + improvements_made=0.0, # Will be updated later ) # Calculate execution time @@ -107,7 +130,8 @@ def generate_reflection_step( # Count confidence levels in synthesized info confidence_levels = [ - info.confidence_level for info in state.synthesized_info.values() + info.confidence_level + for info in synthesis_data.synthesized_info.values() ] confidence_distribution = { "high": confidence_levels.count("high"), @@ -121,20 +145,18 @@ def generate_reflection_step( "reflection_generation": { "execution_time_seconds": execution_time, "llm_model": llm_model, - "num_sub_questions_analyzed": len(state.sub_questions), - "num_synthesized_answers": len(state.synthesized_info), - "viewpoint_analysis_included": bool(viewpoint_analysis_dict), - "num_critique_points": len(reflection_output.critique_summary), - "num_additional_questions": len( - reflection_output.additional_questions - ), - "num_recommended_queries": len( - reflection_output.recommended_queries + "num_sub_questions_analyzed": len(query_context.sub_questions), + "num_synthesized_answers": len( + synthesis_data.synthesized_info ), + "viewpoint_analysis_included": bool(viewpoint_analysis_dict), + "num_critique_points": len(critique_summary), + "num_additional_questions": len(additional_questions), + "num_recommended_queries": len(recommended_queries), "confidence_distribution": confidence_distribution, "has_information_gaps": any( info.information_gaps - for info in state.synthesized_info.values() + for info in synthesis_data.synthesized_info.values() ), } } @@ -143,24 +165,33 @@ def generate_reflection_step( # Log artifact metadata log_metadata( metadata={ - "reflection_output_characteristics": { - "has_recommendations": bool( - reflection_output.recommended_queries - ), - "has_critique": bool(reflection_output.critique_summary), - "has_additional_questions": bool( - reflection_output.additional_questions - ), - "total_recommendations": len( - reflection_output.recommended_queries - ) - + len(reflection_output.additional_questions), + "analysis_data_characteristics": { + "has_reflection_metadata": True, + "has_viewpoint_analysis": analysis_data.viewpoint_analysis + is not None, + "num_critique_points": len(critique_summary), + "num_additional_questions": len(additional_questions), + } + }, + artifact_name="analysis_data", + infer_artifact=True, + ) + + log_metadata( + metadata={ + "recommended_queries_characteristics": { + "num_queries": len(recommended_queries), + "has_recommendations": bool(recommended_queries), } }, + artifact_name="recommended_queries", infer_artifact=True, ) - # Add tags to the artifact - add_tags(tags=["reflection", "critique"], artifact="reflection_output") + # Add tags to the artifacts + # add_tags(tags=["analysis", "reflection"], artifact_name="analysis_data", infer_artifact=True) + # add_tags( + # tags=["recommendations", "queries"], artifact_name="recommended_queries", infer_artifact=True + # ) - return reflection_output + return analysis_data, recommended_queries diff --git a/deep_research/steps/initialize_prompts_step.py b/deep_research/steps/initialize_prompts_step.py index f8df4c75..fe47c395 100644 --- a/deep_research/steps/initialize_prompts_step.py +++ b/deep_research/steps/initialize_prompts_step.py @@ -10,7 +10,7 @@ from materializers.prompt_materializer import PromptMaterializer from utils import prompts from utils.pydantic_models import Prompt -from zenml import add_tags, step +from zenml import step logger = logging.getLogger(__name__) @@ -119,26 +119,27 @@ def initialize_prompts_step( logger.info(f"Loaded 9 individual prompts") - # add tags to all prompts - add_tags(tags=["prompt", "search"], artifact="search_query_prompt") - add_tags( - tags=["prompt", "generation"], artifact="query_decomposition_prompt" - ) - add_tags(tags=["prompt", "generation"], artifact="synthesis_prompt") - add_tags( - tags=["prompt", "generation"], artifact="viewpoint_analysis_prompt" - ) - add_tags(tags=["prompt", "generation"], artifact="reflection_prompt") - add_tags( - tags=["prompt", "generation"], artifact="additional_synthesis_prompt" - ) - add_tags( - tags=["prompt", "generation"], artifact="conclusion_generation_prompt" - ) - add_tags( - tags=["prompt", "generation"], artifact="executive_summary_prompt" - ) - add_tags(tags=["prompt", "generation"], artifact="introduction_prompt") + # # add tags to all prompts + # add_tags(tags=["prompt", "search"], artifact_name="search_query_prompt", infer_artifact=True) + + # add_tags( + # tags=["prompt", "generation"], artifact_name="query_decomposition_prompt", infer_artifact=True + # ) + # add_tags(tags=["prompt", "generation"], artifact_name="synthesis_prompt", infer_artifact=True) + # add_tags( + # tags=["prompt", "generation"], artifact_name="viewpoint_analysis_prompt", infer_artifact=True + # ) + # add_tags(tags=["prompt", "generation"], artifact_name="reflection_prompt", infer_artifact=True) + # add_tags( + # tags=["prompt", "generation"], artifact_name="additional_synthesis_prompt", infer_artifact=True + # ) + # add_tags( + # tags=["prompt", "generation"], artifact_name="conclusion_generation_prompt", infer_artifact=True + # ) + # add_tags( + # tags=["prompt", "generation"], artifact_name="executive_summary_prompt", infer_artifact=True + # ) + # add_tags(tags=["prompt", "generation"], artifact_name="introduction_prompt", infer_artifact=True) return ( search_query_prompt, diff --git a/deep_research/steps/iterative_reflection_step.py b/deep_research/steps/iterative_reflection_step.py deleted file mode 100644 index 593c26d9..00000000 --- a/deep_research/steps/iterative_reflection_step.py +++ /dev/null @@ -1,391 +0,0 @@ -import json -import logging -import time -from typing import Annotated - -from materializers.pydantic_materializer import ResearchStateMaterializer -from utils.llm_utils import ( - find_most_relevant_string, - get_structured_llm_output, - is_text_relevant, -) -from utils.prompts import ADDITIONAL_SYNTHESIS_PROMPT, REFLECTION_PROMPT -from utils.pydantic_models import ( - ReflectionMetadata, - ResearchState, - SynthesizedInfo, -) -from utils.search_utils import search_and_extract_results -from zenml import add_tags, log_metadata, step - -logger = logging.getLogger(__name__) - - -@step(output_materializers=ResearchStateMaterializer) -def iterative_reflection_step( - state: ResearchState, - max_additional_searches: int = 2, - num_results_per_search: int = 3, - cap_search_length: int = 20000, - llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", - reflection_prompt: str = REFLECTION_PROMPT, - additional_synthesis_prompt: str = ADDITIONAL_SYNTHESIS_PROMPT, -) -> Annotated[ResearchState, "reflected_state"]: - """Perform iterative reflection on the research, identifying gaps and improving it. - - Args: - state: The current research state - max_additional_searches: Maximum number of additional searches to perform - num_results_per_search: Number of results to fetch per search - cap_search_length: Maximum length of content to process from search results - llm_model: The model to use for reflection - reflection_prompt: System prompt for the reflection - additional_synthesis_prompt: System prompt for incorporating additional information - - Returns: - Updated research state with enhanced information and reflection metadata - """ - start_time = time.time() - logger.info("Starting iterative reflection on research") - - # Prepare input for reflection - synthesized_info_dict = { - question: { - "synthesized_answer": info.synthesized_answer, - "key_sources": info.key_sources, - "confidence_level": info.confidence_level, - "information_gaps": info.information_gaps, - } - for question, info in state.synthesized_info.items() - } - - viewpoint_analysis_dict = None - if state.viewpoint_analysis: - # Convert the viewpoint analysis to a dict for the LLM - tension_list = [] - for tension in state.viewpoint_analysis.areas_of_tension: - tension_list.append( - {"topic": tension.topic, "viewpoints": tension.viewpoints} - ) - - viewpoint_analysis_dict = { - "main_points_of_agreement": state.viewpoint_analysis.main_points_of_agreement, - "areas_of_tension": tension_list, - "perspective_gaps": state.viewpoint_analysis.perspective_gaps, - "integrative_insights": state.viewpoint_analysis.integrative_insights, - } - - reflection_input = { - "main_query": state.main_query, - "sub_questions": state.sub_questions, - "synthesized_information": synthesized_info_dict, - } - - if viewpoint_analysis_dict: - reflection_input["viewpoint_analysis"] = viewpoint_analysis_dict - - # Get reflection critique - try: - logger.info(f"Generating self-critique via {llm_model}") - - # Define fallback for reflection result - fallback_reflection = { - "critique": [], - "additional_questions": [], - "recommended_search_queries": [], - } - - # Use utility function to get structured output - reflection_result = get_structured_llm_output( - prompt=json.dumps(reflection_input), - system_prompt=reflection_prompt, - model=llm_model, - fallback_response=fallback_reflection, - ) - - # Make a deep copy of the synthesized info to create enhanced_info - enhanced_info = { - k: SynthesizedInfo( - synthesized_answer=v.synthesized_answer, - key_sources=v.key_sources.copy(), - confidence_level=v.confidence_level, - information_gaps=v.information_gaps, - improvements=v.improvements.copy() - if hasattr(v, "improvements") - else [], - ) - for k, v in state.synthesized_info.items() - } - - # Perform additional searches based on recommendations - search_queries = reflection_result.get( - "recommended_search_queries", [] - ) - if max_additional_searches > 0 and search_queries: - # Limit to max_additional_searches - search_queries = search_queries[:max_additional_searches] - - for query in search_queries: - logger.info(f"Performing additional search: {query}") - # Execute the search using the utility function - search_results, search_cost = search_and_extract_results( - query=query, - max_results=num_results_per_search, - cap_content_length=cap_search_length, - ) - - # Extract raw contents - raw_contents = [result.content for result in search_results] - - # Find the most relevant sub-question for this query - most_relevant_question = find_most_relevant_string( - query, state.sub_questions, llm_model - ) - - # Track search costs if using Exa (default provider) - # Note: This step doesn't have a search_provider parameter, so we check the default - from utils.search_utils import SearchEngineConfig - - config = SearchEngineConfig() - if ( - config.default_provider.lower() in ["exa", "both"] - and search_cost > 0 - ): - # Update total costs - state.search_costs["exa"] = ( - state.search_costs.get("exa", 0.0) + search_cost - ) - - # Add detailed cost entry - state.search_cost_details.append( - { - "provider": "exa", - "query": query, - "cost": search_cost, - "timestamp": time.time(), - "step": "iterative_reflection", - "purpose": "gap_filling", - "relevant_question": most_relevant_question, - } - ) - logger.info( - f"Exa search cost for reflection query: ${search_cost:.4f}" - ) - - if ( - most_relevant_question - and most_relevant_question in enhanced_info - ): - # Enhance the synthesis with new information - enhancement_input = { - "original_synthesis": enhanced_info[ - most_relevant_question - ].synthesized_answer, - "new_information": raw_contents, - "critique": [ - item - for item in reflection_result.get("critique", []) - if is_text_relevant( - item.get("issue", ""), most_relevant_question - ) - ], - } - - # Use the utility function for enhancement - enhanced_synthesis = get_structured_llm_output( - prompt=json.dumps(enhancement_input), - system_prompt=additional_synthesis_prompt, - model=llm_model, - fallback_response={ - "enhanced_synthesis": enhanced_info[ - most_relevant_question - ].synthesized_answer, - "improvements_made": [ - "Failed to enhance synthesis" - ], - "remaining_limitations": "Enhancement process failed.", - }, - ) - - if ( - enhanced_synthesis - and "enhanced_synthesis" in enhanced_synthesis - ): - # Update the synthesized answer - enhanced_info[ - most_relevant_question - ].synthesized_answer = enhanced_synthesis[ - "enhanced_synthesis" - ] - - # Add improvements - improvements = enhanced_synthesis.get( - "improvements_made", [] - ) - enhanced_info[ - most_relevant_question - ].improvements.extend(improvements) - - # Add any additional questions as new synthesized entries - for new_question in reflection_result.get("additional_questions", []): - if ( - new_question not in state.sub_questions - and new_question not in enhanced_info - ): - enhanced_info[new_question] = SynthesizedInfo( - synthesized_answer=f"This question was identified during reflection but has not yet been researched: {new_question}", - key_sources=[], - confidence_level="low", - information_gaps="This question requires additional research.", - ) - - # Prepare metadata about the reflection process - reflection_metadata = ReflectionMetadata( - critique_summary=[ - item.get("issue", "") - for item in reflection_result.get("critique", []) - ], - additional_questions_identified=reflection_result.get( - "additional_questions", [] - ), - searches_performed=search_queries, - improvements_made=sum( - [len(info.improvements) for info in enhanced_info.values()] - ), - ) - - logger.info( - f"Completed iterative reflection with {reflection_metadata.improvements_made} improvements" - ) - - # Update the state with enhanced info and metadata - state.update_after_reflection(enhanced_info, reflection_metadata) - - # Calculate execution time - execution_time = time.time() - start_time - - # Count questions that were enhanced - questions_enhanced = 0 - for question, enhanced in enhanced_info.items(): - if question in state.synthesized_info: - original = state.synthesized_info[question] - if enhanced.synthesized_answer != original.synthesized_answer: - questions_enhanced += 1 - - # Calculate confidence level changes - confidence_improvements = {"improved": 0, "unchanged": 0, "new": 0} - for question, enhanced in enhanced_info.items(): - if question in state.synthesized_info: - original = state.synthesized_info[question] - original_level = original.confidence_level.lower() - enhanced_level = enhanced.confidence_level.lower() - - level_map = {"low": 0, "medium": 1, "high": 2} - if enhanced_level in level_map and original_level in level_map: - if level_map[enhanced_level] > level_map[original_level]: - confidence_improvements["improved"] += 1 - else: - confidence_improvements["unchanged"] += 1 - else: - confidence_improvements["new"] += 1 - - # Log metadata - log_metadata( - metadata={ - "iterative_reflection": { - "execution_time_seconds": execution_time, - "llm_model": llm_model, - "max_additional_searches": max_additional_searches, - "searches_performed": len(search_queries), - "num_critique_points": len( - reflection_result.get("critique", []) - ), - "num_additional_questions": len( - reflection_result.get("additional_questions", []) - ), - "questions_enhanced": questions_enhanced, - "total_improvements": reflection_metadata.improvements_made, - "confidence_improvements": confidence_improvements, - "has_viewpoint_analysis": bool(viewpoint_analysis_dict), - "total_search_cost": state.search_costs.get("exa", 0.0), - } - } - ) - - # Log model metadata for cross-pipeline tracking - log_metadata( - metadata={ - "improvement_metrics": { - "confidence_improvements": confidence_improvements, - "total_improvements": reflection_metadata.improvements_made, - } - }, - infer_model=True, - ) - - # Log artifact metadata - log_metadata( - metadata={ - "enhanced_state_characteristics": { - "total_questions": len(enhanced_info), - "questions_with_improvements": sum( - 1 - for info in enhanced_info.values() - if info.improvements - ), - "high_confidence_count": sum( - 1 - for info in enhanced_info.values() - if info.confidence_level.lower() == "high" - ), - "medium_confidence_count": sum( - 1 - for info in enhanced_info.values() - if info.confidence_level.lower() == "medium" - ), - "low_confidence_count": sum( - 1 - for info in enhanced_info.values() - if info.confidence_level.lower() == "low" - ), - } - }, - infer_artifact=True, - ) - - # Add tags to the artifact - add_tags(tags=["state", "reflected"], artifact="reflected_state") - - return state - - except Exception as e: - logger.error(f"Error during iterative reflection: {e}") - - # Create error metadata - error_metadata = ReflectionMetadata( - error=f"Reflection failed: {str(e)}" - ) - - # Update the state with the original synthesized info as enhanced info - # and the error metadata - state.update_after_reflection(state.synthesized_info, error_metadata) - - # Log error metadata - execution_time = time.time() - start_time - log_metadata( - metadata={ - "iterative_reflection": { - "execution_time_seconds": execution_time, - "llm_model": llm_model, - "max_additional_searches": max_additional_searches, - "searches_performed": 0, - "status": "failed", - "error_message": str(e), - } - } - ) - - # Add tags to the artifact - add_tags(tags=["state", "reflected"], artifact="reflected_state") - - return state diff --git a/deep_research/steps/merge_results_step.py b/deep_research/steps/merge_results_step.py index d334c610..4802c98f 100644 --- a/deep_research/steps/merge_results_step.py +++ b/deep_research/steps/merge_results_step.py @@ -1,35 +1,38 @@ -import copy import logging import time -from typing import Annotated +from typing import Annotated, Tuple -from materializers.pydantic_materializer import ResearchStateMaterializer -from utils.pydantic_models import ResearchState -from zenml import add_tags, get_step_context, log_metadata, step +from materializers.search_data_materializer import SearchDataMaterializer +from materializers.synthesis_data_materializer import SynthesisDataMaterializer +from utils.pydantic_models import SearchData, SynthesisData +from zenml import get_step_context, log_metadata, step from zenml.client import Client logger = logging.getLogger(__name__) -@step(output_materializers=ResearchStateMaterializer) +@step( + output_materializers={ + "merged_search_data": SearchDataMaterializer, + "merged_synthesis_data": SynthesisDataMaterializer, + } +) def merge_sub_question_results_step( - original_state: ResearchState, step_prefix: str = "process_question_", - output_name: str = "output", -) -> Annotated[ResearchState, "merged_state"]: +) -> Tuple[ + Annotated[SearchData, "merged_search_data"], + Annotated[SynthesisData, "merged_synthesis_data"], +]: """Merge results from individual sub-question processing steps. This step collects the results from the parallel sub-question processing steps - and combines them into a single, comprehensive state object. + and combines them into single SearchData and SynthesisData artifacts. Args: - original_state: The original research state with all sub-questions step_prefix: The prefix used in step IDs for the parallel processing steps - output_name: The name of the output artifact from the processing steps Returns: - Annotated[ResearchState, "merged_state"]: A merged ResearchState with combined - results from all sub-questions + Tuple of merged SearchData and SynthesisData artifacts Note: This step is typically configured with the 'after' parameter in the pipeline @@ -38,23 +41,16 @@ def merge_sub_question_results_step( """ start_time = time.time() - # Start with the original state that has all sub-questions - merged_state = copy.deepcopy(original_state) - - # Initialize empty dictionaries for the results - merged_state.search_results = {} - merged_state.synthesized_info = {} - - # Initialize search cost tracking - merged_state.search_costs = {} - merged_state.search_cost_details = [] + # Initialize merged artifacts + merged_search_data = SearchData() + merged_synthesis_data = SynthesisData() # Get pipeline run information to access outputs try: ctx = get_step_context() if not ctx or not ctx.pipeline_run: logger.error("Could not get pipeline run context") - return merged_state + return merged_search_data, merged_synthesis_data run_name = ctx.pipeline_run.name client = Client() @@ -80,77 +76,45 @@ def merge_sub_question_results_step( f"Processing results from step: {step_name} (index: {index})" ) - # Get the output artifact - if output_name in step_info.outputs: - output_artifacts = step_info.outputs[output_name] - if output_artifacts: - output_artifact = output_artifacts[0] - sub_state = output_artifact.load() - - # Check if the sub-state has valid data - if ( - hasattr(sub_state, "sub_questions") - and sub_state.sub_questions - ): - sub_question = sub_state.sub_questions[0] + # Get the search_data artifact + if "search_data" in step_info.outputs: + search_artifacts = step_info.outputs["search_data"] + if search_artifacts: + search_artifact = search_artifacts[0] + sub_search_data = search_artifact.load() + + # Merge search data + merged_search_data.merge(sub_search_data) + + # Track processed questions + for sub_q in sub_search_data.search_results: + processed_questions.add(sub_q) logger.info( - f"Found results for sub-question: {sub_question}" + f"Merged search results for: {sub_q}" ) - parallel_steps_processed += 1 - processed_questions.add(sub_question) - - # Merge search results - if ( - hasattr(sub_state, "search_results") - and sub_question - in sub_state.search_results - ): - merged_state.search_results[ - sub_question - ] = sub_state.search_results[ - sub_question - ] - logger.info( - f"Added search results for: {sub_question}" - ) - - # Merge synthesized info - if ( - hasattr(sub_state, "synthesized_info") - and sub_question - in sub_state.synthesized_info - ): - merged_state.synthesized_info[ - sub_question - ] = sub_state.synthesized_info[ - sub_question - ] - logger.info( - f"Added synthesized info for: {sub_question}" - ) - - # Merge search costs - if hasattr(sub_state, "search_costs"): - for ( - provider, - cost, - ) in sub_state.search_costs.items(): - merged_state.search_costs[ - provider - ] = ( - merged_state.search_costs.get( - provider, 0.0 - ) - + cost - ) - - # Merge search cost details - if hasattr( - sub_state, "search_cost_details" - ): - merged_state.search_cost_details.extend( - sub_state.search_cost_details - ) + + # Get the synthesis_data artifact + if "synthesis_data" in step_info.outputs: + synthesis_artifacts = step_info.outputs[ + "synthesis_data" + ] + if synthesis_artifacts: + synthesis_artifact = synthesis_artifacts[0] + sub_synthesis_data = synthesis_artifact.load() + + # Merge synthesis data + merged_synthesis_data.merge(sub_synthesis_data) + + # Track processed questions + for ( + sub_q + ) in sub_synthesis_data.synthesized_info: + logger.info( + f"Merged synthesis info for: {sub_q}" + ) + + parallel_steps_processed += 1 + except (ValueError, IndexError, KeyError, AttributeError) as e: logger.warning(f"Error processing step {step_name}: {e}") continue @@ -164,24 +128,22 @@ def merge_sub_question_results_step( ) # Log search cost summary - if merged_state.search_costs: - total_cost = sum(merged_state.search_costs.values()) + if merged_search_data.search_costs: + total_cost = sum(merged_search_data.search_costs.values()) logger.info( - f"Total search costs merged: ${total_cost:.4f} across {len(merged_state.search_cost_details)} queries" + f"Total search costs merged: ${total_cost:.4f} across {len(merged_search_data.search_cost_details)} queries" ) - for provider, cost in merged_state.search_costs.items(): + for provider, cost in merged_search_data.search_costs.items(): logger.info(f" {provider}: ${cost:.4f}") - # Check for any missing sub-questions - for sub_q in merged_state.sub_questions: - if sub_q not in processed_questions: - logger.warning(f"Missing results for sub-question: {sub_q}") - except Exception as e: logger.error(f"Error during merge step: {e}") # Final check for empty results - if not merged_state.search_results or not merged_state.synthesized_info: + if ( + not merged_search_data.search_results + or not merged_synthesis_data.synthesized_info + ): logger.warning( "No results were found or merged from parallel processing steps!" ) @@ -189,80 +151,81 @@ def merge_sub_question_results_step( # Calculate execution time execution_time = time.time() - start_time - # Calculate metrics - missing_questions = [ - q for q in merged_state.sub_questions if q not in processed_questions - ] - # Count total search results across all questions total_search_results = sum( - len(results) for results in merged_state.search_results.values() + len(results) for results in merged_search_data.search_results.values() ) # Get confidence distribution for merged results confidence_distribution = {"high": 0, "medium": 0, "low": 0} - for info in merged_state.synthesized_info.values(): + for info in merged_synthesis_data.synthesized_info.values(): level = info.confidence_level.lower() if level in confidence_distribution: confidence_distribution[level] += 1 - # Calculate completeness ratio - completeness_ratio = ( - len(processed_questions) / len(merged_state.sub_questions) - if merged_state.sub_questions - else 0 - ) - # Log metadata log_metadata( metadata={ "merge_results": { "execution_time_seconds": execution_time, - "total_sub_questions": len(merged_state.sub_questions), "parallel_steps_processed": parallel_steps_processed, "questions_successfully_merged": len(processed_questions), - "missing_questions_count": len(missing_questions), - "missing_questions": missing_questions[:5] - if missing_questions - else [], # Limit to 5 for metadata "total_search_results": total_search_results, "confidence_distribution": confidence_distribution, "merge_success": bool( - merged_state.search_results - and merged_state.synthesized_info + merged_search_data.search_results + and merged_synthesis_data.synthesized_info + ), + "total_search_costs": merged_search_data.search_costs, + "total_search_queries": len( + merged_search_data.search_cost_details + ), + "total_exa_cost": merged_search_data.search_costs.get( + "exa", 0.0 ), - "total_search_costs": merged_state.search_costs, - "total_search_queries": len(merged_state.search_cost_details), - "total_exa_cost": merged_state.search_costs.get("exa", 0.0), } } ) - # Log model metadata for cross-pipeline tracking + # Log artifact metadata log_metadata( metadata={ - "research_quality": { - "completeness_ratio": completeness_ratio, + "search_data_characteristics": { + "total_searches": merged_search_data.total_searches, + "search_results_count": len(merged_search_data.search_results), + "total_cost": sum(merged_search_data.search_costs.values()), } }, - infer_model=True, + artifact_name="merged_search_data", + infer_artifact=True, ) - # Log artifact metadata log_metadata( metadata={ - "merged_state_characteristics": { - "has_search_results": bool(merged_state.search_results), - "has_synthesized_info": bool(merged_state.synthesized_info), - "search_results_count": len(merged_state.search_results), - "synthesized_info_count": len(merged_state.synthesized_info), - "completeness_ratio": completeness_ratio, + "synthesis_data_characteristics": { + "synthesized_info_count": len( + merged_synthesis_data.synthesized_info + ), + "enhanced_info_count": len( + merged_synthesis_data.enhanced_info + ), + "confidence_distribution": confidence_distribution, } }, + artifact_name="merged_synthesis_data", infer_artifact=True, ) - # Add tags to the artifact - add_tags(tags=["state", "merged"], artifact="merged_state") - - return merged_state + # Add tags to the artifacts + # add_tags( + # tags=["search", "merged"], + # artifact_name="merged_search_data", + # infer_artifact=True, + # ) + # add_tags( + # tags=["synthesis", "merged"], + # artifact_name="merged_synthesis_data", + # infer_artifact=True, + # ) + + return merged_search_data, merged_synthesis_data diff --git a/deep_research/steps/process_sub_question_step.py b/deep_research/steps/process_sub_question_step.py index bc1b12ac..7ff14ab1 100644 --- a/deep_research/steps/process_sub_question_step.py +++ b/deep_research/steps/process_sub_question_step.py @@ -1,8 +1,7 @@ -import copy import logging import time import warnings -from typing import Annotated +from typing import Annotated, Tuple # Suppress Pydantic serialization warnings from ZenML artifact metadata # These occur when ZenML stores timestamp metadata as floats but models expect ints @@ -10,21 +9,34 @@ "ignore", message=".*PydanticSerializationUnexpectedValue.*" ) -from materializers.pydantic_materializer import ResearchStateMaterializer +from materializers.search_data_materializer import SearchDataMaterializer +from materializers.synthesis_data_materializer import SynthesisDataMaterializer from utils.llm_utils import synthesize_information -from utils.pydantic_models import Prompt, ResearchState, SynthesizedInfo +from utils.pydantic_models import ( + Prompt, + QueryContext, + SearchCostDetail, + SearchData, + SynthesisData, + SynthesizedInfo, +) from utils.search_utils import ( generate_search_query, search_and_extract_results, ) -from zenml import add_tags, log_metadata, step +from zenml import log_metadata, step logger = logging.getLogger(__name__) -@step(output_materializers=ResearchStateMaterializer) +@step( + output_materializers={ + "search_data": SearchDataMaterializer, + "synthesis_data": SynthesisDataMaterializer, + } +) def process_sub_question_step( - state: ResearchState, + query_context: QueryContext, search_query_prompt: Prompt, synthesis_prompt: Prompt, question_index: int, @@ -35,14 +47,17 @@ def process_sub_question_step( search_provider: str = "tavily", search_mode: str = "auto", langfuse_project_name: str = "deep-research", -) -> Annotated[ResearchState, "output"]: +) -> Tuple[ + Annotated[SearchData, "search_data"], + Annotated[SynthesisData, "synthesis_data"], +]: """Process a single sub-question if it exists at the given index. This step combines the gathering and synthesis steps for a single sub-question. It's designed to be run in parallel for each sub-question. Args: - state: The original research state with all sub-questions + query_context: The query context with main query and sub-questions search_query_prompt: Prompt for generating search queries synthesis_prompt: Prompt for synthesizing search results question_index: The index of the sub-question to process @@ -52,25 +67,19 @@ def process_sub_question_step( cap_search_length: Maximum length of content to process from search results search_provider: Search provider to use (tavily, exa, or both) search_mode: Search mode for Exa provider (neural, keyword, or auto) + langfuse_project_name: Project name for tracing Returns: - A new ResearchState containing only the processed sub-question's results + Tuple of SearchData and SynthesisData for the processed sub-question """ start_time = time.time() - # Create a copy of the state to avoid modifying the original - sub_state = copy.deepcopy(state) - - # Clear all existing data except the main query - sub_state.search_results = {} - sub_state.synthesized_info = {} - sub_state.enhanced_info = {} - sub_state.viewpoint_analysis = None - sub_state.reflection_metadata = None - sub_state.final_report_html = "" + # Initialize empty artifacts + search_data = SearchData() + synthesis_data = SynthesisData() # Check if this index exists in sub-questions - if question_index >= len(state.sub_questions): + if question_index >= len(query_context.sub_questions): logger.info( f"No sub-question at index {question_index}, skipping processing" ) @@ -81,25 +90,25 @@ def process_sub_question_step( "question_index": question_index, "status": "skipped", "reason": "index_out_of_range", - "total_sub_questions": len(state.sub_questions), + "total_sub_questions": len(query_context.sub_questions), } } ) - # Return an empty state since there's no question to process - sub_state.sub_questions = [] - # Add tags to the artifact - add_tags(tags=["state", "sub-question"], artifact="output") - return sub_state + # Return empty artifacts + # add_tags( + # tags=["search", "synthesis", "skipped"], artifact_name="search_data", infer_artifact=True + # ) + # add_tags( + # tags=["search", "synthesis", "skipped"], artifact_name="synthesis_data", infer_artifact=True + # ) + return search_data, synthesis_data # Get the target sub-question - sub_question = state.sub_questions[question_index] + sub_question = query_context.sub_questions[question_index] logger.info( f"Processing sub-question {question_index + 1}: {sub_question}" ) - # Store only this sub-question in the sub-state - sub_state.sub_questions = [sub_question] - # === INFORMATION GATHERING === search_phase_start = time.time() @@ -126,6 +135,10 @@ def process_sub_question_step( search_mode=search_mode, ) + # Update search data + search_data.search_results[sub_question] = results_list + search_data.total_searches = 1 + # Track search costs if using Exa if ( search_provider @@ -133,29 +146,23 @@ def process_sub_question_step( and search_cost > 0 ): # Update total costs - sub_state.search_costs["exa"] = ( - sub_state.search_costs.get("exa", 0.0) + search_cost - ) + search_data.search_costs["exa"] = search_cost # Add detailed cost entry - sub_state.search_cost_details.append( - { - "provider": "exa", - "query": search_query, - "cost": search_cost, - "timestamp": time.time(), - "step": "process_sub_question", - "sub_question": sub_question, - "question_index": question_index, - } + search_data.search_cost_details.append( + SearchCostDetail( + provider="exa", + query=search_query, + cost=search_cost, + timestamp=time.time(), + step="process_sub_question", + sub_question=sub_question, + ) ) logger.info( f"Exa search cost for sub-question {question_index}: ${search_cost:.4f}" ) - search_results = {sub_question: results_list} - sub_state.update_search_results(search_results) - search_phase_time = time.time() - search_phase_start # === INFORMATION SYNTHESIS === @@ -184,23 +191,21 @@ def process_sub_question_step( ) # Create SynthesizedInfo object - synthesized_info = { - sub_question: SynthesizedInfo( - synthesized_answer=synthesis_result.get( - "synthesized_answer", f"Synthesis for '{sub_question}' failed." - ), - key_sources=synthesis_result.get("key_sources", sources[:1]), - confidence_level=synthesis_result.get("confidence_level", "low"), - information_gaps=synthesis_result.get( - "information_gaps", - "Synthesis process encountered technical difficulties.", - ), - improvements=synthesis_result.get("improvements", []), - ) - } + synthesized_info = SynthesizedInfo( + synthesized_answer=synthesis_result.get( + "synthesized_answer", f"Synthesis for '{sub_question}' failed." + ), + key_sources=synthesis_result.get("key_sources", sources[:1]), + confidence_level=synthesis_result.get("confidence_level", "low"), + information_gaps=synthesis_result.get( + "information_gaps", + "Synthesis process encountered technical difficulties.", + ), + improvements=synthesis_result.get("improvements", []), + ) - # Update the state with synthesized information - sub_state.update_synthesized_info(synthesized_info) + # Update synthesis data + synthesis_data.synthesized_info[sub_question] = synthesized_info synthesis_phase_time = time.time() - synthesis_phase_start total_execution_time = time.time() - start_time @@ -270,22 +275,41 @@ def process_sub_question_step( infer_model=True, ) - # Log artifact metadata for the output state + # Log artifact metadata for the output artifacts log_metadata( metadata={ - "sub_state_characteristics": { - "has_search_results": bool(sub_state.search_results), - "has_synthesized_info": bool(sub_state.synthesized_info), - "sub_question_processed": sub_question, + "search_data_characteristics": { + "sub_question": sub_question, + "num_results": len(results_list), + "search_provider": search_provider, + "search_cost": search_cost if search_cost > 0 else None, + } + }, + artifact_name="search_data", + infer_artifact=True, + ) + + log_metadata( + metadata={ + "synthesis_data_characteristics": { + "sub_question": sub_question, "confidence_level": synthesis_result.get( "confidence_level", "low" ), + "has_information_gaps": bool( + synthesis_result.get("information_gaps") + ), + "num_key_sources": len( + synthesis_result.get("key_sources", []) + ), } }, + artifact_name="synthesis_data", infer_artifact=True, ) - # Add tags to the artifact - add_tags(tags=["state", "sub-question"], artifact="output") + # Add tags to the artifacts + # add_tags(tags=["search", "sub-question"], artifact_name="search_data", infer_artifact=True) + # add_tags(tags=["synthesis", "sub-question"], artifact_name="synthesis_data", infer_artifact=True) - return sub_state + return search_data, synthesis_data diff --git a/deep_research/steps/pydantic_final_report_step.py b/deep_research/steps/pydantic_final_report_step.py index d61e848b..88cf2221 100644 --- a/deep_research/steps/pydantic_final_report_step.py +++ b/deep_research/steps/pydantic_final_report_step.py @@ -1,7 +1,7 @@ -"""Final report generation step using Pydantic models and materializers. +"""Final report generation step using artifact-based approach. This module provides a ZenML pipeline step for generating the final HTML research report -using Pydantic models and improved materializers. +using the new artifact-based approach. """ import html @@ -11,7 +11,7 @@ import time from typing import Annotated, Tuple -from materializers.pydantic_materializer import ResearchStateMaterializer +from materializers.final_report_materializer import FinalReportMaterializer from utils.helper_functions import ( extract_html_from_content, remove_reasoning_from_output, @@ -22,8 +22,15 @@ SUB_QUESTION_TEMPLATE, VIEWPOINT_ANALYSIS_TEMPLATE, ) -from utils.pydantic_models import Prompt, ResearchState -from zenml import add_tags, log_metadata, step +from utils.pydantic_models import ( + AnalysisData, + FinalReport, + Prompt, + QueryContext, + SearchData, + SynthesisData, +) +from zenml import log_metadata, step from zenml.types import HTMLString logger = logging.getLogger(__name__) @@ -99,53 +106,84 @@ def format_text_with_code_blocks(text: str) -> str: if not text: return "" - # First escape HTML - escaped_text = html.escape(text) - - # Handle code blocks (wrap content in ``` or ```) - pattern = r"```(?:\w*\n)?(.*?)```" - - def code_block_replace(match): - code_content = match.group(1) - # Strip extra newlines at beginning and end - code_content = code_content.strip("\n") - return f"
    {code_content}
    " - - # Replace code blocks - formatted_text = re.sub( - pattern, code_block_replace, escaped_text, flags=re.DOTALL - ) + # Handle code blocks + lines = text.split("\n") + formatted_lines = [] + in_code_block = False + code_language = "" + code_lines = [] + + for line in lines: + # Check for code block start + if line.strip().startswith("```"): + if in_code_block: + # End of code block + code_content = "\n".join(code_lines) + formatted_lines.append( + f'
    {html.escape(code_content)}
    ' + ) + code_lines = [] + in_code_block = False + code_language = "" + else: + # Start of code block + in_code_block = True + # Extract language if specified + code_language = line.strip()[3:].strip() or "plaintext" + elif in_code_block: + code_lines.append(line) + else: + # Process inline code + line = re.sub(r"`([^`]+)`", r"\1", html.escape(line)) + # Process bullet points + if line.strip().startswith("•") or line.strip().startswith("-"): + line = re.sub(r"^(\s*)[•-]\s*", r"\1", line) + formatted_lines.append(f"
  • {line.strip()}
  • ") + elif line.strip(): + formatted_lines.append(f"

    {line}

    ") + + # Handle case where code block wasn't closed + if in_code_block and code_lines: + code_content = "\n".join(code_lines) + formatted_lines.append( + f'
    {html.escape(code_content)}
    ' + ) - # Convert regular newlines to
    tags (but not inside
     blocks)
    -    parts = []
    -    in_pre = False
    -    for line in formatted_text.split("\n"):
    -        if "
    " in line:
    -            in_pre = True
    -            parts.append(line)
    -        elif "
    " in line: - in_pre = False - parts.append(line) - elif in_pre: - # Inside a code block, preserve newlines - parts.append(line) + # Wrap list items in ul tags + result = [] + in_list = False + for line in formatted_lines: + if line.startswith("
  • "): + if not in_list: + result.append("
      ") + in_list = True + result.append(line) else: - # Outside code blocks, convert newlines to
      - parts.append(line + "
      ") + if in_list: + result.append("
    ") + in_list = False + result.append(line) + + if in_list: + result.append("") - return "".join(parts) + return "\n".join(result) def generate_executive_summary( - state: ResearchState, + query_context: QueryContext, + synthesis_data: SynthesisData, + analysis_data: AnalysisData, executive_summary_prompt: Prompt, llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", langfuse_project_name: str = "deep-research", ) -> str: - """Generate an executive summary using LLM based on research findings. + """Generate an executive summary using LLM based on the complete research findings. Args: - state: The current research state + query_context: The query context with main query and sub-questions + synthesis_data: The synthesis data with all synthesized information + analysis_data: The analysis data with viewpoint analysis executive_summary_prompt: Prompt for generating executive summary llm_model: The model to use for generation langfuse_project_name: Name of the Langfuse project for tracking @@ -155,45 +193,48 @@ def generate_executive_summary( """ logger.info("Generating executive summary using LLM") - # Prepare the context with all research findings - context = f"Main Research Query: {state.main_query}\n\n" - - # Add synthesized findings for each sub-question - for i, sub_question in enumerate(state.sub_questions, 1): - info = state.enhanced_info.get( - sub_question - ) or state.synthesized_info.get(sub_question) - if info: - context += f"Sub-question {i}: {sub_question}\n" - context += f"Answer Summary: {info.synthesized_answer[:500]}...\n" - context += f"Confidence: {info.confidence_level}\n" - context += f"Key Sources: {', '.join(info.key_sources[:3]) if info.key_sources else 'N/A'}\n\n" - - # Add viewpoint analysis insights if available - if state.viewpoint_analysis: - context += "Key Areas of Agreement:\n" - for agreement in state.viewpoint_analysis.main_points_of_agreement[:3]: - context += f"- {agreement}\n" - context += "\nKey Tensions:\n" - for tension in state.viewpoint_analysis.areas_of_tension[:2]: - context += f"- {tension.topic}\n" - - # Use the executive summary prompt - try: - executive_summary_prompt_str = str(executive_summary_prompt) - logger.info("Successfully retrieved executive_summary_prompt") - except Exception as e: - logger.error(f"Failed to get executive_summary_prompt: {e}") - return generate_fallback_executive_summary(state) + # Prepare the context + summary_input = { + "main_query": query_context.main_query, + "sub_questions": query_context.sub_questions, + "key_findings": {}, + "viewpoint_analysis": None, + } + + # Include key findings from synthesis data + # Prefer enhanced info if available + info_source = ( + synthesis_data.enhanced_info + if synthesis_data.enhanced_info + else synthesis_data.synthesized_info + ) + + for question in query_context.sub_questions: + if question in info_source: + info = info_source[question] + summary_input["key_findings"][question] = { + "answer": info.synthesized_answer, + "confidence": info.confidence_level, + "gaps": info.information_gaps, + } + + # Include viewpoint analysis if available + if analysis_data.viewpoint_analysis: + va = analysis_data.viewpoint_analysis + summary_input["viewpoint_analysis"] = { + "agreements": va.main_points_of_agreement, + "tensions": len(va.areas_of_tension), + "insights": va.integrative_insights, + } try: # Call LLM to generate executive summary result = run_llm_completion( - prompt=context, - system_prompt=executive_summary_prompt_str, + prompt=json.dumps(summary_input), + system_prompt=str(executive_summary_prompt), model=llm_model, temperature=0.7, - max_tokens=800, + max_tokens=600, project=langfuse_project_name, tags=["executive_summary_generation"], ) @@ -206,15 +247,19 @@ def generate_executive_summary( return content else: logger.warning("Failed to generate executive summary via LLM") - return generate_fallback_executive_summary(state) + return generate_fallback_executive_summary( + query_context, synthesis_data + ) except Exception as e: logger.error(f"Error generating executive summary: {e}") - return generate_fallback_executive_summary(state) + return generate_fallback_executive_summary( + query_context, synthesis_data + ) def generate_introduction( - state: ResearchState, + query_context: QueryContext, introduction_prompt: Prompt, llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", langfuse_project_name: str = "deep-research", @@ -222,7 +267,7 @@ def generate_introduction( """Generate an introduction using LLM based on research query and sub-questions. Args: - state: The current research state + query_context: The query context with main query and sub-questions introduction_prompt: Prompt for generating introduction llm_model: The model to use for generation langfuse_project_name: Name of the Langfuse project for tracking @@ -233,24 +278,16 @@ def generate_introduction( logger.info("Generating introduction using LLM") # Prepare the context - context = f"Main Research Query: {state.main_query}\n\n" + context = f"Main Research Query: {query_context.main_query}\n\n" context += "Sub-questions being explored:\n" - for i, sub_question in enumerate(state.sub_questions, 1): + for i, sub_question in enumerate(query_context.sub_questions, 1): context += f"{i}. {sub_question}\n" - # Get the introduction prompt - try: - introduction_prompt_str = str(introduction_prompt) - logger.info("Successfully retrieved introduction_prompt") - except Exception as e: - logger.error(f"Failed to get introduction_prompt: {e}") - return generate_fallback_introduction(state) - try: # Call LLM to generate introduction result = run_llm_completion( prompt=context, - system_prompt=introduction_prompt_str, + system_prompt=str(introduction_prompt), model=llm_model, temperature=0.7, max_tokens=600, @@ -266,22 +303,29 @@ def generate_introduction( return content else: logger.warning("Failed to generate introduction via LLM") - return generate_fallback_introduction(state) + return generate_fallback_introduction(query_context) except Exception as e: logger.error(f"Error generating introduction: {e}") - return generate_fallback_introduction(state) + return generate_fallback_introduction(query_context) -def generate_fallback_executive_summary(state: ResearchState) -> str: +def generate_fallback_executive_summary( + query_context: QueryContext, synthesis_data: SynthesisData +) -> str: """Generate a fallback executive summary when LLM fails.""" - summary = f"

    This report examines the question: {html.escape(state.main_query)}

    " - summary += f"

    The research explored {len(state.sub_questions)} key dimensions of this topic, " + summary = f"

    This report examines the question: {html.escape(query_context.main_query)}

    " + summary += f"

    The research explored {len(query_context.sub_questions)} key dimensions of this topic, " summary += "synthesizing findings from multiple sources to provide a comprehensive analysis.

    " # Add confidence overview confidence_counts = {"high": 0, "medium": 0, "low": 0} - for info in state.enhanced_info.values(): + info_source = ( + synthesis_data.enhanced_info + if synthesis_data.enhanced_info + else synthesis_data.synthesized_info + ) + for info in info_source.values(): level = info.confidence_level.lower() if level in confidence_counts: confidence_counts[level] += 1 @@ -292,10 +336,10 @@ def generate_fallback_executive_summary(state: ResearchState) -> str: return summary -def generate_fallback_introduction(state: ResearchState) -> str: +def generate_fallback_introduction(query_context: QueryContext) -> str: """Generate a fallback introduction when LLM fails.""" - intro = f"

    This report addresses the research query: {html.escape(state.main_query)}

    " - intro += f"

    The research was conducted by breaking down the main query into {len(state.sub_questions)} " + intro = f"

    This report addresses the research query: {html.escape(query_context.main_query)}

    " + intro += f"

    The research was conducted by breaking down the main query into {len(query_context.sub_questions)} " intro += ( "sub-questions to explore different aspects of the topic in depth. " ) @@ -304,7 +348,9 @@ def generate_fallback_introduction(state: ResearchState) -> str: def generate_conclusion( - state: ResearchState, + query_context: QueryContext, + synthesis_data: SynthesisData, + analysis_data: AnalysisData, conclusion_generation_prompt: Prompt, llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", langfuse_project_name: str = "deep-research", @@ -312,9 +358,12 @@ def generate_conclusion( """Generate a comprehensive conclusion using LLM based on all research findings. Args: - state: The ResearchState containing all research findings + query_context: The query context with main query and sub-questions + synthesis_data: The synthesis data with all synthesized information + analysis_data: The analysis data with viewpoint analysis conclusion_generation_prompt: Prompt for generating conclusion llm_model: The model to use for conclusion generation + langfuse_project_name: Name of the Langfuse project for tracking Returns: str: HTML-formatted conclusion content @@ -323,15 +372,21 @@ def generate_conclusion( # Prepare input data for conclusion generation conclusion_input = { - "main_query": state.main_query, - "sub_questions": state.sub_questions, + "main_query": query_context.main_query, + "sub_questions": query_context.sub_questions, "enhanced_info": {}, } # Include enhanced information for each sub-question - for question in state.sub_questions: - if question in state.enhanced_info: - info = state.enhanced_info[question] + info_source = ( + synthesis_data.enhanced_info + if synthesis_data.enhanced_info + else synthesis_data.synthesized_info + ) + + for question in query_context.sub_questions: + if question in info_source: + info = info_source[question] conclusion_input["enhanced_info"][question] = { "synthesized_answer": info.synthesized_answer, "confidence_level": info.confidence_level, @@ -339,93 +394,99 @@ def generate_conclusion( "key_sources": info.key_sources, "improvements": getattr(info, "improvements", []), } - elif question in state.synthesized_info: - # Fallback to synthesized info if enhanced info not available - info = state.synthesized_info[question] - conclusion_input["enhanced_info"][question] = { - "synthesized_answer": info.synthesized_answer, - "confidence_level": info.confidence_level, - "information_gaps": info.information_gaps, - "key_sources": info.key_sources, - "improvements": [], - } - # Include viewpoint analysis if available - if state.viewpoint_analysis: + # Include viewpoint analysis + if analysis_data.viewpoint_analysis: + va = analysis_data.viewpoint_analysis conclusion_input["viewpoint_analysis"] = { - "main_points_of_agreement": state.viewpoint_analysis.main_points_of_agreement, + "main_points_of_agreement": va.main_points_of_agreement, "areas_of_tension": [ - {"topic": tension.topic, "viewpoints": tension.viewpoints} - for tension in state.viewpoint_analysis.areas_of_tension + {"topic": t.topic, "viewpoints": t.viewpoints} + for t in va.areas_of_tension ], - "perspective_gaps": state.viewpoint_analysis.perspective_gaps, - "integrative_insights": state.viewpoint_analysis.integrative_insights, + "integrative_insights": va.integrative_insights, } # Include reflection metadata if available - if state.reflection_metadata: - conclusion_input["reflection_metadata"] = { - "critique_summary": state.reflection_metadata.critique_summary, - "additional_questions_identified": state.reflection_metadata.additional_questions_identified, - "improvements_made": state.reflection_metadata.improvements_made, + if analysis_data.reflection_metadata: + rm = analysis_data.reflection_metadata + conclusion_input["reflection_insights"] = { + "improvements_made": rm.improvements_made, + "additional_questions_identified": rm.additional_questions_identified, } try: - # Use the conclusion generation prompt - conclusion_prompt_str = str(conclusion_generation_prompt) - - # Generate conclusion using LLM - conclusion_html = run_llm_completion( - prompt=json.dumps(conclusion_input, indent=2), - system_prompt=conclusion_prompt_str, + # Call LLM to generate conclusion + result = run_llm_completion( + prompt=json.dumps(conclusion_input), + system_prompt=str(conclusion_generation_prompt), model=llm_model, - clean_output=True, - max_tokens=1500, # Sufficient for comprehensive conclusion + temperature=0.7, + max_tokens=800, project=langfuse_project_name, + tags=["conclusion_generation"], ) - # Clean up any formatting issues - conclusion_html = conclusion_html.strip() + if result: + content = remove_reasoning_from_output(result) + # Clean up the HTML + content = extract_html_from_content(content) + logger.info("Successfully generated LLM-based conclusion") + return content + else: + logger.warning("Failed to generate conclusion via LLM") + return generate_fallback_conclusion(query_context, synthesis_data) + + except Exception as e: + logger.error(f"Error generating conclusion: {e}") + return generate_fallback_conclusion(query_context, synthesis_data) - # Remove any h2 tags with "Conclusion" text that LLM might have added - # Since we already have a Conclusion header in the template - conclusion_html = re.sub( - r"]*>\s*Conclusion\s*\s*", - "", - conclusion_html, - flags=re.IGNORECASE, - ) - conclusion_html = re.sub( - r"]*>\s*Conclusion\s*\s*", - "", - conclusion_html, - flags=re.IGNORECASE, - ) - # Also remove plain text "Conclusion" at the start if it exists - conclusion_html = re.sub( - r"^Conclusion\s*\n*", - "", - conclusion_html.strip(), - flags=re.IGNORECASE, - ) +def generate_fallback_conclusion( + query_context: QueryContext, synthesis_data: SynthesisData +) -> str: + """Generate a fallback conclusion when LLM fails. + + Args: + query_context: The query context with main query and sub-questions + synthesis_data: The synthesis data with all synthesized information + + Returns: + str: Basic HTML-formatted conclusion + """ + conclusion = f"

    This research has explored the question: {html.escape(query_context.main_query)}

    " + conclusion += f"

    Through systematic investigation of {len(query_context.sub_questions)} sub-questions, " + conclusion += ( + "we have gathered insights from multiple sources and perspectives.

    " + ) - if not conclusion_html.startswith("

    "): - # Wrap in paragraph tags if not already formatted - conclusion_html = f"

    {conclusion_html}

    " + # Add a summary of confidence levels + info_source = ( + synthesis_data.enhanced_info + if synthesis_data.enhanced_info + else synthesis_data.synthesized_info + ) + high_confidence = sum( + 1 + for info in info_source.values() + if info.confidence_level.lower() == "high" + ) - logger.info("Successfully generated LLM-based conclusion") - return conclusion_html + if high_confidence > 0: + conclusion += f"

    The research yielded {high_confidence} high-confidence findings out of " + conclusion += f"{len(info_source)} total areas investigated.

    " - except Exception as e: - logger.warning(f"Failed to generate LLM conclusion: {e}") - # Return a basic fallback conclusion - return f"""

    This report has explored {html.escape(state.main_query)} through a structured research approach, examining {len(state.sub_questions)} focused sub-questions and synthesizing information from diverse sources. The findings provide a comprehensive understanding of the topic, highlighting key aspects, perspectives, and current knowledge.

    -

    While some information gaps remain, as noted in the respective sections, this research provides a solid foundation for understanding the topic and its implications.

    """ + conclusion += "

    Further research may be beneficial to address remaining information gaps " + conclusion += "and explore emerging questions identified during this investigation.

    " + + return conclusion def generate_report_from_template( - state: ResearchState, + query_context: QueryContext, + search_data: SearchData, + synthesis_data: SynthesisData, + analysis_data: AnalysisData, conclusion_generation_prompt: Prompt, executive_summary_prompt: Prompt, introduction_prompt: Prompt, @@ -435,25 +496,29 @@ def generate_report_from_template( """Generate a final HTML report from a static template. Instead of using an LLM to generate HTML, this function uses predefined HTML - templates and populates them with data from the research state. + templates and populates them with data from the research artifacts. Args: - state: The current research state + query_context: The query context with main query and sub-questions + search_data: The search data (for source information) + synthesis_data: The synthesis data with all synthesized information + analysis_data: The analysis data with viewpoint analysis conclusion_generation_prompt: Prompt for generating conclusion executive_summary_prompt: Prompt for generating executive summary introduction_prompt: Prompt for generating introduction llm_model: The model to use for conclusion generation + langfuse_project_name: Name of the Langfuse project for tracking Returns: str: The HTML content of the report """ logger.info( - f"Generating templated HTML report for query: {state.main_query}" + f"Generating templated HTML report for query: {query_context.main_query}" ) # Generate table of contents for sub-questions sub_questions_toc = "" - for i, question in enumerate(state.sub_questions, 1): + for i, question in enumerate(query_context.sub_questions, 1): safe_id = f"question-{i}" sub_questions_toc += ( f'
  • {html.escape(question)}
  • \n' @@ -461,7 +526,7 @@ def generate_report_from_template( # Add viewpoint analysis to TOC if available additional_sections_toc = "" - if state.viewpoint_analysis: + if analysis_data.viewpoint_analysis: additional_sections_toc += ( '
  • Viewpoint Analysis
  • \n' ) @@ -470,11 +535,41 @@ def generate_report_from_template( sub_questions_html = "" all_sources = set() - for i, question in enumerate(state.sub_questions, 1): - info = state.enhanced_info.get(question, None) + # Determine which info source to use (merge original with enhanced) + # Start with the original synthesized info + info_source = synthesis_data.synthesized_info.copy() + + # Override with enhanced info where available + if synthesis_data.enhanced_info: + info_source.update(synthesis_data.enhanced_info) + + # Debug logging + logger.info( + f"Synthesis data has enhanced_info: {bool(synthesis_data.enhanced_info)}" + ) + logger.info( + f"Synthesis data has synthesized_info: {bool(synthesis_data.synthesized_info)}" + ) + logger.info(f"Info source has {len(info_source)} entries") + logger.info(f"Processing {len(query_context.sub_questions)} sub-questions") + + # Log the keys in info_source for debugging + if info_source: + logger.info( + f"Keys in info_source: {list(info_source.keys())[:3]}..." + ) # First 3 keys + logger.info( + f"Sub-questions from query_context: {query_context.sub_questions[:3]}..." + ) # First 3 + + for i, question in enumerate(query_context.sub_questions, 1): + info = info_source.get(question, None) # Skip if no information is available if not info: + logger.warning( + f"No synthesis info found for question {i}: {question}" + ) continue # Process confidence level @@ -527,65 +622,73 @@ def generate_report_from_template( confidence_upper=confidence_upper, confidence_icon=confidence_icon, answer=format_text_with_code_blocks(info.synthesized_answer), - info_gaps_html=info_gaps_html, key_sources_html=key_sources_html, + info_gaps_html=info_gaps_html, ) sub_questions_html += sub_question_html # Generate viewpoint analysis HTML if available viewpoint_analysis_html = "" - if state.viewpoint_analysis: - # Format points of agreement - agreements_html = "" - for point in state.viewpoint_analysis.main_points_of_agreement: - agreements_html += f"
  • {html.escape(point)}
  • \n" - - # Format areas of tension + if analysis_data.viewpoint_analysis: + va = analysis_data.viewpoint_analysis + # Format tensions tensions_html = "" - for tension in state.viewpoint_analysis.areas_of_tension: - viewpoints_html = "" - for title, content in tension.viewpoints.items(): - # Create category-specific styling - category_class = f"category-{title.lower()}" - category_title = title.capitalize() - - viewpoints_html += f""" -
    - {category_title} -

    {html.escape(content)}

    -
    - """ - + for tension in va.areas_of_tension: + viewpoints_list = "\n".join( + [ + f"
  • {html.escape(viewpoint)}: {html.escape(description)}
  • " + for viewpoint, description in tension.viewpoints.items() + ] + ) tensions_html += f""" -
    +

    {html.escape(tension.topic)}

    -
    - {viewpoints_html} -
    +
      + {viewpoints_list} +
    """ - # Format the viewpoint analysis section using the template + # Format agreements (just the list items) + agreements_html = "" + if va.main_points_of_agreement: + agreements_html = "\n".join( + [ + f"
  • {html.escape(point)}
  • " + for point in va.main_points_of_agreement + ] + ) + + # Get perspective gaps if available + perspective_gaps = "" + if hasattr(va, "perspective_gaps") and va.perspective_gaps: + perspective_gaps = va.perspective_gaps + else: + perspective_gaps = "No significant perspective gaps identified." + + # Get integrative insights + integrative_insights = "" + if va.integrative_insights: + integrative_insights = format_text_with_code_blocks( + va.integrative_insights + ) + viewpoint_analysis_html = VIEWPOINT_ANALYSIS_TEMPLATE.format( agreements_html=agreements_html, tensions_html=tensions_html, - perspective_gaps=format_text_with_code_blocks( - state.viewpoint_analysis.perspective_gaps - ), - integrative_insights=format_text_with_code_blocks( - state.viewpoint_analysis.integrative_insights - ), + perspective_gaps=perspective_gaps, + integrative_insights=integrative_insights, ) - # Generate references HTML - references_html = "
      " + # Generate references section + references_html = '
        ' if all_sources: for source in sorted(all_sources): if source.startswith(("http://", "https://")): - references_html += f'
      • {html.escape(source)}
      • \n' + references_html += f'
      • {html.escape(source)}
      • ' else: - references_html += f"
      • {html.escape(source)}
      • \n" + references_html += f"
      • {html.escape(source)}
      • " else: references_html += ( "
      • No external sources were referenced in this research.
      • " @@ -595,7 +698,12 @@ def generate_report_from_template( # Generate dynamic executive summary using LLM logger.info("Generating dynamic executive summary...") executive_summary = generate_executive_summary( - state, executive_summary_prompt, llm_model, langfuse_project_name + query_context, + synthesis_data, + analysis_data, + executive_summary_prompt, + llm_model, + langfuse_project_name, ) logger.info( f"Executive summary generated: {len(executive_summary)} characters" @@ -604,23 +712,28 @@ def generate_report_from_template( # Generate dynamic introduction using LLM logger.info("Generating dynamic introduction...") introduction_html = generate_introduction( - state, introduction_prompt, llm_model, langfuse_project_name + query_context, introduction_prompt, llm_model, langfuse_project_name ) logger.info(f"Introduction generated: {len(introduction_html)} characters") # Generate comprehensive conclusion using LLM conclusion_html = generate_conclusion( - state, conclusion_generation_prompt, llm_model, langfuse_project_name + query_context, + synthesis_data, + analysis_data, + conclusion_generation_prompt, + llm_model, + langfuse_project_name, ) # Generate complete HTML report html_content = STATIC_HTML_TEMPLATE.format( - main_query=html.escape(state.main_query), + main_query=html.escape(query_context.main_query), sub_questions_toc=sub_questions_toc, additional_sections_toc=additional_sections_toc, executive_summary=executive_summary, introduction_html=introduction_html, - num_sub_questions=len(state.sub_questions), + num_sub_questions=len(query_context.sub_questions), sub_questions_html=sub_questions_html, viewpoint_analysis_html=viewpoint_analysis_html, conclusion_html=conclusion_html, @@ -630,19 +743,20 @@ def generate_report_from_template( return html_content -def _generate_fallback_report(state: ResearchState) -> str: +def _generate_fallback_report( + query_context: QueryContext, + synthesis_data: SynthesisData, + analysis_data: AnalysisData, +) -> str: """Generate a minimal fallback report when the main report generation fails. This function creates a simplified HTML report with a consistent structure when - the main report generation process encounters an error. The HTML includes: - - A header section with the main research query - - An error notice - - Introduction section - - Individual sections for each sub-question with available answers - - A references section if sources are available + the main report generation process encounters an error. Args: - state: The current research state containing query and answer information + query_context: The query context with main query and sub-questions + synthesis_data: The synthesis data with all synthesized information + analysis_data: The analysis data with viewpoint analysis Returns: str: A basic HTML report with a standard research report structure @@ -688,302 +802,161 @@ def _generate_fallback_report(state: ResearchState) -> str: }} h3 {{ - color: #3498db; + color: #34495e; margin-top: 20px; }} p {{ - margin: 15px 0; + margin: 10px 0; }} /* Sections */ .section {{ - margin: 30px 0; + margin-bottom: 30px; padding: 20px; background-color: #f8f9fa; - border-left: 4px solid #3498db; - border-radius: 4px; - }} - - .content {{ - margin-top: 15px; - }} - - /* Notice/Error Styles */ - .notice {{ - padding: 15px; - margin: 20px 0; - border-radius: 4px; + border-radius: 8px; }} - .error {{ + .error-notice {{ background-color: #fee; - border-left: 4px solid #e74c3c; - color: #c0392b; + border: 1px solid #fcc; + color: #c33; + padding: 15px; + border-radius: 8px; + margin-bottom: 20px; }} - /* Confidence Level Indicators */ - .confidence-level {{ + /* Confidence badges */ + .confidence {{ display: inline-block; - padding: 5px 10px; - border-radius: 4px; + padding: 4px 12px; + border-radius: 20px; + font-size: 12px; font-weight: bold; - margin: 10px 0; + margin-left: 10px; }} - .confidence-high {{ + .confidence.high {{ background-color: #d4edda; color: #155724; - border-left: 4px solid #28a745; }} - .confidence-medium {{ + .confidence.medium {{ background-color: #fff3cd; color: #856404; - border-left: 4px solid #ffc107; }} - .confidence-low {{ + .confidence.low {{ background-color: #f8d7da; color: #721c24; - border-left: 4px solid #dc3545; }} /* Lists */ ul {{ - padding-left: 20px; + margin: 10px 0; + padding-left: 25px; }} li {{ - margin: 8px 0; + margin: 5px 0; }} - /* References Section */ + /* References */ .references {{ margin-top: 40px; padding-top: 20px; - border-top: 1px solid #eee; - }} - - .references ul {{ - list-style-type: none; - padding-left: 0; - }} - - .references li {{ - padding: 8px 0; - border-bottom: 1px dotted #ddd; - }} - - /* Table of Contents */ - .toc {{ - background-color: #f8f9fa; - padding: 15px; - border-radius: 4px; - margin: 20px 0; - }} - - .toc ul {{ - list-style-type: none; - padding-left: 10px; + border-top: 2px solid #eee; }} - .toc li {{ - margin: 5px 0; + .reference-list {{ + font-size: 14px; }} - .toc a {{ + .reference-list a {{ color: #3498db; text-decoration: none; + word-break: break-word; }} - .toc a:hover {{ + .reference-list a:hover {{ text-decoration: underline; }} - - /* Executive Summary */ - .executive-summary {{ - background-color: #e8f4f8; - padding: 20px; - border-radius: 4px; - margin: 20px 0; - border-left: 4px solid #3498db; - }} + Research Report - {html.escape(query_context.main_query)}
        -

        Research Report: {state.main_query}

        - -
        -

        Note: This is a fallback report generated due to an error in the report generation process.

        -
        +

        Research Report: {html.escape(query_context.main_query)}

        - -
        -

        Table of Contents

        -
          -
        • Introduction
        • -""" - - # Add TOC entries for each sub-question - for i, sub_question in enumerate(state.sub_questions): - safe_id = f"question-{i + 1}" - html += f'
        • {sub_question}
        • \n' - - html += """
        • References
        • -
        +
        + Note: This is a simplified version of the report generated due to processing limitations.
        - -
        -

        Executive Summary

        -

        This report presents findings related to the main research query. It explores multiple aspects of the topic through structured sub-questions and synthesizes information from various sources.

        -
        - -
        +

        Introduction

        -

        This report addresses the research query: "{state.main_query}"

        -

        The analysis is structured around {len(state.sub_questions)} sub-questions that explore different dimensions of this topic.

        +

        This report investigates the research query: {html.escape(query_context.main_query)}

        +

        The investigation was structured around {len(query_context.sub_questions)} key sub-questions to provide comprehensive coverage of the topic.

        + +
        +

        Research Findings

        """ - # Add each sub-question and its synthesized information - for i, sub_question in enumerate(state.sub_questions): - safe_id = f"question-{i + 1}" - info = state.enhanced_info.get(sub_question, None) + # Add findings for each sub-question + info_source = ( + synthesis_data.enhanced_info + if synthesis_data.enhanced_info + else synthesis_data.synthesized_info + ) - if not info: - # Try to get from synthesized info if not in enhanced info - info = state.synthesized_info.get(sub_question, None) - - if info: - answer = info.synthesized_answer - confidence = info.confidence_level - - # Add appropriate confidence class - confidence_class = "" - if confidence == "high": - confidence_class = "confidence-high" - elif confidence == "medium": - confidence_class = "confidence-medium" - elif confidence == "low": - confidence_class = "confidence-low" + for i, question in enumerate(query_context.sub_questions, 1): + if question in info_source: + info = info_source[question] + confidence_class = info.confidence_level.lower() html += f""" -
        -

        {i + 1}. {sub_question}

        -

        Confidence Level: {confidence.upper()}

        -
        -

        {answer}

        -
        +
        +

        {i}. {html.escape(question)}

        + Confidence: {info.confidence_level.upper()} +

        {html.escape(info.synthesized_answer)}

        """ - # Add information gaps if available - if hasattr(info, "information_gaps") and info.information_gaps: - html += f""" -
        -

        Information Gaps

        -

        {info.information_gaps}

        -
        - """ - - # Add improvements if available - if hasattr(info, "improvements") and info.improvements: - html += """ -
        -

        Improvements Made

        -
          - """ - - for improvement in info.improvements: - html += f"
        • {improvement}
        • \n" - - html += """ -
        -
        - """ - - # Add key sources if available - if hasattr(info, "key_sources") and info.key_sources: - html += """ -
        -

        Key Sources

        -
          - """ + if info.information_gaps: + html += f"

          Information Gaps: {html.escape(info.information_gaps)}

          " - for source in info.key_sources: - html += f"
        • {source}
        • \n" + html += "
        " - html += """ -
      -
    - """ - - html += """ -
    - """ - else: - html += f""" -
    -

    {i + 1}. {sub_question}

    -

    No information available for this question.

    -
    - """ - - # Add conclusion section html += """ + +

    Conclusion

    -

    This report has explored the research query through multiple sub-questions, providing synthesized information based on available sources. While limitations exist in some areas, the report provides a structured analysis of the topic.

    +

    This research has provided insights into the various aspects of the main query through systematic investigation.

    +

    The findings represent a synthesis of available information, with varying levels of confidence across different areas.

    - """ - - # Add sources if available - sources_set = set() - for info in state.enhanced_info.values(): - if info.key_sources: - sources_set.update(info.key_sources) - - if sources_set: - html += """ -
    -

    References

    -
      - """ - - for source in sorted(sources_set): - html += f"
    • {source}
    • \n" - - html += """ -
    -
    - """ - else: - html += """ -
    + +

    References

    -

    No references available.

    +

    Sources were gathered from various search providers and synthesized to create this report.

    - """ - - # Close the HTML structure - html += """
    - - """ +""" return html @step( output_materializers={ - "state": ResearchStateMaterializer, + "final_report": FinalReportMaterializer, } ) def pydantic_final_report_step( - state: ResearchState, + query_context: QueryContext, + search_data: SearchData, + synthesis_data: SynthesisData, + analysis_data: AnalysisData, conclusion_generation_prompt: Prompt, executive_summary_prompt: Prompt, introduction_prompt: Prompt, @@ -991,34 +964,41 @@ def pydantic_final_report_step( llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", langfuse_project_name: str = "deep-research", ) -> Tuple[ - Annotated[ResearchState, "state"], + Annotated[FinalReport, "final_report"], Annotated[HTMLString, "report_html"], ]: - """Generate the final research report in HTML format using Pydantic models. + """Generate the final research report in HTML format using artifact-based approach. - This step uses the Pydantic models and materializers to generate a final - HTML report and return both the updated state and the HTML report as - separate artifacts. + This step uses the individual artifacts to generate a final HTML report. Args: - state: The current research state (Pydantic model) + query_context: The query context with main query and sub-questions + search_data: The search data (for source information) + synthesis_data: The synthesis data with all synthesized information + analysis_data: The analysis data with viewpoint analysis and reflection metadata conclusion_generation_prompt: Prompt for generating conclusions executive_summary_prompt: Prompt for generating executive summary introduction_prompt: Prompt for generating introduction use_static_template: Whether to use a static template instead of LLM generation llm_model: The model to use for report generation with provider prefix + langfuse_project_name: Name of the Langfuse project for tracking Returns: - A tuple containing the updated research state and the HTML report + A tuple containing the FinalReport artifact and the HTML report string """ start_time = time.time() - logger.info("Generating final research report using Pydantic models") + logger.info( + "Generating final research report using artifact-based approach" + ) if use_static_template: # Use the static HTML template approach logger.info("Using static HTML template for report generation") html_content = generate_report_from_template( - state, + query_context, + search_data, + synthesis_data, + analysis_data, conclusion_generation_prompt, executive_summary_prompt, introduction_prompt, @@ -1026,232 +1006,101 @@ def pydantic_final_report_step( langfuse_project_name, ) - # Update the state with the final report HTML - state.set_final_report(html_content) + # Create the FinalReport artifact + final_report = FinalReport( + report_html=html_content, + main_query=query_context.main_query, + ) - # Collect metadata about the report + # Calculate execution time execution_time = time.time() - start_time - # Count sources - all_sources = set() - for info in state.enhanced_info.values(): - if info.key_sources: - all_sources.update(info.key_sources) - - # Count confidence levels + # Calculate report metrics + info_source = ( + synthesis_data.enhanced_info + if synthesis_data.enhanced_info + else synthesis_data.synthesized_info + ) confidence_distribution = {"high": 0, "medium": 0, "low": 0} - for info in state.enhanced_info.values(): + for info in info_source.values(): level = info.confidence_level.lower() if level in confidence_distribution: confidence_distribution[level] += 1 - # Log metadata - log_metadata( - metadata={ - "report_generation": { - "execution_time_seconds": execution_time, - "generation_method": "static_template", - "llm_model": llm_model, - "report_length_chars": len(html_content), - "num_sub_questions": len(state.sub_questions), - "num_sources": len(all_sources), - "has_viewpoint_analysis": bool(state.viewpoint_analysis), - "has_reflection": bool(state.reflection_metadata), - "confidence_distribution": confidence_distribution, - "fallback_report": False, - } - } - ) - - # Log model metadata for cross-pipeline tracking - log_metadata( - metadata={ - "research_quality": { - "confidence_distribution": confidence_distribution, - } - }, - infer_model=True, - ) - - # Log artifact metadata for the HTML report - log_metadata( - metadata={ - "html_report_characteristics": { - "size_bytes": len(html_content.encode("utf-8")), - "has_toc": "toc" in html_content.lower(), - "has_executive_summary": "executive summary" - in html_content.lower(), - "has_conclusion": "conclusion" in html_content.lower(), - "has_references": "references" in html_content.lower(), - } - }, - infer_artifact=True, - artifact_name="report_html", - ) - - logger.info( - "Final research report generated successfully with static template" + # Count various elements in the report + num_sources = len( + set( + source + for info in info_source.values() + for source in info.key_sources + ) ) - # Add tags to the artifacts - add_tags(tags=["state", "final"], artifact="state") - add_tags(tags=["report", "html"], artifact="report_html") - return state, HTMLString(html_content) - - # Otherwise use the LLM-generated approach - # Convert Pydantic model to dict for LLM input - report_input = { - "main_query": state.main_query, - "sub_questions": state.sub_questions, - "synthesized_information": state.enhanced_info, - } - - if state.viewpoint_analysis: - report_input["viewpoint_analysis"] = state.viewpoint_analysis - - if state.reflection_metadata: - report_input["reflection_metadata"] = state.reflection_metadata - - # Generate the report - try: - logger.info(f"Calling {llm_model} to generate final report") - - # Use a default report generation prompt - report_prompt = "Generate a comprehensive HTML research report based on the provided research data. Include proper HTML structure with sections for executive summary, introduction, findings, and conclusion." - - # Use the utility function to run LLM completion - html_content = run_llm_completion( - prompt=json.dumps(report_input), - system_prompt=report_prompt, - model=llm_model, - clean_output=False, # Don't clean in case of breaking HTML formatting - max_tokens=4000, # Increased token limit for detailed report generation - project=langfuse_project_name, + has_viewpoint_analysis = analysis_data.viewpoint_analysis is not None + has_reflection_insights = ( + analysis_data.reflection_metadata is not None + and analysis_data.reflection_metadata.improvements_made > 0 ) - # Clean up any JSON wrapper or other artifacts - html_content = remove_reasoning_from_output(html_content) - - # Process the HTML content to remove code block markers and fix common issues - html_content = clean_html_output(html_content) - - # Basic validation of HTML content - if not html_content.strip().startswith("<"): - logger.warning( - "Generated content does not appear to be valid HTML" - ) - # Try to extract HTML if it might be wrapped in code blocks or JSON - html_content = extract_html_from_content(html_content) - - # Update the state with the final report HTML - state.set_final_report(html_content) - - # Collect metadata about the report - execution_time = time.time() - start_time - - # Count sources - all_sources = set() - for info in state.enhanced_info.values(): - if info.key_sources: - all_sources.update(info.key_sources) - - # Count confidence levels - confidence_distribution = {"high": 0, "medium": 0, "low": 0} - for info in state.enhanced_info.values(): - level = info.confidence_level.lower() - if level in confidence_distribution: - confidence_distribution[level] += 1 - - # Log metadata + # Log step metadata log_metadata( metadata={ - "report_generation": { + "final_report_generation": { "execution_time_seconds": execution_time, - "generation_method": "llm_generated", + "use_static_template": use_static_template, "llm_model": llm_model, - "report_length_chars": len(html_content), - "num_sub_questions": len(state.sub_questions), - "num_sources": len(all_sources), - "has_viewpoint_analysis": bool(state.viewpoint_analysis), - "has_reflection": bool(state.reflection_metadata), + "main_query_length": len(query_context.main_query), + "num_sub_questions": len(query_context.sub_questions), + "num_synthesized_answers": len(info_source), + "has_enhanced_info": bool(synthesis_data.enhanced_info), "confidence_distribution": confidence_distribution, - "fallback_report": False, + "num_unique_sources": num_sources, + "has_viewpoint_analysis": has_viewpoint_analysis, + "has_reflection_insights": has_reflection_insights, + "report_length_chars": len(html_content), + "report_generation_success": True, } } ) - # Log model metadata for cross-pipeline tracking + # Log artifact metadata log_metadata( metadata={ - "research_quality": { - "confidence_distribution": confidence_distribution, + "final_report_characteristics": { + "report_length": len(html_content), + "main_query": query_context.main_query, + "num_sections": len(query_context.sub_questions) + + (1 if has_viewpoint_analysis else 0), + "has_executive_summary": True, + "has_introduction": True, + "has_conclusion": True, } }, - infer_model=True, + artifact_name="final_report", + infer_artifact=True, ) - logger.info("Final research report generated successfully") - # Add tags to the artifacts - add_tags(tags=["state", "final"], artifact="state") - add_tags(tags=["report", "html"], artifact="report_html") - return state, HTMLString(html_content) - - except Exception as e: - logger.error(f"Error generating final report: {e}") - # Generate a minimal fallback report - fallback_html = _generate_fallback_report(state) - - # Process the fallback HTML to ensure it's clean - fallback_html = clean_html_output(fallback_html) - - # Update the state with the fallback report - state.set_final_report(fallback_html) - - # Collect metadata about the fallback report - execution_time = time.time() - start_time - - # Count sources - all_sources = set() - for info in state.enhanced_info.values(): - if info.key_sources: - all_sources.update(info.key_sources) - - # Count confidence levels - confidence_distribution = {"high": 0, "medium": 0, "low": 0} - for info in state.enhanced_info.values(): - level = info.confidence_level.lower() - if level in confidence_distribution: - confidence_distribution[level] += 1 + # Add tags to the artifact + # add_tags(tags=["report", "final", "html"], artifact_name="final_report", infer_artifact=True) - # Log metadata for fallback report - log_metadata( - metadata={ - "report_generation": { - "execution_time_seconds": execution_time, - "generation_method": "fallback", - "llm_model": llm_model, - "report_length_chars": len(fallback_html), - "num_sub_questions": len(state.sub_questions), - "num_sources": len(all_sources), - "has_viewpoint_analysis": bool(state.viewpoint_analysis), - "has_reflection": bool(state.reflection_metadata), - "confidence_distribution": confidence_distribution, - "fallback_report": True, - "error_message": str(e), - } - } + logger.info( + f"Successfully generated final report ({len(html_content)} characters)" ) + return final_report, HTMLString(html_content) - # Log model metadata for cross-pipeline tracking - log_metadata( - metadata={ - "research_quality": { - "confidence_distribution": confidence_distribution, - } - }, - infer_model=True, + else: + # Handle non-static template case (future implementation) + logger.warning( + "Non-static template generation not yet implemented, falling back to static template" + ) + return pydantic_final_report_step( + query_context=query_context, + search_data=search_data, + synthesis_data=synthesis_data, + analysis_data=analysis_data, + conclusion_generation_prompt=conclusion_generation_prompt, + executive_summary_prompt=executive_summary_prompt, + introduction_prompt=introduction_prompt, + use_static_template=True, + llm_model=llm_model, + langfuse_project_name=langfuse_project_name, ) - - # Add tags to the artifacts - add_tags(tags=["state", "final"], artifact="state") - add_tags(tags=["report", "html"], artifact="report_html") - return state, HTMLString(fallback_html) diff --git a/deep_research/steps/query_decomposition_step.py b/deep_research/steps/query_decomposition_step.py index ec0aa0d5..053a8fd4 100644 --- a/deep_research/steps/query_decomposition_step.py +++ b/deep_research/steps/query_decomposition_step.py @@ -2,35 +2,36 @@ import time from typing import Annotated -from materializers.pydantic_materializer import ResearchStateMaterializer +from materializers.query_context_materializer import QueryContextMaterializer from utils.llm_utils import get_structured_llm_output -from utils.pydantic_models import Prompt, ResearchState -from zenml import add_tags, log_metadata, step +from utils.pydantic_models import Prompt, QueryContext +from zenml import log_metadata, step logger = logging.getLogger(__name__) -@step(output_materializers=ResearchStateMaterializer) +@step(output_materializers=QueryContextMaterializer) def initial_query_decomposition_step( - state: ResearchState, + main_query: str, query_decomposition_prompt: Prompt, llm_model: str = "sambanova/DeepSeek-R1-Distill-Llama-70B", max_sub_questions: int = 8, langfuse_project_name: str = "deep-research", -) -> Annotated[ResearchState, "updated_state"]: +) -> Annotated[QueryContext, "query_context"]: """Break down a complex research query into specific sub-questions. Args: - state: The current research state + main_query: The main research query to decompose query_decomposition_prompt: Prompt for query decomposition llm_model: The reasoning model to use with provider prefix max_sub_questions: Maximum number of sub-questions to generate + langfuse_project_name: Project name for tracing Returns: - Updated research state with sub-questions + QueryContext containing the main query and decomposed sub-questions """ start_time = time.time() - logger.info(f"Decomposing research query: {state.main_query}") + logger.info(f"Decomposing research query: {main_query}") # Get the prompt content system_prompt = str(query_decomposition_prompt) @@ -48,22 +49,22 @@ def initial_query_decomposition_step( # Define fallback questions fallback_questions = [ { - "sub_question": f"What is {state.main_query}?", + "sub_question": f"What is {main_query}?", "reasoning": "Basic understanding of the topic", }, { - "sub_question": f"What are the key aspects of {state.main_query}?", + "sub_question": f"What are the key aspects of {main_query}?", "reasoning": "Exploring important dimensions", }, { - "sub_question": f"What are the implications of {state.main_query}?", + "sub_question": f"What are the implications of {main_query}?", "reasoning": "Understanding broader impact", }, ] # Use utility function to get structured output decomposed_questions = get_structured_llm_output( - prompt=state.main_query, + prompt=main_query, system_prompt=updated_system_prompt, model=llm_model, fallback_response=fallback_questions, @@ -84,8 +85,10 @@ def initial_query_decomposition_step( for i, question in enumerate(sub_questions, 1): logger.info(f" {i}. {question}") - # Update the state with the new sub-questions - state.update_sub_questions(sub_questions) + # Create the QueryContext + query_context = QueryContext( + main_query=main_query, sub_questions=sub_questions + ) # Log step metadata execution_time = time.time() - start_time @@ -97,7 +100,7 @@ def initial_query_decomposition_step( "llm_model": llm_model, "max_sub_questions_requested": max_sub_questions, "fallback_used": False, - "main_query_length": len(state.main_query), + "main_query_length": len(main_query), "sub_questions": sub_questions, } } @@ -113,36 +116,40 @@ def initial_query_decomposition_step( infer_model=True, ) - # Log artifact metadata for the output state + # Log artifact metadata for the output query context log_metadata( metadata={ - "state_characteristics": { - "total_sub_questions": len(state.sub_questions), - "has_search_results": bool(state.search_results), - "has_synthesized_info": bool(state.synthesized_info), + "query_context_characteristics": { + "main_query": main_query, + "num_sub_questions": len(sub_questions), + "timestamp": query_context.decomposition_timestamp, } }, infer_artifact=True, ) # Add tags to the artifact - add_tags(tags=["state", "decomposed"], artifact="updated_state") + # add_tags(tags=["query", "decomposed"], artifact_name="query_context", infer_artifact=True) - return state + return query_context except Exception as e: logger.error(f"Error decomposing query: {e}") - # Return fallback questions in the state + # Return fallback questions fallback_questions = [ - f"What is {state.main_query}?", - f"What are the key aspects of {state.main_query}?", - f"What are the implications of {state.main_query}?", + f"What is {main_query}?", + f"What are the key aspects of {main_query}?", + f"What are the implications of {main_query}?", ] fallback_questions = fallback_questions[:max_sub_questions] logger.info(f"Using {len(fallback_questions)} fallback questions:") for i, question in enumerate(fallback_questions, 1): logger.info(f" {i}. {question}") - state.update_sub_questions(fallback_questions) + + # Create QueryContext with fallback questions + query_context = QueryContext( + main_query=main_query, sub_questions=fallback_questions + ) # Log metadata for fallback scenario execution_time = time.time() - start_time @@ -155,7 +162,7 @@ def initial_query_decomposition_step( "max_sub_questions_requested": max_sub_questions, "fallback_used": True, "error_message": str(e), - "main_query_length": len(state.main_query), + "main_query_length": len(main_query), "sub_questions": fallback_questions, } } @@ -172,6 +179,8 @@ def initial_query_decomposition_step( ) # Add tags to the artifact - add_tags(tags=["state", "decomposed"], artifact="updated_state") + # add_tags( + # tags=["query", "decomposed", "fallback"], artifact_name="query_context", infer_artifact=True + # ) - return state + return query_context diff --git a/deep_research/tests/test_approval_utils.py b/deep_research/tests/test_approval_utils.py index fe859b0f..f1dd15a5 100644 --- a/deep_research/tests/test_approval_utils.py +++ b/deep_research/tests/test_approval_utils.py @@ -6,9 +6,7 @@ format_critique_summary, format_query_list, parse_approval_response, - summarize_research_progress, ) -from utils.pydantic_models import ResearchState, SynthesizedInfo def test_parse_approval_responses(): @@ -77,34 +75,6 @@ def test_format_approval_request(): assert "Missing data" in message -def test_summarize_research_progress(): - """Test research progress summarization.""" - state = ResearchState( - main_query="test", - synthesized_info={ - "q1": SynthesizedInfo( - synthesized_answer="a1", confidence_level="high" - ), - "q2": SynthesizedInfo( - synthesized_answer="a2", confidence_level="medium" - ), - "q3": SynthesizedInfo( - synthesized_answer="a3", confidence_level="low" - ), - "q4": SynthesizedInfo( - synthesized_answer="a4", confidence_level="low" - ), - }, - ) - - summary = summarize_research_progress(state) - - assert summary["completed_count"] == 4 - # (1.0 + 0.5 + 0.0 + 0.0) / 4 = 1.5 / 4 = 0.375, rounded to 0.38 - assert summary["avg_confidence"] == 0.38 - assert summary["low_confidence_count"] == 2 - - def test_format_critique_summary(): """Test critique summary formatting.""" # Test with no critiques diff --git a/deep_research/tests/test_artifact_models.py b/deep_research/tests/test_artifact_models.py new file mode 100644 index 00000000..415862b5 --- /dev/null +++ b/deep_research/tests/test_artifact_models.py @@ -0,0 +1,210 @@ +"""Tests for the new artifact models.""" + +import time + +import pytest +from utils.pydantic_models import ( + AnalysisData, + FinalReport, + QueryContext, + ReflectionMetadata, + SearchCostDetail, + SearchData, + SearchResult, + SynthesisData, + SynthesizedInfo, + ViewpointAnalysis, +) + + +class TestQueryContext: + """Test the QueryContext artifact.""" + + def test_query_context_creation(self): + """Test creating a QueryContext.""" + query = QueryContext( + main_query="What is quantum computing?", + sub_questions=["What are qubits?", "How do quantum gates work?"], + ) + + assert query.main_query == "What is quantum computing?" + assert len(query.sub_questions) == 2 + assert query.decomposition_timestamp > 0 + + def test_query_context_immutable(self): + """Test that QueryContext is immutable.""" + query = QueryContext(main_query="Test query", sub_questions=[]) + + # Should raise error when trying to modify + with pytest.raises(Exception): # Pydantic will raise validation error + query.main_query = "Modified query" + + def test_query_context_defaults(self): + """Test QueryContext with defaults.""" + query = QueryContext(main_query="Test") + assert query.sub_questions == [] + assert query.decomposition_timestamp > 0 + + +class TestSearchData: + """Test the SearchData artifact.""" + + def test_search_data_creation(self): + """Test creating SearchData.""" + search_data = SearchData() + + assert search_data.search_results == {} + assert search_data.search_costs == {} + assert search_data.search_cost_details == [] + assert search_data.total_searches == 0 + + def test_search_data_with_results(self): + """Test SearchData with actual results.""" + result = SearchResult( + url="https://example.com", + content="Test content", + title="Test Title", + ) + + cost_detail = SearchCostDetail( + provider="exa", + query="test query", + cost=0.01, + timestamp=time.time(), + step="process_sub_question", + ) + + search_data = SearchData( + search_results={"Question 1": [result]}, + search_costs={"exa": 0.01}, + search_cost_details=[cost_detail], + total_searches=1, + ) + + assert len(search_data.search_results) == 1 + assert search_data.search_costs["exa"] == 0.01 + assert len(search_data.search_cost_details) == 1 + assert search_data.total_searches == 1 + + def test_search_data_merge(self): + """Test merging SearchData instances.""" + # Create first instance + data1 = SearchData( + search_results={ + "Q1": [SearchResult(url="url1", content="content1")] + }, + search_costs={"exa": 0.01}, + total_searches=1, + ) + + # Create second instance + data2 = SearchData( + search_results={ + "Q1": [SearchResult(url="url2", content="content2")], + "Q2": [SearchResult(url="url3", content="content3")], + }, + search_costs={"exa": 0.02, "tavily": 0.01}, + total_searches=2, + ) + + # Merge + data1.merge(data2) + + # Check results + assert len(data1.search_results["Q1"]) == 2 # Merged Q1 results + assert "Q2" in data1.search_results # Added Q2 + assert data1.search_costs["exa"] == 0.03 # Combined costs + assert data1.search_costs["tavily"] == 0.01 # New provider + assert data1.total_searches == 3 + + +class TestSynthesisData: + """Test the SynthesisData artifact.""" + + def test_synthesis_data_creation(self): + """Test creating SynthesisData.""" + synthesis = SynthesisData() + + assert synthesis.synthesized_info == {} + assert synthesis.enhanced_info == {} + + def test_synthesis_data_with_info(self): + """Test SynthesisData with synthesized info.""" + synth_info = SynthesizedInfo( + synthesized_answer="Test answer", + key_sources=["source1", "source2"], + confidence_level="high", + ) + + synthesis = SynthesisData(synthesized_info={"Q1": synth_info}) + + assert "Q1" in synthesis.synthesized_info + assert synthesis.synthesized_info["Q1"].confidence_level == "high" + + def test_synthesis_data_merge(self): + """Test merging SynthesisData instances.""" + info1 = SynthesizedInfo(synthesized_answer="Answer 1") + info2 = SynthesizedInfo(synthesized_answer="Answer 2") + + data1 = SynthesisData(synthesized_info={"Q1": info1}) + data2 = SynthesisData(synthesized_info={"Q2": info2}) + + data1.merge(data2) + + assert "Q1" in data1.synthesized_info + assert "Q2" in data1.synthesized_info + + +class TestAnalysisData: + """Test the AnalysisData artifact.""" + + def test_analysis_data_creation(self): + """Test creating AnalysisData.""" + analysis = AnalysisData() + + assert analysis.viewpoint_analysis is None + assert analysis.reflection_metadata is None + + def test_analysis_data_with_viewpoint(self): + """Test AnalysisData with viewpoint analysis.""" + viewpoint = ViewpointAnalysis( + main_points_of_agreement=["Point 1", "Point 2"], + perspective_gaps="Some gaps", + ) + + analysis = AnalysisData(viewpoint_analysis=viewpoint) + + assert analysis.viewpoint_analysis is not None + assert len(analysis.viewpoint_analysis.main_points_of_agreement) == 2 + + def test_analysis_data_with_reflection(self): + """Test AnalysisData with reflection metadata.""" + reflection = ReflectionMetadata( + critique_summary=["Critique 1"], improvements_made=3.0 + ) + + analysis = AnalysisData(reflection_metadata=reflection) + + assert analysis.reflection_metadata is not None + assert analysis.reflection_metadata.improvements_made == 3.0 + + +class TestFinalReport: + """Test the FinalReport artifact.""" + + def test_final_report_creation(self): + """Test creating FinalReport.""" + report = FinalReport() + + assert report.report_html == "" + assert report.generated_at > 0 + assert report.main_query == "" + + def test_final_report_with_content(self): + """Test FinalReport with HTML content.""" + html = "Test Report" + report = FinalReport(report_html=html, main_query="What is AI?") + + assert report.report_html == html + assert report.main_query == "What is AI?" + assert report.generated_at > 0 diff --git a/deep_research/tests/test_pydantic_final_report_step.py b/deep_research/tests/test_pydantic_final_report_step.py index c0f13530..b4dcd956 100644 --- a/deep_research/tests/test_pydantic_final_report_step.py +++ b/deep_research/tests/test_pydantic_final_report_step.py @@ -9,9 +9,14 @@ import pytest from steps.pydantic_final_report_step import pydantic_final_report_step from utils.pydantic_models import ( + AnalysisData, + FinalReport, + Prompt, + QueryContext, ReflectionMetadata, - ResearchState, + SearchData, SearchResult, + SynthesisData, SynthesizedInfo, ViewpointAnalysis, ViewpointTension, @@ -20,15 +25,15 @@ @pytest.fixture -def sample_research_state() -> ResearchState: - """Create a sample research state for testing.""" - # Create a basic research state - state = ResearchState(main_query="What are the impacts of climate change?") - - # Add sub-questions - state.update_sub_questions(["Economic impacts", "Environmental impacts"]) +def sample_artifacts(): + """Create sample artifacts for testing.""" + # Create QueryContext + query_context = QueryContext( + main_query="What are the impacts of climate change?", + sub_questions=["Economic impacts", "Environmental impacts"], + ) - # Add search results + # Create SearchData search_results: Dict[str, List[SearchResult]] = { "Economic impacts": [ SearchResult( @@ -37,11 +42,19 @@ def sample_research_state() -> ResearchState: snippet="Overview of economic impacts", content="Detailed content about economic impacts of climate change", ) - ] + ], + "Environmental impacts": [ + SearchResult( + url="https://example.com/environment", + title="Environmental Impacts", + snippet="Environmental impact overview", + content="Content about environmental impacts", + ) + ], } - state.update_search_results(search_results) + search_data = SearchData(search_results=search_results) - # Add synthesized info + # Create SynthesisData synthesized_info: Dict[str, SynthesizedInfo] = { "Economic impacts": SynthesizedInfo( synthesized_answer="Climate change will have significant economic impacts...", @@ -54,12 +67,12 @@ def sample_research_state() -> ResearchState: confidence_level="high", ), } - state.update_synthesized_info(synthesized_info) - - # Add enhanced info (same as synthesized for this test) - state.enhanced_info = state.synthesized_info + synthesis_data = SynthesisData( + synthesized_info=synthesized_info, + enhanced_info=synthesized_info, # Same as synthesized for this test + ) - # Add viewpoint analysis + # Create AnalysisData viewpoint_analysis = ViewpointAnalysis( main_points_of_agreement=[ "Climate change is happening", @@ -77,9 +90,7 @@ def sample_research_state() -> ResearchState: perspective_gaps="Indigenous perspectives are underrepresented", integrative_insights="A balanced approach combining regulations and market incentives may be most effective", ) - state.update_viewpoint_analysis(viewpoint_analysis) - # Add reflection metadata reflection_metadata = ReflectionMetadata( critique_summary=["Need more sources for economic impacts"], additional_questions_identified=[ @@ -89,43 +100,111 @@ def sample_research_state() -> ResearchState: "economic impacts of climate change", "regional climate impacts", ], - improvements_made=2, + improvements_made=2.0, ) - state.reflection_metadata = reflection_metadata - return state + analysis_data = AnalysisData( + viewpoint_analysis=viewpoint_analysis, + reflection_metadata=reflection_metadata, + ) + + # Create prompts + conclusion_prompt = Prompt( + name="conclusion_generation", + content="Generate a conclusion based on the research findings.", + ) + executive_summary_prompt = Prompt( + name="executive_summary", content="Generate an executive summary." + ) + introduction_prompt = Prompt( + name="introduction", content="Generate an introduction." + ) + + return { + "query_context": query_context, + "search_data": search_data, + "synthesis_data": synthesis_data, + "analysis_data": analysis_data, + "conclusion_generation_prompt": conclusion_prompt, + "executive_summary_prompt": executive_summary_prompt, + "introduction_prompt": introduction_prompt, + } def test_pydantic_final_report_step_returns_tuple(): - """Test that the step returns a tuple with state and HTML.""" - # Create a simple state - state = ResearchState(main_query="What is climate change?") - state.update_sub_questions(["What causes climate change?"]) + """Test that the step returns a tuple with FinalReport and HTML.""" + # Create simple artifacts + query_context = QueryContext( + main_query="What is climate change?", + sub_questions=["What causes climate change?"], + ) + search_data = SearchData() + synthesis_data = SynthesisData( + synthesized_info={ + "What causes climate change?": SynthesizedInfo( + synthesized_answer="Climate change is caused by greenhouse gases.", + confidence_level="high", + key_sources=["https://example.com/causes"], + ) + } + ) + analysis_data = AnalysisData() + + # Create prompts + conclusion_prompt = Prompt( + name="conclusion_generation", content="Generate a conclusion." + ) + executive_summary_prompt = Prompt( + name="executive_summary", content="Generate summary." + ) + introduction_prompt = Prompt( + name="introduction", content="Generate intro." + ) # Run the step - result = pydantic_final_report_step(state=state) + result = pydantic_final_report_step( + query_context=query_context, + search_data=search_data, + synthesis_data=synthesis_data, + analysis_data=analysis_data, + conclusion_generation_prompt=conclusion_prompt, + executive_summary_prompt=executive_summary_prompt, + introduction_prompt=introduction_prompt, + ) # Assert that result is a tuple with 2 elements assert isinstance(result, tuple) assert len(result) == 2 - # Assert first element is ResearchState - assert isinstance(result[0], ResearchState) + # Assert first element is FinalReport + assert isinstance(result[0], FinalReport) # Assert second element is HTMLString assert isinstance(result[1], HTMLString) -def test_pydantic_final_report_step_with_complex_state(sample_research_state): - """Test that the step handles a complex state properly.""" - # Run the step with a complex state - result = pydantic_final_report_step(state=sample_research_state) +def test_pydantic_final_report_step_with_complex_artifacts(sample_artifacts): + """Test that the step handles complex artifacts properly.""" + # Run the step with complex artifacts + result = pydantic_final_report_step( + query_context=sample_artifacts["query_context"], + search_data=sample_artifacts["search_data"], + synthesis_data=sample_artifacts["synthesis_data"], + analysis_data=sample_artifacts["analysis_data"], + conclusion_generation_prompt=sample_artifacts[ + "conclusion_generation_prompt" + ], + executive_summary_prompt=sample_artifacts["executive_summary_prompt"], + introduction_prompt=sample_artifacts["introduction_prompt"], + ) # Unpack the results - updated_state, html_report = result + final_report, html_report = result - # Assert state contains final report HTML - assert updated_state.final_report_html != "" + # Assert FinalReport contains expected data + assert final_report.main_query == "What are the impacts of climate change?" + assert len(final_report.sub_questions) == 2 + assert final_report.report_html != "" # Assert HTML report contains key elements html_str = str(html_report) @@ -136,32 +215,51 @@ def test_pydantic_final_report_step_with_complex_state(sample_research_state): assert "Conservative" in html_str -def test_pydantic_final_report_step_updates_state(): - """Test that the step properly updates the state.""" - # Create an initial state without a final report - state = ResearchState( +def test_pydantic_final_report_step_creates_report(): + """Test that the step properly creates a final report.""" + # Create artifacts + query_context = QueryContext( main_query="What is climate change?", sub_questions=["What causes climate change?"], + ) + search_data = SearchData() + synthesis_data = SynthesisData( synthesized_info={ "What causes climate change?": SynthesizedInfo( synthesized_answer="Climate change is caused by greenhouse gases.", confidence_level="high", + key_sources=["https://example.com/causes"], ) - }, - enhanced_info={ - "What causes climate change?": SynthesizedInfo( - synthesized_answer="Climate change is caused by greenhouse gases.", - confidence_level="high", - ) - }, + } ) + analysis_data = AnalysisData() - # Verify initial state has no report - assert state.final_report_html == "" + # Create prompts + conclusion_prompt = Prompt( + name="conclusion_generation", content="Generate a conclusion." + ) + executive_summary_prompt = Prompt( + name="executive_summary", content="Generate summary." + ) + introduction_prompt = Prompt( + name="introduction", content="Generate intro." + ) # Run the step - updated_state, _ = pydantic_final_report_step(state=state) + final_report, html_report = pydantic_final_report_step( + query_context=query_context, + search_data=search_data, + synthesis_data=synthesis_data, + analysis_data=analysis_data, + conclusion_generation_prompt=conclusion_prompt, + executive_summary_prompt=executive_summary_prompt, + introduction_prompt=introduction_prompt, + ) + + # Verify FinalReport was created with content + assert final_report.report_html != "" + assert "climate change" in final_report.report_html.lower() - # Verify state was updated with a report - assert updated_state.final_report_html != "" - assert "climate change" in updated_state.final_report_html.lower() + # Verify HTML report was created + assert str(html_report) != "" + assert "climate change" in str(html_report).lower() diff --git a/deep_research/tests/test_pydantic_materializer.py b/deep_research/tests/test_pydantic_materializer.py deleted file mode 100644 index 49cb17f8..00000000 --- a/deep_research/tests/test_pydantic_materializer.py +++ /dev/null @@ -1,161 +0,0 @@ -"""Tests for Pydantic-based materializer. - -This module contains tests for the Pydantic-based implementation of -ResearchStateMaterializer, verifying that it correctly serializes and -visualizes ResearchState objects. -""" - -import os -import tempfile -from typing import Dict, List - -import pytest -from materializers.pydantic_materializer import ResearchStateMaterializer -from utils.pydantic_models import ( - ResearchState, - SearchResult, - SynthesizedInfo, - ViewpointAnalysis, - ViewpointTension, -) - - -@pytest.fixture -def sample_state() -> ResearchState: - """Create a sample research state for testing.""" - # Create a basic research state - state = ResearchState(main_query="What are the impacts of climate change?") - - # Add sub-questions - state.update_sub_questions(["Economic impacts", "Environmental impacts"]) - - # Add search results - search_results: Dict[str, List[SearchResult]] = { - "Economic impacts": [ - SearchResult( - url="https://example.com/economy", - title="Economic Impacts of Climate Change", - snippet="Overview of economic impacts", - content="Detailed content about economic impacts of climate change", - ) - ] - } - state.update_search_results(search_results) - - # Add synthesized info - synthesized_info: Dict[str, SynthesizedInfo] = { - "Economic impacts": SynthesizedInfo( - synthesized_answer="Climate change will have significant economic impacts...", - key_sources=["https://example.com/economy"], - confidence_level="high", - ) - } - state.update_synthesized_info(synthesized_info) - - return state - - -def test_materializer_initialization(): - """Test that the materializer can be initialized.""" - # Create a temporary directory for artifact storage - with tempfile.TemporaryDirectory() as tmpdirname: - materializer = ResearchStateMaterializer(uri=tmpdirname) - assert materializer is not None - - -def test_materializer_save_and_load(sample_state: ResearchState): - """Test saving and loading a state using the materializer.""" - # Create a temporary directory for artifact storage - with tempfile.TemporaryDirectory() as tmpdirname: - # Initialize materializer with temporary artifact URI - materializer = ResearchStateMaterializer(uri=tmpdirname) - - # Save the state - materializer.save(sample_state) - - # Load the state - loaded_state = materializer.load(ResearchState) - - # Verify that the loaded state matches the original - assert loaded_state.main_query == sample_state.main_query - assert loaded_state.sub_questions == sample_state.sub_questions - assert len(loaded_state.search_results) == len( - sample_state.search_results - ) - assert ( - loaded_state.get_current_stage() - == sample_state.get_current_stage() - ) - - # Check that key fields were preserved - question = "Economic impacts" - assert ( - loaded_state.synthesized_info[question].synthesized_answer - == sample_state.synthesized_info[question].synthesized_answer - ) - assert ( - loaded_state.synthesized_info[question].confidence_level - == sample_state.synthesized_info[question].confidence_level - ) - - -def test_materializer_save_visualizations(sample_state: ResearchState): - """Test generating and saving visualizations.""" - # Create a temporary directory for artifact storage - with tempfile.TemporaryDirectory() as tmpdirname: - # Initialize materializer with temporary artifact URI - materializer = ResearchStateMaterializer(uri=tmpdirname) - - # Generate and save visualizations - viz_paths = materializer.save_visualizations(sample_state) - - # Verify visualization file exists - html_path = list(viz_paths.keys())[0] - assert os.path.exists(html_path) - - # Verify the file has content - with open(html_path, "r") as f: - content = f.read() - # Check for expected elements in the HTML - assert "Research State" in content - assert sample_state.main_query in content - assert "Economic impacts" in content - - -def test_html_generation_stages(sample_state: ResearchState): - """Test that HTML visualization reflects the correct research stage.""" - # Create the materializer - with tempfile.TemporaryDirectory() as tmpdirname: - materializer = ResearchStateMaterializer(uri=tmpdirname) - - # Generate visualization at initial state - html = materializer._generate_visualization_html(sample_state) - # Verify stage by checking for expected elements in the HTML - assert ( - "Synthesized Information" in html - ) # Should show synthesized info - - # Add viewpoint analysis - state_with_viewpoints = sample_state.model_copy(deep=True) - viewpoint_analysis = ViewpointAnalysis( - main_points_of_agreement=["There will be economic impacts"], - areas_of_tension=[ - ViewpointTension( - topic="Job impacts", - viewpoints={ - "Positive": "New green jobs", - "Negative": "Job losses", - }, - ) - ], - ) - state_with_viewpoints.update_viewpoint_analysis(viewpoint_analysis) - html = materializer._generate_visualization_html(state_with_viewpoints) - assert "Viewpoint Analysis" in html - assert "Points of Agreement" in html - - # Add final report - state_with_report = state_with_viewpoints.model_copy(deep=True) - state_with_report.set_final_report("Final report content") - html = materializer._generate_visualization_html(state_with_report) - assert "Final Report" in html diff --git a/deep_research/tests/test_pydantic_models.py b/deep_research/tests/test_pydantic_models.py index b900f8d8..21d25123 100644 --- a/deep_research/tests/test_pydantic_models.py +++ b/deep_research/tests/test_pydantic_models.py @@ -8,11 +8,9 @@ """ import json -from typing import Dict, List from utils.pydantic_models import ( ReflectionMetadata, - ResearchState, SearchResult, SynthesizedInfo, ViewpointAnalysis, @@ -199,105 +197,3 @@ def test_reflection_metadata_model(): new_metadata = ReflectionMetadata.model_validate(metadata_dict) assert new_metadata.improvements_made == metadata.improvements_made assert new_metadata.critique_summary == metadata.critique_summary - - -def test_research_state_model(): - """Test the main ResearchState model.""" - # Create with defaults - state = ResearchState() - assert state.main_query == "" - assert state.sub_questions == [] - assert state.search_results == {} - assert state.get_current_stage() == "empty" - - # Set main query - state.main_query = "What are the impacts of climate change?" - assert state.get_current_stage() == "initial" - - # Test update methods - state.update_sub_questions( - ["What are economic impacts?", "What are environmental impacts?"] - ) - assert len(state.sub_questions) == 2 - assert state.get_current_stage() == "after_query_decomposition" - - # Add search results - search_results: Dict[str, List[SearchResult]] = { - "What are economic impacts?": [ - SearchResult( - url="https://example.com/economy", - title="Economic Impacts", - snippet="Overview of economic impacts", - content="Detailed content about economic impacts", - ) - ] - } - state.update_search_results(search_results) - assert state.get_current_stage() == "after_search" - assert len(state.search_results["What are economic impacts?"]) == 1 - - # Add synthesized info - synthesized_info: Dict[str, SynthesizedInfo] = { - "What are economic impacts?": SynthesizedInfo( - synthesized_answer="Economic impacts include job losses and growth opportunities", - key_sources=["https://example.com/economy"], - confidence_level="high", - ) - } - state.update_synthesized_info(synthesized_info) - assert state.get_current_stage() == "after_synthesis" - - # Add viewpoint analysis - analysis = ViewpointAnalysis( - main_points_of_agreement=["Economic changes are happening"], - areas_of_tension=[ - ViewpointTension( - topic="Job impacts", - viewpoints={ - "Positive": "New green jobs", - "Negative": "Fossil fuel job losses", - }, - ) - ], - ) - state.update_viewpoint_analysis(analysis) - assert state.get_current_stage() == "after_viewpoint_analysis" - - # Add reflection results - enhanced_info = { - "What are economic impacts?": SynthesizedInfo( - synthesized_answer="Enhanced answer with more details", - key_sources=[ - "https://example.com/economy", - "https://example.com/new-source", - ], - confidence_level="high", - improvements=["Added more context", "Added more sources"], - ) - } - metadata = ReflectionMetadata( - critique_summary=["Needed more sources"], - improvements_made=2, - ) - state.update_after_reflection(enhanced_info, metadata) - assert state.get_current_stage() == "after_reflection" - - # Set final report - state.set_final_report("Final report content") - assert state.get_current_stage() == "final_report" - assert state.final_report_html == "Final report content" - - # Test serialization and deserialization - state_dict = state.model_dump() - new_state = ResearchState.model_validate(state_dict) - - # Verify key properties were preserved - assert new_state.main_query == state.main_query - assert len(new_state.sub_questions) == len(state.sub_questions) - assert new_state.get_current_stage() == state.get_current_stage() - assert new_state.viewpoint_analysis is not None - assert len(new_state.viewpoint_analysis.areas_of_tension) == 1 - assert ( - new_state.viewpoint_analysis.areas_of_tension[0].topic == "Job impacts" - ) - assert new_state.final_report_html == state.final_report_html diff --git a/deep_research/utils/approval_utils.py b/deep_research/utils/approval_utils.py index 56a8ff91..94cd5a47 100644 --- a/deep_research/utils/approval_utils.py +++ b/deep_research/utils/approval_utils.py @@ -5,28 +5,6 @@ from utils.pydantic_models import ApprovalDecision -def summarize_research_progress(state) -> Dict[str, Any]: - """Summarize the current research progress.""" - completed_count = len(state.synthesized_info) - confidence_levels = [ - info.confidence_level for info in state.synthesized_info.values() - ] - - # Calculate average confidence (high=1.0, medium=0.5, low=0.0) - confidence_map = {"high": 1.0, "medium": 0.5, "low": 0.0} - avg_confidence = sum( - confidence_map.get(c, 0.5) for c in confidence_levels - ) / max(len(confidence_levels), 1) - - low_confidence_count = sum(1 for c in confidence_levels if c == "low") - - return { - "completed_count": completed_count, - "avg_confidence": round(avg_confidence, 2), - "low_confidence_count": low_confidence_count, - } - - def format_critique_summary(critique_points: List[Dict[str, Any]]) -> str: """Format critique points for display.""" if not critique_points: diff --git a/deep_research/utils/pydantic_models.py b/deep_research/utils/pydantic_models.py index 02ecc3c0..822afe99 100644 --- a/deep_research/utils/pydantic_models.py +++ b/deep_research/utils/pydantic_models.py @@ -247,21 +247,6 @@ def set_final_report(self, html: str) -> None: self.final_report_html = html -class ReflectionOutput(BaseModel): - """Output from the reflection generation step.""" - - state: ResearchState - recommended_queries: List[str] = Field(default_factory=list) - critique_summary: List[Dict[str, Any]] = Field(default_factory=list) - additional_questions: List[str] = Field(default_factory=list) - - model_config = { - "extra": "ignore", - "frozen": False, - "validate_assignment": True, - } - - class ApprovalDecision(BaseModel): """Approval decision from human reviewer.""" @@ -360,3 +345,159 @@ class TracingMetadata(BaseModel): "frozen": False, "validate_assignment": True, } + + +# ============================================================================ +# New Artifact Classes for ResearchState Refactoring +# ============================================================================ + + +class QueryContext(BaseModel): + """Immutable context containing the research query and its decomposition. + + This artifact is created once at the beginning of the pipeline and + remains unchanged throughout execution. + """ + + main_query: str = Field( + ..., description="The main research question from the user" + ) + sub_questions: List[str] = Field( + default_factory=list, + description="Decomposed sub-questions for parallel processing", + ) + decomposition_timestamp: float = Field( + default_factory=lambda: time.time(), + description="When the query was decomposed", + ) + + model_config = { + "extra": "ignore", + "frozen": True, # Make immutable after creation + "validate_assignment": True, + } + + +class SearchCostDetail(BaseModel): + """Detailed information about a single search operation.""" + + provider: str + query: str + cost: float + timestamp: float + step: str + sub_question: Optional[str] = None + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +class SearchData(BaseModel): + """Accumulates search results and cost tracking throughout the pipeline. + + This artifact grows as searches are performed and can be merged + when parallel searches complete. + """ + + search_results: Dict[str, List[SearchResult]] = Field( + default_factory=dict, + description="Map of sub-question to search results", + ) + search_costs: Dict[str, float] = Field( + default_factory=dict, description="Total costs by provider" + ) + search_cost_details: List[SearchCostDetail] = Field( + default_factory=list, description="Detailed log of each search" + ) + total_searches: int = Field( + default=0, description="Total number of searches performed" + ) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + def merge(self, other: "SearchData") -> "SearchData": + """Merge another SearchData instance into this one.""" + # Merge search results + for sub_q, results in other.search_results.items(): + if sub_q in self.search_results: + self.search_results[sub_q].extend(results) + else: + self.search_results[sub_q] = results + + # Merge costs + for provider, cost in other.search_costs.items(): + self.search_costs[provider] = ( + self.search_costs.get(provider, 0.0) + cost + ) + + # Merge cost details + self.search_cost_details.extend(other.search_cost_details) + + # Update total searches + self.total_searches += other.total_searches + + return self + + +class SynthesisData(BaseModel): + """Contains synthesized information for all sub-questions.""" + + synthesized_info: Dict[str, SynthesizedInfo] = Field( + default_factory=dict, + description="Synthesized answers for each sub-question", + ) + enhanced_info: Dict[str, SynthesizedInfo] = Field( + default_factory=dict, + description="Enhanced information after reflection", + ) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + def merge(self, other: "SynthesisData") -> "SynthesisData": + """Merge another SynthesisData instance into this one.""" + self.synthesized_info.update(other.synthesized_info) + self.enhanced_info.update(other.enhanced_info) + return self + + +class AnalysisData(BaseModel): + """Contains viewpoint analysis and reflection metadata.""" + + viewpoint_analysis: Optional[ViewpointAnalysis] = None + reflection_metadata: Optional[ReflectionMetadata] = None + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } + + +class FinalReport(BaseModel): + """Contains the final HTML report.""" + + report_html: str = Field(default="", description="The final HTML report") + generated_at: float = Field( + default_factory=lambda: time.time(), + description="Timestamp when report was generated", + ) + main_query: str = Field( + default="", description="The original research query" + ) + + model_config = { + "extra": "ignore", + "frozen": False, + "validate_assignment": True, + } From c9d11e7025c25c8556f645bf2e8fd370a90b83d0 Mon Sep 17 00:00:00 2001 From: Alex Strick van Linschoten Date: Wed, 28 May 2025 09:29:31 +0200 Subject: [PATCH 07/11] Refactor CSS to shared styles and utility functions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Created shared CSS file at assets/styles.css consolidating all common styles - Added CSS utility functions in utils/css_utils.py for consistent styling - Updated all materializers to use shared CSS instead of inline styles - Updated report template in prompts.py to use shared CSS - Fixed unused 'question' variable in approval_decision_materializer.py - Reduced code duplication across all visualization components This refactoring improves maintainability by centralizing style definitions and provides consistent visual styling across all materializers. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- deep_research/assets/styles.css | 672 +++++++++++++++ .../analysis_data_materializer.py | 217 +---- .../approval_decision_materializer.py | 184 ++--- .../final_report_materializer.py | 98 +-- .../materializers/prompt_materializer.py | 228 +---- .../query_context_materializer.py | 143 ++-- .../materializers/search_data_materializer.py | 275 ++----- .../synthesis_data_materializer.py | 287 ++----- .../tracing_metadata_materializer.py | 344 +++----- .../steps/pydantic_final_report_step.py | 2 + deep_research/utils/css_utils.py | 221 +++++ deep_research/utils/prompts.py | 779 +++++++++++++----- 12 files changed, 1938 insertions(+), 1512 deletions(-) create mode 100644 deep_research/assets/styles.css create mode 100644 deep_research/utils/css_utils.py diff --git a/deep_research/assets/styles.css b/deep_research/assets/styles.css new file mode 100644 index 00000000..130a0c3b --- /dev/null +++ b/deep_research/assets/styles.css @@ -0,0 +1,672 @@ +/* =================================== + Deep Research Pipeline Global Styles + =================================== */ + +/* 1. CSS Variables / Custom Properties */ +:root { + /* Color Palette */ + --color-primary: #3498db; + --color-primary-dark: #2980b9; + --color-secondary: #667eea; + --color-secondary-dark: #5a63d8; + --color-accent: #764ba2; + + /* Status Colors */ + --color-success: #27ae60; + --color-success-light: #d4edda; + --color-success-dark: #155724; + --color-warning: #f39c12; + --color-warning-light: #fff3cd; + --color-warning-dark: #856404; + --color-danger: #e74c3c; + --color-danger-light: #f8d7da; + --color-danger-dark: #721c24; + --color-info: #17a2b8; + --color-info-light: #d1ecf1; + --color-info-dark: #0c5460; + + /* Chart Colors */ + --color-chart-1: #5e72e4; + --color-chart-2: #2dce89; + --color-chart-3: #11cdef; + --color-chart-4: #f5365c; + --color-chart-5: #fb6340; + --color-chart-6: #ffd600; + + /* Neutrals */ + --color-text-primary: #333; + --color-text-secondary: #666; + --color-text-muted: #999; + --color-text-light: #7f8c8d; + --color-heading: #2c3e50; + --color-bg-primary: #f5f7fa; + --color-bg-secondary: #f8f9fa; + --color-bg-light: #f0f2f5; + --color-bg-white: #ffffff; + --color-border: #e9ecef; + --color-border-light: #dee2e6; + --color-border-dark: #ddd; + + /* Typography */ + --font-family-base: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; + --font-family-mono: Consolas, Monaco, 'Andale Mono', 'Courier New', monospace; + + /* Spacing */ + --spacing-xs: 5px; + --spacing-sm: 10px; + --spacing-md: 20px; + --spacing-lg: 30px; + --spacing-xl: 40px; + --spacing-xxl: 60px; + + /* Border Radius */ + --radius-sm: 4px; + --radius-md: 8px; + --radius-lg: 15px; + --radius-xl: 20px; + --radius-round: 50%; + + /* Shadows */ + --shadow-sm: 0 2px 4px rgba(0, 0, 0, 0.1); + --shadow-md: 0 5px 20px rgba(0, 0, 0, 0.08); + --shadow-lg: 0 10px 30px rgba(0, 0, 0, 0.1); + --shadow-xl: 0 20px 60px rgba(0, 0, 0, 0.3); + --shadow-hover: 0 8px 25px rgba(0, 0, 0, 0.12); + --shadow-hover-lg: 0 5px 15px rgba(0, 0, 0, 0.3); + + /* Transitions */ + --transition-base: all 0.3s ease; + --transition-fast: all 0.2s ease; +} + +/* 2. Base Styles */ +* { + box-sizing: border-box; +} + +body { + font-family: var(--font-family-base); + line-height: 1.6; + color: var(--color-text-primary); + background-color: var(--color-bg-primary); + margin: 0; + padding: var(--spacing-md); +} + +/* 3. Layout Components */ +.dr-container { + max-width: 1200px; + margin: 0 auto; + padding: var(--spacing-md); +} + +.dr-container--wide { + max-width: 1400px; +} + +.dr-container--narrow { + max-width: 900px; +} + +/* 4. Typography */ +.dr-h1, h1 { + color: var(--color-heading); + font-size: 2rem; + margin: 0 0 var(--spacing-md) 0; + padding-bottom: var(--spacing-sm); + border-bottom: 2px solid var(--color-primary); +} + +.dr-h1--no-border { + border-bottom: none; + padding-bottom: 0; +} + +.dr-h2, h2 { + color: var(--color-heading); + font-size: 1.5rem; + margin-top: var(--spacing-lg); + margin-bottom: var(--spacing-md); + border-bottom: 1px solid var(--color-border); + padding-bottom: var(--spacing-xs); +} + +.dr-h3, h3 { + color: var(--color-primary); + font-size: 1.2rem; + margin-top: var(--spacing-md); + margin-bottom: var(--spacing-sm); +} + +p { + margin: 15px 0; + line-height: 1.8; + color: var(--color-text-secondary); +} + +/* 5. Card Components */ +.dr-card { + background: var(--color-bg-white); + border-radius: var(--radius-lg); + padding: var(--spacing-lg); + box-shadow: var(--shadow-md); + margin-bottom: var(--spacing-lg); + transition: var(--transition-base); +} + +.dr-card:hover { + transform: translateY(-3px); + box-shadow: var(--shadow-hover); +} + +.dr-card--bordered { + border: 1px solid var(--color-border); +} + +.dr-card--no-hover:hover { + transform: none; + box-shadow: var(--shadow-md); +} + +/* Header Cards */ +.dr-header-card { + background: white; + border-radius: var(--radius-lg); + padding: var(--spacing-lg); + box-shadow: var(--shadow-md); + margin-bottom: var(--spacing-lg); +} + +/* 6. Grid System */ +.dr-grid { + display: grid; + gap: var(--spacing-md); +} + +.dr-grid--stats { + grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); +} + +.dr-grid--cards { + grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); +} + +.dr-grid--metrics { + grid-template-columns: repeat(auto-fit, minmax(250px, 1fr)); +} + +/* 7. Badges & Tags */ +.dr-badge { + display: inline-block; + padding: var(--spacing-xs) calc(var(--spacing-sm) * 1.5); + border-radius: 20px; + font-size: 0.875rem; + font-weight: bold; + text-transform: uppercase; + letter-spacing: 0.5px; +} + +.dr-badge--success { + background-color: var(--color-success-light); + color: var(--color-success-dark); +} + +.dr-badge--warning { + background-color: var(--color-warning-light); + color: var(--color-warning-dark); +} + +.dr-badge--danger { + background-color: var(--color-danger-light); + color: var(--color-danger-dark); +} + +.dr-badge--info { + background-color: var(--color-info-light); + color: var(--color-info-dark); +} + +.dr-badge--primary { + background-color: var(--color-primary); + color: white; +} + +/* Tag variations */ +.dr-tag { + display: inline-block; + background-color: #f0f0f0; + color: #555; + padding: 3px 12px; + border-radius: 15px; + font-size: 0.75rem; + font-weight: 500; + margin: 2px; +} + +.dr-tag--primary { + background-color: #e1f5fe; + color: #0277bd; +} + +/* 8. Stat Cards */ +.dr-stat-card { + background: var(--color-bg-secondary); + border-radius: var(--radius-md); + padding: var(--spacing-md) 25px; + text-align: center; + transition: var(--transition-base); + border: 2px solid var(--color-border); +} + +.dr-stat-card:hover { + transform: translateY(-5px); + box-shadow: var(--shadow-hover); +} + +.dr-stat-value { + font-size: 2.25rem; + font-weight: bold; + color: var(--color-chart-1); + margin-bottom: var(--spacing-xs); + display: block; +} + +.dr-stat-label { + color: var(--color-text-secondary); + font-size: 0.875rem; + text-transform: uppercase; + letter-spacing: 0.5px; + display: block; +} + +/* 9. Sections */ +.dr-section { + background: var(--color-bg-white); + border-radius: var(--radius-lg); + padding: var(--spacing-lg); + margin-bottom: var(--spacing-lg); + box-shadow: var(--shadow-md); +} + +.dr-section--bordered { + border-left: 4px solid var(--color-primary); +} + +.dr-section--info { + background-color: #e8f4f8; + border-left: 4px solid var(--color-primary); +} + +.dr-section--warning { + background-color: var(--color-warning-light); + border-left: 4px solid var(--color-warning); +} + +.dr-section--success { + background-color: var(--color-success-light); + border-left: 4px solid var(--color-success); +} + +.dr-section--danger { + background-color: var(--color-danger-light); + border-left: 4px solid var(--color-danger); +} + +/* 10. Tables */ +.dr-table { + width: 100%; + border-collapse: collapse; + margin: var(--spacing-md) 0; + background: var(--color-bg-white); + overflow: hidden; +} + +.dr-table th { + background-color: var(--color-primary); + color: white; + padding: var(--spacing-sm); + text-align: left; + font-weight: 600; +} + +.dr-table td { + padding: var(--spacing-sm); + border-bottom: 1px solid var(--color-border); +} + +.dr-table tr:last-child td { + border-bottom: none; +} + +.dr-table tr:hover { + background-color: var(--color-bg-secondary); +} + +.dr-table--striped tr:nth-child(even) { + background-color: #f2f2f2; +} + +/* 11. Buttons */ +.dr-button { + background: var(--color-primary); + color: white; + border: none; + padding: var(--spacing-sm) var(--spacing-md); + border-radius: var(--radius-md); + font-size: 1rem; + font-weight: 500; + cursor: pointer; + transition: var(--transition-base); + display: inline-flex; + align-items: center; + gap: var(--spacing-xs); + text-decoration: none; +} + +.dr-button:hover { + background: var(--color-primary-dark); + transform: translateY(-2px); + box-shadow: 0 5px 15px rgba(52, 152, 219, 0.3); +} + +.dr-button--secondary { + background: var(--color-secondary); +} + +.dr-button--secondary:hover { + background: var(--color-secondary-dark); + box-shadow: 0 5px 15px rgba(102, 126, 234, 0.3); +} + +.dr-button--success { + background: var(--color-success); +} + +.dr-button--small { + padding: 6px 15px; + font-size: 0.875rem; +} + +/* 12. Confidence Indicators */ +.dr-confidence { + display: inline-flex; + align-items: center; + padding: 6px 15px; + border-radius: 30px; + font-weight: bold; + gap: var(--spacing-xs); + box-shadow: var(--shadow-sm); +} + +.dr-confidence--high { + background: linear-gradient(to right, #d4edda, #c3e6cb); + color: var(--color-success-dark); +} + +.dr-confidence--medium { + background: linear-gradient(to right, #fff3cd, #ffeeba); + color: var(--color-warning-dark); +} + +.dr-confidence--low { + background: linear-gradient(to right, #f8d7da, #f5c6cb); + color: var(--color-danger-dark); +} + +/* 13. Chart Containers */ +.dr-chart-container { + position: relative; + height: 300px; + margin: var(--spacing-md) 0; +} + +/* 14. Code Blocks */ +.dr-code { + background-color: #f7f7f7; + border: 1px solid #e1e1e8; + border-radius: var(--radius-sm); + padding: var(--spacing-sm); + font-family: var(--font-family-mono); + overflow-x: auto; + white-space: pre-wrap; + word-wrap: break-word; +} + +/* 15. Lists */ +.dr-list { + margin: var(--spacing-sm) 0; + padding-left: 25px; +} + +.dr-list li { + margin: 8px 0; + line-height: 1.6; +} + +.dr-list--unstyled { + list-style-type: none; + padding-left: 0; +} + +/* 16. Notice Boxes */ +.dr-notice { + padding: 15px; + margin: 20px 0; + border-radius: var(--radius-sm); +} + +.dr-notice--info { + background-color: #e8f4f8; + border-left: 4px solid var(--color-primary); + color: var(--color-info-dark); +} + +.dr-notice--warning { + background-color: var(--color-warning-light); + border-left: 4px solid var(--color-warning); + color: var(--color-warning-dark); +} + +/* 17. Loading States */ +.dr-loading { + text-align: center; + padding: var(--spacing-xxl); + color: var(--color-text-secondary); + font-style: italic; +} + +/* 18. Empty States */ +.dr-empty { + text-align: center; + color: var(--color-text-muted); + font-style: italic; + padding: var(--spacing-xl); + background: var(--color-bg-white); + border-radius: var(--radius-lg); + box-shadow: var(--shadow-md); +} + +/* 19. Utility Classes */ +.dr-text-center { text-align: center; } +.dr-text-right { text-align: right; } +.dr-text-left { text-align: left; } +.dr-text-muted { color: var(--color-text-muted); } +.dr-text-secondary { color: var(--color-text-secondary); } +.dr-text-primary { color: var(--color-text-primary); } + +/* Margin utilities */ +.dr-mt-xs { margin-top: var(--spacing-xs); } +.dr-mt-sm { margin-top: var(--spacing-sm); } +.dr-mt-md { margin-top: var(--spacing-md); } +.dr-mt-lg { margin-top: var(--spacing-lg); } +.dr-mt-xl { margin-top: var(--spacing-xl); } + +.dr-mb-xs { margin-bottom: var(--spacing-xs); } +.dr-mb-sm { margin-bottom: var(--spacing-sm); } +.dr-mb-md { margin-bottom: var(--spacing-md); } +.dr-mb-lg { margin-bottom: var(--spacing-lg); } +.dr-mb-xl { margin-bottom: var(--spacing-xl); } + +/* Padding utilities */ +.dr-p-sm { padding: var(--spacing-sm); } +.dr-p-md { padding: var(--spacing-md); } +.dr-p-lg { padding: var(--spacing-lg); } + +/* Display utilities */ +.dr-d-none { display: none; } +.dr-d-block { display: block; } +.dr-d-flex { display: flex; } +.dr-d-grid { display: grid; } + +/* Flex utilities */ +.dr-flex-center { + display: flex; + align-items: center; + justify-content: center; +} + +.dr-flex-between { + display: flex; + align-items: center; + justify-content: space-between; +} + +/* 20. Special Components */ + +/* Mind Map Styles */ +.dr-mind-map { + position: relative; + margin: var(--spacing-xl) 0; +} + +.dr-mind-map-node { + background: linear-gradient(135deg, var(--color-secondary) 0%, var(--color-accent) 100%); + color: white; + padding: var(--spacing-lg); + border-radius: var(--radius-lg); + text-align: center; + font-size: 1.25rem; + font-weight: bold; + box-shadow: 0 10px 30px rgba(102, 126, 234, 0.3); + margin-bottom: var(--spacing-xl); +} + +/* Result Cards */ +.dr-result-item { + background: var(--color-bg-secondary); + border-radius: var(--radius-md); + padding: 15px; + margin-bottom: 15px; + border: 1px solid var(--color-border); + transition: var(--transition-base); +} + +.dr-result-item:hover { + box-shadow: 0 5px 15px rgba(0, 0, 0, 0.1); + transform: translateY(-2px); +} + +.dr-result-title { + font-weight: bold; + color: var(--color-heading); + margin-bottom: 8px; +} + +.dr-result-snippet { + color: var(--color-text-secondary); + font-size: 0.875rem; + line-height: 1.6; + margin-bottom: 10px; +} + +.dr-result-link { + color: var(--color-chart-1); + text-decoration: none; + font-size: 0.875rem; + font-weight: 500; +} + +.dr-result-link:hover { + text-decoration: underline; +} + +/* Timestamp */ +.dr-timestamp { + text-align: right; + color: var(--color-text-light); + font-size: 0.875rem; + margin-top: var(--spacing-md); + padding-top: 15px; + border-top: 1px dashed var(--color-border-dark); +} + +/* 21. Gradients */ +.dr-gradient-primary { + background: linear-gradient(135deg, var(--color-secondary) 0%, var(--color-accent) 100%); +} + +.dr-gradient-header { + background: linear-gradient(90deg, #3498db, #2ecc71, #f1c40f, #e74c3c); + height: 5px; +} + +/* 22. Responsive Design */ +@media (max-width: 768px) { + body { + padding: var(--spacing-sm); + } + + .dr-container { + padding: var(--spacing-sm); + } + + .dr-grid--stats, + .dr-grid--cards, + .dr-grid--metrics { + grid-template-columns: 1fr; + } + + .dr-h1, h1 { + font-size: 1.5rem; + } + + .dr-h2, h2 { + font-size: 1.25rem; + } + + .dr-stat-value { + font-size: 2rem; + } + + .dr-section, + .dr-card { + padding: var(--spacing-md); + } + + .dr-table { + font-size: 0.875rem; + } + + .dr-table th, + .dr-table td { + padding: 8px; + } +} + +/* 23. Print Styles */ +@media print { + body { + background: white; + color: black; + } + + .dr-card, + .dr-section { + box-shadow: none; + border: 1px solid #ddd; + } + + .dr-button { + display: none; + } +} \ No newline at end of file diff --git a/deep_research/materializers/analysis_data_materializer.py b/deep_research/materializers/analysis_data_materializer.py index 79053a79..8b78aa91 100644 --- a/deep_research/materializers/analysis_data_materializer.py +++ b/deep_research/materializers/analysis_data_materializer.py @@ -3,6 +3,11 @@ import os from typing import Dict +from utils.css_utils import ( + get_card_class, + get_section_class, + get_shared_css_tag, +) from utils.pydantic_models import AnalysisData from zenml.enums import ArtifactType, VisualizationType from zenml.io import fileio @@ -51,7 +56,7 @@ def _generate_visualization_html(self, data: AnalysisData) -> str: # Points of agreement agreement_html = "" if va.main_points_of_agreement: - agreement_html = "

    Main Points of Agreement

      " + agreement_html = "

      Main Points of Agreement

        " for point in va.main_points_of_agreement: agreement_html += f"
      • {point}
      • " agreement_html += "
      " @@ -60,13 +65,13 @@ def _generate_visualization_html(self, data: AnalysisData) -> str: tensions_html = "" if va.areas_of_tension: tensions_html = ( - "

      Areas of Tension

      " + "

      Areas of Tension

      " ) for tension in va.areas_of_tension: viewpoints_html = "" for perspective, view in tension.viewpoints.items(): viewpoints_html += f""" -
      +
      {perspective}
      {view}
      @@ -86,7 +91,7 @@ def _generate_visualization_html(self, data: AnalysisData) -> str: gaps_html = "" if va.perspective_gaps: gaps_html = f""" -
      +

      Perspective Gaps

      {va.perspective_gaps}

      @@ -96,14 +101,14 @@ def _generate_visualization_html(self, data: AnalysisData) -> str: insights_html = "" if va.integrative_insights: insights_html = f""" -
      +

      Integrative Insights

      {va.integrative_insights}

      """ viewpoint_html = f""" -
      +

      Viewpoint Analysis

      {agreement_html} {tensions_html} @@ -120,7 +125,7 @@ def _generate_visualization_html(self, data: AnalysisData) -> str: # Critique summary critique_html = "" if rm.critique_summary: - critique_html = "

      Critique Summary

        " + critique_html = "

        Critique Summary

          " for critique in rm.critique_summary: critique_html += f"
        • {critique}
        • " critique_html += "
        " @@ -128,7 +133,7 @@ def _generate_visualization_html(self, data: AnalysisData) -> str: # Additional questions questions_html = "" if rm.additional_questions_identified: - questions_html = "

        Additional Questions Identified

          " + questions_html = "" @@ -136,7 +141,7 @@ def _generate_visualization_html(self, data: AnalysisData) -> str: # Searches performed searches_html = "" if rm.searches_performed: - searches_html = "

          Searches Performed

            " + searches_html = "

            Searches Performed

              " for search in rm.searches_performed: searches_html += f"
            • {search}
            • " searches_html += "
            " @@ -145,18 +150,18 @@ def _generate_visualization_html(self, data: AnalysisData) -> str: error_html = "" if rm.error: error_html = f""" -
            +

            Error Encountered

            {rm.error}

            """ reflection_html = f""" -
            +

            Reflection Metadata

            -
            - {int(rm.improvements_made)} - Improvements Made +
            + {int(rm.improvements_made)} + Improvements Made
            {critique_html} {questions_html} @@ -168,7 +173,7 @@ def _generate_visualization_html(self, data: AnalysisData) -> str: # Handle empty state if not viewpoint_html and not reflection_html: content_html = ( - '
            No analysis data available yet
            ' + '
            No analysis data available yet
            ' ) else: content_html = viewpoint_html + reflection_html @@ -178,79 +183,23 @@ def _generate_visualization_html(self, data: AnalysisData) -> str: Analysis Data Visualization + {get_shared_css_tag()} -
            -
            +
            +

            Research Analysis

            diff --git a/deep_research/materializers/approval_decision_materializer.py b/deep_research/materializers/approval_decision_materializer.py index 0a4ed2c3..e17b9d9b 100644 --- a/deep_research/materializers/approval_decision_materializer.py +++ b/deep_research/materializers/approval_decision_materializer.py @@ -4,6 +4,12 @@ from datetime import datetime from typing import Dict +from utils.css_utils import ( + get_card_class, + get_grid_class, + get_section_class, + get_shared_css_tag, +) from utils.pydantic_models import ApprovalDecision from zenml.enums import ArtifactType, VisualizationType from zenml.io import fileio @@ -54,13 +60,11 @@ def _generate_visualization_html(self, decision: ApprovalDecision) -> str: "%Y-%m-%d %H:%M:%S" ) - # Determine status color and icon + # Determine status icon and text if decision.approved: - status_color = "#27ae60" status_icon = "✅" status_text = "APPROVED" else: - status_color = "#e74c3c" status_icon = "❌" status_text = "NOT APPROVED" @@ -71,93 +75,76 @@ def _generate_visualization_html(self, decision: ApprovalDecision) -> str: "SELECT_SPECIFIC": "Select Specific Queries", }.get(decision.approval_method, decision.approval_method or "Unknown") + # Build info cards + info_cards_html = f""" +
            +
            +
            Approval Method
            +
            {method_display}
            +
            +
            +
            Decision Time
            +
            {decision_time}
            +
            +
            +
            Queries Selected
            +
            {len(decision.selected_queries)}
            +
            +
            + """ + html = f""" Approval Decision + {get_shared_css_tag()} -
            -

            - 🔒 Approval Decision -
            - {status_icon} - {status_text} -
            -

            - -
            -
            -
            Approval Method
            -
            {method_display}
            -
            -
            -
            Decision Time
            -
            {decision_time}
            -
            -
            -
            Queries Selected
            -
            {len(decision.selected_queries)}
            -
            -
            +
            +
            +

            + 🔒 Approval Decision +
            + {status_icon} + {status_text} +
            +

            + + {info_cards_html} """ # Add selected queries section if any if decision.selected_queries: - html += """ -
            + html += f""" +

            📋Selected Queries

            """ @@ -248,10 +207,10 @@ def _generate_visualization_html(self, decision: ApprovalDecision) -> str:
            """ else: - html += """ -
            + html += f""" +

            📋Selected Queries

            -
            +
            No queries were selected for additional research
            @@ -260,7 +219,7 @@ def _generate_visualization_html(self, decision: ApprovalDecision) -> str: # Add reviewer notes if any if decision.reviewer_notes: html += f""" -
            +

            📝Reviewer Notes

            {decision.reviewer_notes} @@ -270,8 +229,9 @@ def _generate_visualization_html(self, decision: ApprovalDecision) -> str: # Add timestamp footer html += f""" -
            - Decision recorded at: {decision_time} +
            + Decision recorded at: {decision_time} +
            diff --git a/deep_research/materializers/final_report_materializer.py b/deep_research/materializers/final_report_materializer.py index 3eb848ba..960e4afc 100644 --- a/deep_research/materializers/final_report_materializer.py +++ b/deep_research/materializers/final_report_materializer.py @@ -4,6 +4,7 @@ from datetime import datetime from typing import Dict +from utils.css_utils import get_shared_css_tag from utils.pydantic_models import FinalReport from zenml.enums import ArtifactType, VisualizationType from zenml.io import fileio @@ -68,74 +69,64 @@ def _generate_visualization_html(self, data: FinalReport) -> str: Final Research Report - {data.main_query[:50]}... + {get_shared_css_tag()} -
            -
            +
            +

            Final Research Report

            @@ -227,15 +193,15 @@ def _generate_visualization_html(self, data: FinalReport) -> str:
            - + Open in New Tab -
            -
            +
            Loading report...
            diff --git a/deep_research/materializers/prompt_materializer.py b/deep_research/materializers/prompt_materializer.py index 835c53a6..4e306419 100644 --- a/deep_research/materializers/prompt_materializer.py +++ b/deep_research/materializers/prompt_materializer.py @@ -7,6 +7,12 @@ import os from typing import Dict +from utils.css_utils import ( + create_stat_card, + get_card_class, + get_grid_class, + get_shared_css_tag, +) from utils.pydantic_models import Prompt from zenml.enums import ArtifactType, VisualizationType from zenml.io import fileio @@ -52,66 +58,45 @@ def _generate_visualization_html(self, prompt: Prompt) -> str: Returns: HTML string """ - # Determine tag colors + # Create tags HTML tag_html = "" if prompt.tags: - tag_colors = { - "search": "search", - "synthesis": "synthesis", - "analysis": "analysis", - "reflection": "reflection", - "report": "report", - "query": "query", - "decomposition": "decomposition", - "viewpoint": "viewpoint", - "conclusion": "conclusion", - "summary": "summary", - "introduction": "introduction", - } - tag_html = '
            ' for tag in prompt.tags: - tag_class = tag_colors.get(tag, "default") - tag_html += f'{tag}' + tag_html += ( + f'{tag}' + ) tag_html += "
            " + # Build stats HTML + stats_html = f""" +
            + {create_stat_card(len(prompt.content.split()), "Words")} + {create_stat_card(len(prompt.content), "Characters")} + {create_stat_card(len(prompt.content.splitlines()), "Lines")} +
            + """ + # Create HTML content html = f""" {prompt.name} - Prompt + {get_shared_css_tag()} -
            +
            -

            +

            🎯 {prompt.name} v{prompt.version}

            @@ -305,25 +180,12 @@ def _generate_visualization_html(self, prompt: Prompt) -> str: {tag_html}
            -
            -
            - {len(prompt.content.split())} - Words -
            -
            - {len(prompt.content)} - Characters -
            -
            - {len(prompt.content.splitlines())} - Lines -
            -
            + {stats_html}
            -

            📝 Prompt Content

            +

            📝Prompt Content

            -