diff --git a/.env.example b/.env.example
index 8558e9c..43e38c9 100644
--- a/.env.example
+++ b/.env.example
@@ -46,9 +46,14 @@ MAX_WORKERS=30
# API Keys and External Services
# =============================================================================
-# Serper API for web search and Google Scholar
-# Get your key from: https://serper.dev/
-SERPER_KEY_ID=your_key
+# Exa.ai API for semantic web search
+# Get your key from: https://exa.ai/
+# Exa provides AI-native neural search with:
+# - Semantic understanding (not just keyword matching)
+# - Built-in query optimization
+# - Direct content retrieval
+# - Better results for complex research queries
+EXA_API_KEY=your_key
# Jina API for web page reading
# Get your key from: https://jina.ai/
@@ -57,8 +62,8 @@ JINA_API_KEYS=your_key
# Summary model API (OpenAI-compatible) for page summarization
# Get your key from: https://platform.openai.com/
API_KEY=your_key
-API_BASE=your_api_base
-SUMMARY_MODEL_NAME=your_summary_model_name
+API_BASE=https://api.openai.com/v1
+SUMMARY_MODEL_NAME=gpt-4o-mini
# Dashscope API for file parsing (PDF, Office, etc.)
# Get your key from: https://dashscope.aliyun.com/
@@ -95,4 +100,24 @@ IDP_KEY_SECRET=your_idp_key_secret
# These are typically set by distributed training frameworks
# WORLD_SIZE=1
-# RANK=0
\ No newline at end of file
+# RANK=0
+
+# =============================================================================
+# MLX Configuration (Apple Silicon Only)
+# =============================================================================
+# For running on Apple Silicon Macs (M1/M2/M3/M4) using MLX framework
+# instead of CUDA/vLLM. Uses mlx-lm for efficient local inference.
+#
+# Requirements:
+# pip install mlx-lm
+#
+# Recommended models:
+# - abalogh/Tongyi-DeepResearch-30B-A3B-4bit (17GB, fits 32GB RAM)
+# - Original BF16 model requires 62GB+
+#
+# Usage:
+# bash inference/run_mlx_infer.sh
+#
+# MLX_MODEL=abalogh/Tongyi-DeepResearch-30B-A3B-4bit
+# MLX_HOST=127.0.0.1
+# MLX_PORT=8080
\ No newline at end of file
diff --git a/inference/interactive.py b/inference/interactive.py
new file mode 100644
index 0000000..38dd560
--- /dev/null
+++ b/inference/interactive.py
@@ -0,0 +1,218 @@
+#!/usr/bin/env python3
+"""
+Interactive CLI for DeepResearch on Apple Silicon (MLX)
+
+Usage:
+ python interactive.py [--model MODEL_PATH]
+
+Example:
+ python interactive.py
+ python interactive.py --model abalogh/Tongyi-DeepResearch-30B-A3B-4bit
+"""
+
+import argparse
+import json
+import os
+import sys
+import time
+
+# Load environment variables first
+from dotenv import load_dotenv
+load_dotenv(os.path.join(os.path.dirname(__file__), "..", ".env"))
+
+# Disable tokenizer parallelism warning
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+
+# Optional: rich for better formatting
+try:
+ from rich.console import Console
+ from rich.markdown import Markdown
+ from rich.panel import Panel
+ from rich.progress import Progress, SpinnerColumn, TextColumn
+ RICH_AVAILABLE = True
+ console = Console()
+except ImportError:
+ RICH_AVAILABLE = False
+ console = None
+
+
+def print_header():
+ """Print welcome header."""
+ header = """
+╔══════════════════════════════════════════════════════════════╗
+║ DeepResearch - Interactive Mode (MLX) ║
+║ Apple Silicon Optimized ║
+╚══════════════════════════════════════════════════════════════╝
+"""
+ if RICH_AVAILABLE:
+ console.print(header, style="bold blue")
+ else:
+ print(header)
+
+
+def print_help():
+ """Print help information."""
+ help_text = """
+Commands:
+ /help - Show this help message
+ /quit - Exit the program (or Ctrl+C)
+ /clear - Clear conversation history (start fresh)
+ /status - Show model and memory status
+
+Just type your research question to begin!
+
+Examples:
+ > What is the current population of Tokyo?
+ > Who won the 2024 Nobel Prize in Physics?
+ > Explain the mechanism of CRISPR-Cas9 gene editing
+"""
+ if RICH_AVAILABLE:
+ console.print(Panel(help_text, title="Help", border_style="green"))
+ else:
+ print(help_text)
+
+
+def format_answer(answer: str):
+ """Format the answer for display."""
+ if RICH_AVAILABLE:
+ console.print("\n")
+ console.print(Panel(Markdown(answer), title="[bold green]Answer[/]", border_style="green"))
+ else:
+ print("\n" + "=" * 60)
+ print("ANSWER:")
+ print("=" * 60)
+ print(answer)
+ print("=" * 60)
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Interactive DeepResearch CLI")
+ parser.add_argument("--model", type=str,
+ default="abalogh/Tongyi-DeepResearch-30B-A3B-4bit",
+ help="Model path or HuggingFace ID")
+ parser.add_argument("--temperature", type=float, default=0.7,
+ help="Sampling temperature")
+ parser.add_argument("--max_tokens", type=int, default=4096,
+ help="Max tokens per generation")
+ parser.add_argument("--max_rounds", type=int, default=15,
+ help="Max research rounds per question")
+ args = parser.parse_args()
+
+ print_header()
+
+ # Set max rounds via environment
+ os.environ['MAX_LLM_CALL_PER_RUN'] = str(args.max_rounds)
+
+ # Import agent after setting environment
+ print("Loading model (this may take a minute)...")
+
+ try:
+ from run_mlx_react import MLXReactAgent, TOOL_MAP
+ except ImportError as e:
+ print(f"Error importing agent: {e}")
+ print("Make sure you're running from the inference directory.")
+ return 1
+
+ if RICH_AVAILABLE:
+ with Progress(
+ SpinnerColumn(),
+ TextColumn("[progress.description]{task.description}"),
+ console=console
+ ) as progress:
+ progress.add_task("Loading MLX model...", total=None)
+ agent = MLXReactAgent(
+ model_path=args.model,
+ temperature=args.temperature,
+ max_tokens=args.max_tokens
+ )
+ else:
+ agent = MLXReactAgent(
+ model_path=args.model,
+ temperature=args.temperature,
+ max_tokens=args.max_tokens
+ )
+
+ print(f"\nTools available: {list(TOOL_MAP.keys())}")
+ print(f"Max rounds per question: {args.max_rounds}")
+ print_help()
+
+ while True:
+ try:
+ # Get user input
+ if RICH_AVAILABLE:
+ query = console.input("\n[bold cyan]Research Query>[/] ").strip()
+ else:
+ query = input("\nResearch Query> ").strip()
+
+ # Handle commands
+ if not query:
+ continue
+
+ if query.lower() in ('/quit', '/exit', '/q'):
+ print("Goodbye!")
+ break
+
+ if query.lower() == '/help':
+ print_help()
+ continue
+
+ if query.lower() == '/clear':
+ print("Ready for a new question.")
+ continue
+
+ if query.lower() == '/status':
+ try:
+ import mlx.core as mx
+ # Use new API (mlx >= 0.24) or fall back to deprecated
+ if hasattr(mx, 'get_active_memory'):
+ mem_gb = mx.get_active_memory() / (1024**3)
+ else:
+ mem_gb = mx.metal.get_active_memory() / (1024**3)
+ print(f"Model: {args.model}")
+ print(f"GPU Memory: {mem_gb:.1f} GB")
+ except Exception:
+ print(f"Model: {args.model}")
+ continue
+
+ if query.startswith('/'):
+ print(f"Unknown command: {query}. Type /help for available commands.")
+ continue
+
+ # Run research
+ print("\nResearching...\n")
+ start = time.time()
+
+ data = {'item': {'question': query, 'answer': ''}}
+ result = agent.run(data)
+
+ elapsed = time.time() - start
+
+ # Display result
+ prediction = result.get('prediction', 'No answer found.')
+ termination = result.get('termination', 'unknown')
+ num_rounds = len([m for m in result.get('messages', []) if m.get('role') == 'assistant'])
+
+ format_answer(prediction)
+
+ if RICH_AVAILABLE:
+ console.print(f"[dim]Completed in {elapsed:.1f}s | {num_rounds} rounds | Termination: {termination}[/]")
+ else:
+ print(f"\nCompleted in {elapsed:.1f}s | {num_rounds} rounds | Termination: {termination}")
+
+ except KeyboardInterrupt:
+ print("\n\nInterrupted. Type /quit to exit or continue with a new question.")
+ continue
+ except EOFError:
+ print("\nGoodbye!")
+ break
+ except Exception as e:
+ print(f"\nError: {e}")
+ import traceback
+ traceback.print_exc()
+ continue
+
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/inference/prompt.py b/inference/prompt.py
index 649e1e0..d63988b 100644
--- a/inference/prompt.py
+++ b/inference/prompt.py
@@ -1,4 +1,19 @@
-SYSTEM_PROMPT = """You are a deep research assistant. Your core function is to conduct thorough, multi-source investigations into any topic. You must handle both broad, open-domain inquiries and queries within specialized academic fields. For every request, synthesize information from credible, diverse sources to deliver a comprehensive, accurate, and objective response. When you have gathered sufficient information and are ready to provide the definitive response, you must enclose the entire final answer within tags.
+SYSTEM_PROMPT = """You are a deep research assistant. Your core function is to conduct thorough, multi-source investigations into any topic. You must handle both broad, open-domain inquiries and queries within specialized academic fields. For every request, synthesize information from credible, diverse sources to deliver a comprehensive, accurate, and objective response.
+
+# CRITICAL: Answer Behavior
+
+**You MUST provide a final answer after gathering sufficient information.** Do not continue researching indefinitely.
+
+Guidelines for when to provide your answer:
+1. After 2-3 search queries that return relevant results, you likely have enough information
+2. If multiple sources agree on key facts, you have sufficient confirmation
+3. If a webpage visit fails, use the search snippets you already have
+4. A good answer with available information is better than endless searching
+5. When uncertain, provide the best answer you can with appropriate caveats
+
+**When ready to answer, use this format:**
+Final reasoning about the gathered information
+Your comprehensive answer here
# Tools
@@ -6,25 +21,11 @@
You are provided with function signatures within XML tags:
-{"type": "function", "function": {"name": "search", "description": "Perform Google web searches then returns a string of the top search results. Accepts multiple queries.", "parameters": {"type": "object", "properties": {"query": {"type": "array", "items": {"type": "string", "description": "The search query."}, "minItems": 1, "description": "The list of search queries."}}, "required": ["query"]}}}
-{"type": "function", "function": {"name": "visit", "description": "Visit webpage(s) and return the summary of the content.", "parameters": {"type": "object", "properties": {"url": {"type": "array", "items": {"type": "string"}, "description": "The URL(s) of the webpage(s) to visit. Can be a single URL or an array of URLs."}, "goal": {"type": "string", "description": "The specific information goal for visiting webpage(s)."}}, "required": ["url", "goal"]}}}
-{"type": "function", "function": {"name": "PythonInterpreter", "description": "Executes Python code in a sandboxed environment. To use this tool, you must follow this format:
-1. The 'arguments' JSON object must be empty: {}.
-2. The Python code to be executed must be placed immediately after the JSON block, enclosed within and tags.
-
-IMPORTANT: Any output you want to see MUST be printed to standard output using the print() function.
-
-Example of a correct call:
-
-{"name": "PythonInterpreter", "arguments": {}}
-
-import numpy as np
-# Your code here
-print(f"The result is: {np.mean([1,2,3])}")
-
-", "parameters": {"type": "object", "properties": {}, "required": []}}}
-{"type": "function", "function": {"name": "google_scholar", "description": "Leverage Google Scholar to retrieve relevant information from academic publications. Accepts multiple queries. This tool will also return results from google search", "parameters": {"type": "object", "properties": {"query": {"type": "array", "items": {"type": "string", "description": "The search query."}, "minItems": 1, "description": "The list of search queries for Google Scholar."}}, "required": ["query"]}}}
-{"type": "function", "function": {"name": "parse_file", "description": "This is a tool that can be used to parse multiple user uploaded local files such as PDF, DOCX, PPTX, TXT, CSV, XLSX, DOC, ZIP, MP4, MP3.", "parameters": {"type": "object", "properties": {"files": {"type": "array", "items": {"type": "string"}, "description": "The file name of the user uploaded local files to be parsed."}}, "required": ["files"]}}}
+{"type": "function", "function": {"name": "search", "description": "Perform web searches and return top results with snippets. Use this first to find relevant sources.", "parameters": {"type": "object", "properties": {"query": {"type": "array", "items": {"type": "string"}, "minItems": 1, "description": "Search queries (1-3 queries recommended)."}}, "required": ["query"]}}}
+{"type": "function", "function": {"name": "visit", "description": "Visit webpage(s) to extract detailed content. Only visit if search snippets are insufficient.", "parameters": {"type": "object", "properties": {"url": {"type": "array", "items": {"type": "string"}, "description": "URL(s) to visit."}, "goal": {"type": "string", "description": "What specific information you need from the page."}}, "required": ["url", "goal"]}}}
+{"type": "function", "function": {"name": "google_scholar", "description": "Search academic publications. Use for scientific/research questions.", "parameters": {"type": "object", "properties": {"query": {"type": "array", "items": {"type": "string"}, "minItems": 1, "description": "Academic search queries."}}, "required": ["query"]}}}
+{"type": "function", "function": {"name": "PythonInterpreter", "description": "Execute Python code for calculations or data processing.", "parameters": {"type": "object", "properties": {}, "required": []}}}
+{"type": "function", "function": {"name": "parse_file", "description": "Parse uploaded files (PDF, DOCX, etc.).", "parameters": {"type": "object", "properties": {"files": {"type": "array", "items": {"type": "string"}, "description": "File names to parse."}}, "required": ["files"]}}}
For each function call, return a json object with function name and arguments within XML tags:
diff --git a/inference/run_mlx_infer.sh b/inference/run_mlx_infer.sh
new file mode 100755
index 0000000..dd45fd1
--- /dev/null
+++ b/inference/run_mlx_infer.sh
@@ -0,0 +1,75 @@
+#!/bin/bash
+# MLX Inference Script for Apple Silicon (M1/M2/M3/M4)
+# This script runs DeepResearch using Apple's MLX framework instead of vLLM/CUDA
+#
+# Uses native MLX Python API (no separate server needed)
+
+set -e
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+PROJECT_ROOT="$SCRIPT_DIR/.."
+ENV_FILE="$PROJECT_ROOT/.env"
+VENV_PATH="$PROJECT_ROOT/venv"
+
+# Activate virtual environment if it exists
+if [ -d "$VENV_PATH" ]; then
+ echo "Activating virtual environment..."
+ source "$VENV_PATH/bin/activate"
+else
+ echo "Warning: No venv found at $VENV_PATH"
+ echo "Create one with: python3 -m venv $VENV_PATH"
+ echo "Then install: pip install mlx-lm python-dotenv requests json5 tqdm qwen-agent"
+ exit 1
+fi
+
+# Load environment variables
+if [ ! -f "$ENV_FILE" ]; then
+ echo "Error: .env file not found at $ENV_FILE"
+ echo "Please copy .env.example to .env and configure your settings"
+ exit 1
+fi
+
+echo "Loading environment variables..."
+set -a
+source "$ENV_FILE"
+set +a
+
+# MLX-specific configuration
+MLX_MODEL="${MLX_MODEL:-abalogh/Tongyi-DeepResearch-30B-A3B-4bit}"
+
+# Default inference parameters
+TEMPERATURE="${TEMPERATURE:-0.85}"
+MAX_TOKENS="${MAX_TOKENS:-8192}"
+TOP_P="${TOP_P:-0.95}"
+
+echo "============================================"
+echo "DeepResearch MLX Inference (Apple Silicon)"
+echo "============================================"
+echo "Model: $MLX_MODEL"
+echo "Temperature: $TEMPERATURE"
+echo "Top-P: $TOP_P"
+echo "Max Tokens: $MAX_TOKENS"
+echo "============================================"
+
+# Check if mlx-lm is installed
+python -c "import mlx_lm" 2>/dev/null || {
+ echo "Error: mlx-lm not installed. Install with: pip install mlx-lm"
+ exit 1
+}
+
+# Disable tokenizer parallelism warning
+export TOKENIZERS_PARALLELISM=false
+
+# Run inference using native MLX API (no server needed)
+cd "$SCRIPT_DIR"
+
+python -u run_mlx_react.py \
+ --dataset "${DATASET:-$PROJECT_ROOT/eval_data/sample_questions.jsonl}" \
+ --output "${OUTPUT_PATH:-./outputs}" \
+ --model "$MLX_MODEL" \
+ --temperature "$TEMPERATURE" \
+ --top_p "$TOP_P" \
+ --max_tokens "$MAX_TOKENS" \
+ --roll_out_count "${ROLLOUT_COUNT:-1}"
+
+echo "Inference complete!"
diff --git a/inference/run_mlx_react.py b/inference/run_mlx_react.py
new file mode 100644
index 0000000..c5bbbac
--- /dev/null
+++ b/inference/run_mlx_react.py
@@ -0,0 +1,636 @@
+"""
+MLX React Agent Runner for Apple Silicon
+
+This script runs DeepResearch using Apple's MLX framework instead of vLLM/CUDA.
+Uses native MLX Python API with proper chat template handling.
+
+Requirements:
+ pip install mlx-lm python-dotenv requests json5 tqdm qwen-agent
+
+Usage:
+ python run_mlx_react.py --dataset eval_data/test.jsonl --output ./outputs
+"""
+
+import argparse
+import json
+import os
+import signal
+import sys
+import time
+import threading
+from datetime import datetime
+from typing import Any, Dict, List, Optional
+
+# Load environment variables before other imports
+from dotenv import load_dotenv
+load_dotenv(os.path.join(os.path.dirname(__file__), "..", ".env"))
+
+from tqdm import tqdm
+import json5
+from mlx_lm import load, generate
+from mlx_lm.sample_utils import make_sampler
+
+from prompt import SYSTEM_PROMPT
+
+# Disable tokenizer parallelism warning
+os.environ["TOKENIZERS_PARALLELISM"] = "false"
+
+# Tool registry - import tools with fallbacks
+TOOL_MAP: Dict[str, Any] = {}
+
+try:
+ from tool_search import Search
+ TOOL_MAP['search'] = Search()
+except ImportError as e:
+ print(f"Warning: Could not import Search tool: {e}")
+
+try:
+ from tool_visit import Visit
+ TOOL_MAP['visit'] = Visit()
+except ImportError as e:
+ print(f"Warning: Could not import Visit tool: {e}")
+
+try:
+ from tool_scholar import Scholar
+ TOOL_MAP['google_scholar'] = Scholar()
+except ImportError as e:
+ print(f"Warning: Could not import Scholar tool: {e}")
+
+try:
+ from tool_file import FileParser
+ TOOL_MAP['parse_file'] = FileParser()
+except ImportError as e:
+ print(f"Warning: Could not import FileParser tool: {e}")
+
+try:
+ from tool_python import PythonInterpreter
+ TOOL_MAP['PythonInterpreter'] = PythonInterpreter()
+except (ImportError, Exception) as e:
+ print(f"Warning: Could not import PythonInterpreter tool: {e}")
+
+print(f"Loaded tools: {list(TOOL_MAP.keys())}")
+
+MAX_LLM_CALL_PER_RUN = int(os.getenv('MAX_LLM_CALL_PER_RUN', 100))
+
+# Graceful shutdown flag
+shutdown_requested = False
+
+
+def signal_handler(signum, frame):
+ """Handle interrupt signals gracefully."""
+ global shutdown_requested
+ if shutdown_requested:
+ print("\nForce quit...")
+ sys.exit(1)
+ shutdown_requested = True
+ print("\nShutdown requested. Finishing current task...")
+
+
+signal.signal(signal.SIGINT, signal_handler)
+signal.signal(signal.SIGTERM, signal_handler)
+
+
+def today_date() -> str:
+ return datetime.now().strftime("%Y-%m-%d")
+
+
+class MLXReactAgent:
+ """
+ React agent using native MLX Python API for inference on Apple Silicon.
+
+ Uses the model's built-in chat template for proper formatting.
+ """
+
+ def __init__(self, model_path: str, temperature: float = 0.85,
+ top_p: float = 0.95, max_tokens: int = 8192):
+ self.model_path = model_path
+ self.temperature = temperature
+ self.top_p = top_p
+ self.max_tokens = max_tokens
+
+ print(f"Loading model: {model_path}")
+ self.model, self.tokenizer = load(model_path)
+ print(f"Model loaded successfully (memory: {self._get_memory_usage():.1f} GB)")
+
+ def _get_memory_usage(self) -> float:
+ """Get current GPU memory usage in GB."""
+ try:
+ import mlx.core as mx
+ # Use new API (mlx >= 0.24) or fall back to deprecated
+ if hasattr(mx, 'get_active_memory'):
+ return mx.get_active_memory() / (1024**3)
+ return mx.metal.get_active_memory() / (1024**3)
+ except Exception:
+ return 0.0
+
+ def build_prompt(self, messages: List[Dict[str, str]]) -> str:
+ """
+ Build prompt using tokenizer's chat template.
+ Falls back to manual Qwen format if template unavailable.
+ """
+ # Try using tokenizer's built-in chat template
+ if hasattr(self.tokenizer, 'apply_chat_template'):
+ try:
+ prompt = self.tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True
+ )
+ return prompt
+ except Exception as e:
+ print(f"Warning: apply_chat_template failed, using manual format: {e}")
+
+ # Fallback: Manual Qwen/ChatML format
+ prompt_parts = []
+ for msg in messages:
+ role = msg["role"]
+ content = msg["content"]
+ prompt_parts.append(f"<|im_start|>{role}\n{content}<|im_end|>")
+ prompt_parts.append("<|im_start|>assistant\n")
+ return "\n".join(prompt_parts)
+
+ def count_tokens(self, messages: List[Dict[str, str]]) -> int:
+ """Count tokens using the actual tokenizer."""
+ prompt = self.build_prompt(messages)
+ tokens = self.tokenizer.encode(prompt)
+ return len(tokens)
+
+ def generate_response(self, messages: List[Dict[str, str]], max_tokens: Optional[int] = None) -> str:
+ """Generate response using native MLX API."""
+ prompt = self.build_prompt(messages)
+ tokens = max_tokens or self.max_tokens
+
+ sampler = make_sampler(temp=self.temperature, top_p=self.top_p)
+
+ response = generate(
+ self.model,
+ self.tokenizer,
+ prompt=prompt,
+ max_tokens=tokens,
+ sampler=sampler,
+ verbose=False,
+ )
+
+ # Clean up response - remove trailing tokens
+ if "<|im_end|>" in response:
+ response = response.split("<|im_end|>")[0]
+ if "" in response:
+ response = response.split("")[0]
+
+ return response.strip()
+
+ def execute_tool(self, tool_name: str, tool_args: Dict[str, Any], timeout: int = 120) -> str:
+ """Execute a tool with timeout protection."""
+ if tool_name not in TOOL_MAP:
+ return f"Error: Tool '{tool_name}' not found. Available: {list(TOOL_MAP.keys())}"
+
+ # Copy args to avoid mutation
+ args = dict(tool_args)
+ result = ""
+ error = None
+
+ def run_tool():
+ nonlocal result, error
+ try:
+ if "python" in tool_name.lower():
+ result = str(TOOL_MAP['PythonInterpreter'].call(args))
+ elif tool_name == "parse_file":
+ import asyncio
+ params = {"files": args.get("files", [])}
+ r = asyncio.run(TOOL_MAP[tool_name].call(params, file_root_path="./eval_data/file_corpus"))
+ result = str(r) if not isinstance(r, str) else r
+ else:
+ result = str(TOOL_MAP[tool_name].call(args))
+ except Exception as e:
+ error = str(e)
+
+ thread = threading.Thread(target=run_tool)
+ thread.start()
+ thread.join(timeout=timeout)
+
+ if thread.is_alive():
+ return f"Error: Tool '{tool_name}' timed out after {timeout}s"
+
+ if error:
+ return f"Error executing tool '{tool_name}': {error}"
+
+ return result
+
+ def run(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Run the react agent loop for a single question.
+
+ Args:
+ data: Dict with 'item' containing 'question' and optionally 'answer'
+
+ Returns:
+ Dict with question, answer, messages, prediction, and termination status
+ """
+ global shutdown_requested
+
+ # Extract question
+ item = data['item']
+ question = item.get('question', '')
+ if not question:
+ try:
+ raw_msg = item['messages'][1]["content"]
+ question = raw_msg.split("User:")[1].strip() if "User:" in raw_msg else raw_msg
+ except Exception as e:
+ print(f"Failed to extract question: {e}")
+ return {"question": "", "error": "Could not extract question"}
+
+ answer = item.get('answer', '')
+ start_time = time.time()
+
+ # Build initial messages
+ system_prompt = SYSTEM_PROMPT + str(today_date())
+ messages: List[Dict[str, str]] = [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": question}
+ ]
+
+ num_calls_remaining = MAX_LLM_CALL_PER_RUN
+ round_num = 0
+ max_context_tokens = 100 * 1024 # 100K tokens (conservative for 128K model)
+ timeout_minutes = 120 # 2 hours
+ consecutive_errors = 0
+ last_tool_call = "" # For loop detection
+
+ while num_calls_remaining > 0:
+ # Check for shutdown
+ if shutdown_requested:
+ return {
+ "question": question,
+ "answer": answer,
+ "messages": messages,
+ "prediction": "Interrupted by user",
+ "termination": "interrupted"
+ }
+
+ # Check timeout
+ elapsed = time.time() - start_time
+ if elapsed > timeout_minutes * 60:
+ return {
+ "question": question,
+ "answer": answer,
+ "messages": messages,
+ "prediction": "No answer found after timeout",
+ "termination": "timeout"
+ }
+
+ round_num += 1
+ num_calls_remaining -= 1
+
+ print(f"--- Round {round_num} (calls left: {num_calls_remaining}) ---")
+
+ # Inject reminder at round 5 to encourage conclusion
+ if round_num == 5:
+ messages.append({
+ "role": "user",
+ "content": "REMINDER: You have made several searches. If you have enough information to answer the question, please provide your final answer now using tags. Only continue searching if absolutely necessary."
+ })
+
+ # Generate response
+ content = self.generate_response(messages)
+
+ preview = content[:200] + "..." if len(content) > 200 else content
+ print(f"Response: {preview}")
+
+ messages.append({"role": "assistant", "content": content})
+
+ # Check for tool calls
+ if '' in content and '' in content:
+ tool_call_str = content.split('')[1].split('')[0]
+
+ # Loop detection: check if same tool call as last time
+ if tool_call_str.strip() == last_tool_call:
+ print("Warning: Detected repeated tool call, forcing answer...")
+ messages.append({
+ "role": "user",
+ "content": "You are repeating the same action. Stop and provide your final answer NOW based on available information.\nyour answer"
+ })
+ content = self.generate_response(messages, max_tokens=2048)
+ messages.append({"role": "assistant", "content": content})
+ if '' in content and '' in content:
+ prediction = content.split('')[1].split('')[0]
+ else:
+ prediction = content
+ return {
+ "question": question,
+ "answer": answer,
+ "messages": messages,
+ "prediction": prediction.strip(),
+ "termination": "loop_detected"
+ }
+
+ last_tool_call = tool_call_str.strip()
+
+ try:
+ # Handle Python interpreter specially
+ if "python" in tool_call_str.lower() and "" in content:
+ code = content.split('')[1].split('')[0].strip()
+ result = self.execute_tool('PythonInterpreter', {"code": code})
+ else:
+ tool_call = json5.loads(tool_call_str.strip())
+ tool_name = tool_call.get('name', '')
+ tool_args = tool_call.get('arguments', {})
+ print(f"Tool: {tool_name} | Args: {json.dumps(tool_args)[:100]}...")
+ result = self.execute_tool(tool_name, tool_args)
+ except json.JSONDecodeError as e:
+ result = f'Error: Invalid JSON in tool call. {e}'
+ except Exception as e:
+ result = f'Error: Tool call failed. {e}'
+
+ # Track consecutive errors
+ if result.startswith('Error:'):
+ consecutive_errors += 1
+ if consecutive_errors >= 3:
+ print(f"Warning: {consecutive_errors} consecutive errors, forcing answer...")
+ messages.append({
+ "role": "user",
+ "content": f"Multiple tool errors occurred. Please provide your best answer based on the information you have gathered so far.\nyour answer"
+ })
+ content = self.generate_response(messages, max_tokens=2048)
+ messages.append({"role": "assistant", "content": content})
+ if '' in content and '' in content:
+ prediction = content.split('')[1].split('')[0]
+ else:
+ prediction = content
+ return {
+ "question": question,
+ "answer": answer,
+ "messages": messages,
+ "prediction": prediction.strip(),
+ "termination": "consecutive_errors"
+ }
+ else:
+ consecutive_errors = 0 # Reset on success
+
+ result_preview = result[:200] + "..." if len(result) > 200 else result
+ print(f"Result: {result_preview}")
+
+ tool_response = f"\n{result}\n"
+ messages.append({"role": "user", "content": tool_response})
+
+ # Check for final answer
+ if '' in content and '' in content:
+ prediction = content.split('')[1].split('')[0]
+ elapsed_mins = (time.time() - start_time) / 60
+ print(f"Answer found in {elapsed_mins:.1f} minutes")
+ return {
+ "question": question,
+ "answer": answer,
+ "messages": messages,
+ "prediction": prediction.strip(),
+ "termination": "answer"
+ }
+
+ # Check token limit
+ token_count = self.count_tokens(messages)
+ print(f"Tokens: {token_count:,}")
+
+ if token_count > max_context_tokens:
+ print(f"Token limit exceeded: {token_count:,} > {max_context_tokens:,}")
+
+ # Force final answer
+ messages.append({
+ "role": "user",
+ "content": "IMPORTANT: You have reached the maximum context length. "
+ "Stop making tool calls. Provide your final answer NOW based on all information above.\n"
+ "Format: final reasoning\nyour answer"
+ })
+
+ content = self.generate_response(messages, max_tokens=2048)
+ messages.append({"role": "assistant", "content": content})
+
+ if '' in content and '' in content:
+ prediction = content.split('')[1].split('')[0]
+ termination = "token_limit_answer"
+ else:
+ prediction = content
+ termination = "token_limit_no_answer"
+
+ return {
+ "question": question,
+ "answer": answer,
+ "messages": messages,
+ "prediction": prediction.strip(),
+ "termination": termination
+ }
+
+ # Max calls reached - try to get final answer
+ print("Max LLM calls reached, requesting final answer...")
+ messages.append({
+ "role": "user",
+ "content": "Maximum iterations reached. Provide your final answer NOW.\n"
+ "your answer"
+ })
+
+ content = self.generate_response(messages, max_tokens=2048)
+ messages.append({"role": "assistant", "content": content})
+
+ if '' in content and '' in content:
+ prediction = content.split('')[1].split('')[0]
+ termination = "max_calls_answer"
+ else:
+ prediction = content if content else "No answer found."
+ termination = "max_calls_no_answer"
+
+ return {
+ "question": question,
+ "answer": answer,
+ "messages": messages,
+ "prediction": prediction.strip(),
+ "termination": termination
+ }
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Run DeepResearch with MLX on Apple Silicon",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+ parser.add_argument("--model", type=str, default="abalogh/Tongyi-DeepResearch-30B-A3B-4bit",
+ help="Model path or HuggingFace model ID")
+ parser.add_argument("--dataset", type=str, required=True,
+ help="Path to input dataset (JSON or JSONL)")
+ parser.add_argument("--output", type=str, default="./outputs",
+ help="Output directory")
+ parser.add_argument("--temperature", type=float, default=0.85,
+ help="Sampling temperature (0.0-2.0)")
+ parser.add_argument("--top_p", type=float, default=0.95,
+ help="Top-p (nucleus) sampling (0.0-1.0)")
+ parser.add_argument("--max_tokens", type=int, default=8192,
+ help="Maximum tokens per generation")
+ parser.add_argument("--roll_out_count", type=int, default=1,
+ help="Number of rollouts per question")
+ args = parser.parse_args()
+
+ # Validate args
+ if not 0.0 <= args.temperature <= 2.0:
+ print("Warning: temperature should be between 0.0 and 2.0")
+ if not 0.0 <= args.top_p <= 1.0:
+ print("Warning: top_p should be between 0.0 and 1.0")
+
+ # Setup output directory
+ model_name = os.path.basename(args.model.rstrip('/'))
+ model_dir = os.path.join(args.output, f"{model_name}_mlx")
+ dataset_name = os.path.splitext(os.path.basename(args.dataset))[0]
+ output_dir = os.path.join(model_dir, dataset_name)
+ os.makedirs(output_dir, exist_ok=True)
+
+ print("=" * 60)
+ print("DeepResearch MLX Inference (Apple Silicon)")
+ print("=" * 60)
+ print(f"Model: {args.model}")
+ print(f"Dataset: {args.dataset}")
+ print(f"Output: {output_dir}")
+ print(f"Temperature: {args.temperature}")
+ print(f"Top-P: {args.top_p}")
+ print(f"Max Tokens: {args.max_tokens}")
+ print(f"Rollouts: {args.roll_out_count}")
+ print("=" * 60)
+
+ # Load dataset
+ try:
+ if args.dataset.endswith(".json"):
+ with open(args.dataset, "r", encoding="utf-8") as f:
+ items = json.load(f)
+ if isinstance(items, dict):
+ items = [items]
+ elif args.dataset.endswith(".jsonl"):
+ with open(args.dataset, "r", encoding="utf-8") as f:
+ items = [json.loads(line) for line in f if line.strip()]
+ else:
+ print("Error: Dataset must be .json or .jsonl")
+ return 1
+ except FileNotFoundError:
+ print(f"Error: Dataset not found at {args.dataset}")
+ return 1
+ except json.JSONDecodeError as e:
+ print(f"Error: Invalid JSON in dataset: {e}")
+ return 1
+
+ print(f"Loaded {len(items)} items from dataset")
+
+ if not items:
+ print("Error: No items in dataset")
+ return 1
+
+ # Initialize agent
+ try:
+ agent = MLXReactAgent(
+ model_path=args.model,
+ temperature=args.temperature,
+ top_p=args.top_p,
+ max_tokens=args.max_tokens,
+ )
+ except Exception as e:
+ print(f"Error loading model: {e}")
+ return 1
+
+ # Setup output files per rollout
+ output_files = {
+ i: os.path.join(output_dir, f"iter{i}.jsonl")
+ for i in range(1, args.roll_out_count + 1)
+ }
+
+ # Load already processed questions
+ processed_per_rollout: Dict[int, set] = {}
+ for rollout_idx in range(1, args.roll_out_count + 1):
+ processed: set = set()
+ output_file = output_files[rollout_idx]
+ if os.path.exists(output_file):
+ with open(output_file, "r", encoding="utf-8") as f:
+ for line in f:
+ try:
+ data = json.loads(line)
+ if "question" in data and "error" not in data:
+ processed.add(data["question"].strip())
+ except json.JSONDecodeError:
+ pass
+ processed_per_rollout[rollout_idx] = processed
+ if processed:
+ print(f"Rollout {rollout_idx}: {len(processed)} already processed")
+
+ # Build task list
+ tasks = []
+ for rollout_idx in range(1, args.roll_out_count + 1):
+ processed = processed_per_rollout[rollout_idx]
+ for item in items:
+ question = item.get("question", "").strip()
+ if not question:
+ try:
+ user_msg = item["messages"][1]["content"]
+ question = user_msg.split("User:")[1].strip() if "User:" in user_msg else user_msg
+ item["question"] = question
+ except Exception:
+ continue
+
+ if question and question not in processed:
+ tasks.append({
+ "item": item.copy(),
+ "rollout_idx": rollout_idx,
+ })
+
+ print(f"Tasks to run: {len(tasks)}")
+
+ if not tasks:
+ print("All tasks already completed!")
+ return 0
+
+ # Run tasks
+ write_lock = threading.Lock()
+ completed = 0
+ failed = 0
+
+ for task in tqdm(tasks, desc="Processing", disable=shutdown_requested):
+ if shutdown_requested:
+ print(f"\nStopped early. Completed: {completed}, Failed: {failed}")
+ break
+
+ rollout_idx = task["rollout_idx"]
+ output_file = output_files[rollout_idx]
+
+ try:
+ result = agent.run(task)
+ result["rollout_idx"] = rollout_idx
+ result["elapsed_time"] = time.time()
+
+ with write_lock:
+ with open(output_file, "a", encoding="utf-8") as f:
+ f.write(json.dumps(result, ensure_ascii=False) + "\n")
+
+ completed += 1
+
+ except Exception as e:
+ failed += 1
+ print(f"\nError: {e}")
+ import traceback
+ traceback.print_exc()
+
+ error_result = {
+ "question": task["item"].get("question", ""),
+ "answer": task["item"].get("answer", ""),
+ "rollout_idx": rollout_idx,
+ "error": str(e),
+ "messages": [],
+ "prediction": "[Failed]"
+ }
+ with write_lock:
+ with open(output_file, "a", encoding="utf-8") as f:
+ f.write(json.dumps(error_result, ensure_ascii=False) + "\n")
+
+ print("\n" + "=" * 60)
+ print("Inference Complete")
+ print("=" * 60)
+ print(f"Completed: {completed}")
+ print(f"Failed: {failed}")
+ print(f"Output: {output_dir}")
+ print("=" * 60)
+
+ return 0 if failed == 0 else 1
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/inference/test_mlx_connection.py b/inference/test_mlx_connection.py
new file mode 100644
index 0000000..3064a7d
--- /dev/null
+++ b/inference/test_mlx_connection.py
@@ -0,0 +1,85 @@
+#!/usr/bin/env python3
+"""
+Quick test script to verify MLX server connection and basic inference.
+Run this after starting the MLX server to verify everything works.
+
+Usage:
+ # Terminal 1: Start MLX server
+ mlx_lm.server --model abalogh/Tongyi-DeepResearch-30B-A3B-4bit --port 8080
+
+ # Terminal 2: Run this test
+ python test_mlx_connection.py
+"""
+
+import sys
+from openai import OpenAI
+
+MLX_HOST = "127.0.0.1"
+MLX_PORT = 8080
+
+
+def test_connection():
+ """Test basic connection to MLX server."""
+ print(f"Testing connection to MLX server at {MLX_HOST}:{MLX_PORT}...")
+
+ client = OpenAI(
+ api_key="mlx-local",
+ base_url=f"http://{MLX_HOST}:{MLX_PORT}/v1",
+ timeout=60.0,
+ )
+
+ # Test 1: List models
+ print("\n1. Listing available models...")
+ try:
+ models = client.models.list()
+ available = [m.id for m in models.data]
+ print(f" Available models: {available}")
+ except Exception as e:
+ print(f" FAILED: {e}")
+ return False
+
+ # Test 2: Simple completion
+ print("\n2. Testing simple completion...")
+ try:
+ response = client.chat.completions.create(
+ model=available[0] if available else "default",
+ messages=[
+ {"role": "user", "content": "What is 2+2? Answer with just the number."}
+ ],
+ max_tokens=10,
+ temperature=0.1,
+ )
+ answer = response.choices[0].message.content
+ print(f" Response: {answer}")
+ except Exception as e:
+ print(f" FAILED: {e}")
+ return False
+
+ # Test 3: Test with system prompt (like DeepResearch uses)
+ print("\n3. Testing with system prompt...")
+ try:
+ response = client.chat.completions.create(
+ model=available[0] if available else "default",
+ messages=[
+ {"role": "system", "content": "You are a helpful research assistant. Think step by step."},
+ {"role": "user", "content": "What is the capital of Japan?"}
+ ],
+ max_tokens=100,
+ temperature=0.7,
+ )
+ answer = response.choices[0].message.content or ""
+ print(f" Response: {answer[:200]}..." if len(answer) > 200 else f" Response: {answer}")
+ except Exception as e:
+ print(f" FAILED: {e}")
+ return False
+
+ print("\n" + "=" * 50)
+ print("All tests passed! MLX server is working correctly.")
+ print("You can now run: bash inference/run_mlx_infer.sh")
+ print("=" * 50)
+ return True
+
+
+if __name__ == "__main__":
+ success = test_connection()
+ sys.exit(0 if success else 1)
diff --git a/inference/test_mlx_tool_loop.py b/inference/test_mlx_tool_loop.py
new file mode 100644
index 0000000..e430d8e
--- /dev/null
+++ b/inference/test_mlx_tool_loop.py
@@ -0,0 +1,177 @@
+#!/usr/bin/env python3
+"""
+Diagnostic test for MLX tool response injection.
+
+This script tests the complete tool call loop:
+1. Send a question to MLX
+2. Model generates ...
+3. We parse and execute the tool
+4. We inject ...
+5. Model continues with the tool response
+
+Usage:
+ python test_mlx_tool_loop.py
+"""
+
+import os
+import sys
+import json
+from datetime import datetime
+from typing import Optional, Tuple, Dict, Any, List
+
+from dotenv import load_dotenv
+load_dotenv(os.path.join(os.path.dirname(__file__), "..", ".env"))
+
+from openai import OpenAI
+import json5
+
+# Import tools
+sys.path.insert(0, os.path.dirname(__file__))
+from tool_search import Search
+from prompt import SYSTEM_PROMPT
+
+TOOL_MAP: Dict[str, Any] = {"search": Search()}
+
+
+def today_date() -> str:
+ return datetime.now().strftime("%Y-%m-%d")
+
+
+def parse_tool_call(content: str) -> Tuple[Optional[str], Optional[Dict[str, Any]]]:
+ """Extract tool name and arguments from model output."""
+ if "" not in content or "" not in content:
+ return None, None
+
+ tool_call_str = content.split("")[1].split("")[0].strip()
+
+ try:
+ tool_call = json5.loads(tool_call_str)
+ name = tool_call.get("name") if isinstance(tool_call, dict) else None
+ args = tool_call.get("arguments", {}) if isinstance(tool_call, dict) else {}
+ return name, args
+ except Exception as e:
+ print(f"Failed to parse tool call JSON: {e}")
+ print(f"Raw tool call: {tool_call_str}")
+ return None, None
+
+
+def execute_tool(name: str, args: Dict[str, Any]) -> str:
+ """Execute a tool and return the result."""
+ if name not in TOOL_MAP:
+ return f"Error: Tool '{name}' not found. Available: {list(TOOL_MAP.keys())}"
+
+ args["params"] = args
+ return TOOL_MAP[name].call(args)
+
+
+def test_tool_loop():
+ """Test the complete tool call loop."""
+ print("=" * 60)
+ print("MLX Tool Response Injection Diagnostic Test")
+ print("=" * 60)
+
+ # Connect to MLX server
+ client = OpenAI(
+ api_key="mlx-local",
+ base_url="http://127.0.0.1:8080/v1",
+ timeout=300.0,
+ )
+
+ # Verify connection
+ try:
+ models = client.models.list()
+ print(f"Connected to MLX server. Model: {models.data[0].id}")
+ except Exception as e:
+ print(f"ERROR: Cannot connect to MLX server: {e}")
+ print("Make sure the MLX server is running:")
+ print(" mlx_lm.server --model abalogh/Tongyi-DeepResearch-30B-A3B-4bit --port 8080")
+ return
+
+ # Build messages
+ system_prompt = SYSTEM_PROMPT + str(today_date())
+ question = "What are the latest developments in quantum computing in 2024?"
+
+ messages: List[Dict[str, str]] = [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": question}
+ ]
+
+ print(f"\nQuestion: {question}")
+ print("-" * 60)
+
+ max_rounds = 5
+ for round_num in range(1, max_rounds + 1):
+ print(f"\n--- Round {round_num} ---")
+ print(f"Messages count: {len(messages)}")
+ print(f"Last message role: {messages[-1]['role']}")
+
+ # Call MLX
+ response = client.chat.completions.create(
+ model=models.data[0].id,
+ messages=messages, # type: ignore
+ stop=["\n", ""],
+ temperature=0.85,
+ top_p=0.95,
+ max_tokens=8192,
+ )
+
+ raw_content = response.choices[0].message.content
+ finish_reason = response.choices[0].finish_reason
+ content = raw_content.strip() if raw_content else ""
+
+ print(f"\nFinish reason: {finish_reason}")
+
+ # Clean up if leaked
+ if "" in content:
+ content = content[:content.find("")]
+
+ print(f"\nModel output ({len(content)} chars):")
+ print("-" * 40)
+ print(content[:1000] + "..." if len(content) > 1000 else content)
+ print("-" * 40)
+
+ # Add assistant message
+ messages.append({"role": "assistant", "content": content})
+
+ # Check for final answer
+ if "" in content and "" in content:
+ answer = content.split("")[1].split("")[0]
+ print(f"\nFINAL ANSWER: {answer}")
+ print(f"Total rounds: {round_num}")
+ break
+
+ # Check for tool call
+ tool_name, tool_args = parse_tool_call(content)
+
+ if tool_name and tool_args is not None:
+ print(f"\nTool call detected: {tool_name}")
+ print(f"Arguments: {json.dumps(tool_args, indent=2)}")
+
+ # Execute tool
+ result = execute_tool(tool_name, tool_args)
+ print(f"\nTool result ({len(result)} chars):")
+ print(result[:500] + "..." if len(result) > 500 else result)
+
+ # Inject tool response
+ tool_response = f"\n{result}\n"
+ messages.append({"role": "user", "content": tool_response})
+ print(f"\nInjected tool_response as user message")
+ else:
+ print("\nNo tool call detected in output")
+ if round_num < max_rounds:
+ print("Model may be stuck - no tool call and no answer")
+
+ # Print final message history
+ print("\n" + "=" * 60)
+ print("FULL MESSAGE HISTORY")
+ print("=" * 60)
+ for i, msg in enumerate(messages):
+ role = msg["role"]
+ content = msg["content"]
+ preview = content[:200] + "..." if len(content) > 200 else content
+ print(f"\n[{i}] {role.upper()}:")
+ print(preview)
+
+
+if __name__ == "__main__":
+ test_tool_loop()
diff --git a/inference/tool_search.py b/inference/tool_search.py
index 1a3f7b5..cb3b364 100644
--- a/inference/tool_search.py
+++ b/inference/tool_search.py
@@ -1,131 +1,223 @@
+"""
+Exa.ai Search Tool for DeepResearch
+AI-native semantic search with neural embeddings for superior research results.
+
+Exa.ai advantages:
+- Neural/semantic search (understands meaning, not just keywords)
+- Can retrieve full page contents directly
+- Better for research and complex queries
+- Built-in query optimization (autoprompt)
+- Supports date filtering and domain restrictions
+- Category filtering (research papers, news, company info, etc.)
+- AI-generated highlights for quick comprehension
+"""
+
import json
-from concurrent.futures import ThreadPoolExecutor
-from typing import List, Union
+import os
+from typing import Any, Dict, Optional, Union
import requests
from qwen_agent.tools.base import BaseTool, register_tool
-import asyncio
-from typing import Dict, List, Optional, Union
-import uuid
-import http.client
-import json
-
-import os
+EXA_API_KEY = os.environ.get('EXA_API_KEY')
+EXA_BASE_URL = "https://api.exa.ai"
-SERPER_KEY=os.environ.get('SERPER_KEY_ID')
+# Valid Exa categories for filtering results
+VALID_CATEGORIES = [
+ "company", "research paper", "news", "pdf",
+ "github", "tweet", "personal site", "linkedin profile"
+]
@register_tool("search", allow_overwrite=True)
class Search(BaseTool):
name = "search"
- description = "Performs batched web searches: supply an array 'query'; the tool retrieves the top 10 results for each query in one call."
+ description = "Performs semantic web searches using Exa.ai: supply an array 'query'; retrieves top results with AI-powered understanding. Supports category filtering for research papers, news, etc."
parameters = {
"type": "object",
"properties": {
"query": {
"type": "array",
- "items": {
- "type": "string"
- },
- "description": "Array of query strings. Include multiple complementary search queries in a single call."
+ "items": {"type": "string"},
+ "description": "Array of query strings. Exa understands natural language queries well."
+ },
+ "num_results": {
+ "type": "integer",
+ "description": "Number of results per query (default: 10, max: 100)",
+ "default": 10
},
+ "include_contents": {
+ "type": "boolean",
+ "description": "Whether to include page text content and highlights",
+ "default": False
+ },
+ "category": {
+ "type": "string",
+ "description": "Filter by category: 'research paper', 'news', 'company', 'pdf', 'github', 'tweet', 'personal site', 'linkedin profile'",
+ "enum": ["company", "research paper", "news", "pdf", "github", "tweet", "personal site", "linkedin profile"]
+ }
},
"required": ["query"],
}
def __init__(self, cfg: Optional[dict] = None):
super().__init__(cfg)
- def google_search_with_serp(self, query: str):
- def contains_chinese_basic(text: str) -> bool:
- return any('\u4E00' <= char <= '\u9FFF' for char in text)
- conn = http.client.HTTPSConnection("google.serper.dev")
- if contains_chinese_basic(query):
- payload = json.dumps({
- "q": query,
- "location": "China",
- "gl": "cn",
- "hl": "zh-cn"
- })
-
- else:
- payload = json.dumps({
- "q": query,
- "location": "United States",
- "gl": "us",
- "hl": "en"
- })
+ self.api_key = EXA_API_KEY
+ if not self.api_key:
+ raise ValueError("EXA_API_KEY environment variable not set. Get your key from https://exa.ai/")
+
+ def exa_search(
+ self,
+ query: str,
+ num_results: int = 10,
+ include_contents: bool = False,
+ category: Optional[str] = None
+ ) -> str:
+ """
+ Perform a search using Exa.ai API.
+
+ Exa supports multiple search types:
+ - "auto": Intelligently combines neural and other methods (default)
+ - "neural": AI-powered semantic search
+ - "deep": Comprehensive search with query expansion
+
+ Categories available:
+ - "research paper": Academic papers and publications
+ - "news": News articles
+ - "company": Company websites and info
+ - "pdf": PDF documents
+ - "github": GitHub repositories
+ - "tweet": Twitter/X posts
+ - "personal site": Personal websites/blogs
+ - "linkedin profile": LinkedIn profiles
+ """
headers = {
- 'X-API-KEY': SERPER_KEY,
- 'Content-Type': 'application/json'
- }
+ "Content-Type": "application/json",
+ "x-api-key": self.api_key
+ }
+
+ payload: Dict[str, Any] = {
+ "query": query,
+ "numResults": num_results,
+ "type": "auto",
+ "useAutoprompt": True,
+ }
+
+ # Add category filter if specified
+ if category and category in VALID_CATEGORIES:
+ payload["category"] = category
+ if include_contents:
+ payload["contents"] = {
+ "text": {"maxCharacters": 2000},
+ "highlights": True
+ }
- for i in range(5):
+ response = None
+ for attempt in range(3):
try:
- conn.request("POST", "/search", payload, headers)
- res = conn.getresponse()
+ response = requests.post(
+ f"{EXA_BASE_URL}/search",
+ headers=headers,
+ json=payload,
+ timeout=30
+ )
+ response.raise_for_status()
break
- except Exception as e:
- print(e)
- if i == 4:
- return f"Google search Timeout, return None, Please try again later."
+ except requests.exceptions.HTTPError as e:
+ if response is not None and response.status_code == 429:
+ return f"Exa search rate limited. Please wait and try again."
+ if response is not None and response.status_code == 401:
+ return f"Exa API key invalid. Check your EXA_API_KEY environment variable."
+ if attempt == 2:
+ return f"Exa search failed after 3 attempts: {str(e)}"
+ except requests.exceptions.RequestException as e:
+ if attempt == 2:
+ return f"Exa search failed after 3 attempts: {str(e)}"
continue
-
- data = res.read()
- results = json.loads(data.decode("utf-8"))
-
- try:
- if "organic" not in results:
- raise Exception(f"No results found for query: '{query}'. Use a less specific query.")
-
- web_snippets = list()
- idx = 0
- if "organic" in results:
- for page in results["organic"]:
- idx += 1
- date_published = ""
- if "date" in page:
- date_published = "\nDate published: " + page["date"]
-
- source = ""
- if "source" in page:
- source = "\nSource: " + page["source"]
-
- snippet = ""
- if "snippet" in page:
- snippet = "\n" + page["snippet"]
-
- redacted_version = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{snippet}"
- redacted_version = redacted_version.replace("Your browser can't play this video.", "")
- web_snippets.append(redacted_version)
-
- content = f"A Google search for '{query}' found {len(web_snippets)} results:\n\n## Web Results\n" + "\n\n".join(web_snippets)
- return content
- except:
- return f"No results found for '{query}'. Try with a more general query."
-
-
-
- def search_with_serp(self, query: str):
- result = self.google_search_with_serp(query)
- return result
-
- def call(self, params: Union[str, dict], **kwargs) -> str:
- try:
- query = params["query"]
- except:
- return "[Search] Invalid request format: Input must be a JSON object containing 'query' field"
- if isinstance(query, str):
- # 单个查询
- response = self.search_with_serp(query)
+ if response is None:
+ return "Exa search failed: no response received"
+
+ results = response.json()
+
+ if "results" not in results or not results["results"]:
+ return f"No results found for '{query}'. Try a different query."
+
+ web_snippets = []
+ for idx, result in enumerate(results["results"], 1):
+ title = result.get("title", "No title")
+ url = result.get("url", "")
+ published_date = result.get("publishedDate", "")
+ author = result.get("author", "")
+
+ snippet_parts = [f"{idx}. [{title}]({url})"]
+
+ if author:
+ snippet_parts.append(f"Author: {author}")
+ if published_date:
+ snippet_parts.append(f"Date: {published_date[:10]}")
+
+ # Prefer highlights (AI-generated key points), then text, then snippet
+ if include_contents:
+ highlights = result.get("highlights", [])
+ if highlights:
+ snippet_parts.append("\nKey points:")
+ for h in highlights[:3]:
+ snippet_parts.append(f" • {h}")
+ elif "text" in result:
+ text = result["text"][:500]
+ snippet_parts.append(f"\n{text}...")
+ elif "snippet" in result:
+ snippet_parts.append(f"\n{result['snippet']}")
+
+ web_snippets.append("\n".join(snippet_parts))
+
+ search_type = results.get("resolvedSearchType", "neural")
+ category_info = f" (category: {category})" if category else ""
+ content = f"Exa {search_type} search{category_info} for '{query}' found {len(web_snippets)} results:\n\n## Web Results\n\n" + "\n\n".join(web_snippets)
+ return content
+
+ def call(self, params: Union[str, dict], **kwargs: Any) -> str:
+ params_dict: Dict[str, Any]
+ if isinstance(params, str):
+ try:
+ params_dict = json.loads(params)
+ except json.JSONDecodeError:
+ return "[Search] Invalid JSON input"
else:
- # 多个查询
- assert isinstance(query, List)
+ params_dict = dict(params)
+
+ query = params_dict.get("query")
+ if not query:
+ return "[Search] Invalid request: 'query' field is required"
+
+ raw_num = params_dict.get("num_results", 10)
+ num_results = int(raw_num) if raw_num is not None else 10
+ include_contents = bool(params_dict.get("include_contents", False))
+ category = params_dict.get("category")
+
+ # Validate category if provided
+ if category and category not in VALID_CATEGORIES:
+ category = None
+
+ if isinstance(query, str):
+ return self.exa_search(query, num_results, include_contents, category)
+
+ if isinstance(query, list):
responses = []
for q in query:
- responses.append(self.search_with_serp(q))
- response = "\n=======\n".join(responses)
-
- return response
+ responses.append(self.exa_search(q, num_results, include_contents, category))
+ return "\n=======\n".join(responses)
+
+ return "[Search] Invalid query format: must be string or array of strings"
+
+if __name__ == "__main__":
+ from dotenv import load_dotenv
+
+ env_path = os.path.join(os.path.dirname(__file__), "..", ".env")
+ load_dotenv(env_path)
+
+ searcher = Search()
+ result = searcher.call({"query": ["What is retrieval augmented generation?"]})
+ print(result)
diff --git a/inference/tool_visit.py b/inference/tool_visit.py
index 92e4e3a..7892244 100644
--- a/inference/tool_visit.py
+++ b/inference/tool_visit.py
@@ -1,256 +1,276 @@
import json
import os
-import signal
-import threading
-from concurrent.futures import ThreadPoolExecutor, as_completed
-from typing import List, Union
+import re
+from typing import List, Union, Optional
import requests
from qwen_agent.tools.base import BaseTool, register_tool
from prompt import EXTRACTOR_PROMPT
from openai import OpenAI
-import random
-from urllib.parse import urlparse, unquote
+from urllib.parse import urlparse
import time
-from transformers import AutoTokenizer
import tiktoken
VISIT_SERVER_TIMEOUT = int(os.getenv("VISIT_SERVER_TIMEOUT", 200))
WEBCONTENT_MAXLENGTH = int(os.getenv("WEBCONTENT_MAXLENGTH", 150000))
-
JINA_API_KEYS = os.getenv("JINA_API_KEYS", "")
+# Maximum content length to return when summarization fails
+RAW_CONTENT_MAX_CHARS = int(os.getenv("RAW_CONTENT_MAX_CHARS", 8000))
+
-@staticmethod
def truncate_to_tokens(text: str, max_tokens: int = 95000) -> str:
- encoding = tiktoken.get_encoding("cl100k_base")
+ """Truncate text to a maximum number of tokens."""
+ try:
+ encoding = tiktoken.get_encoding("cl100k_base")
+ tokens = encoding.encode(text)
+ if len(tokens) <= max_tokens:
+ return text
+ return encoding.decode(tokens[:max_tokens])
+ except Exception:
+ # Fallback: rough char estimate (4 chars per token)
+ max_chars = max_tokens * 4
+ return text[:max_chars] if len(text) > max_chars else text
+
+
+def extract_main_content(html_text: str, max_chars: int = 8000) -> str:
+ """
+ Extract main content from HTML/markdown text.
+ Removes boilerplate like navigation, footers, etc.
+ """
+ lines = html_text.split('\n')
+ content_lines = []
+ total_chars = 0
- tokens = encoding.encode(text)
- if len(tokens) <= max_tokens:
- return text
+ skip_patterns = [
+ r'^#{1,2}\s*(navigation|menu|footer|sidebar|cookie|privacy|terms)',
+ r'^\s*(\||---)', # Table separators
+ r'^\s*\[.*\]\(.*\)\s*$', # Standalone links
+ r'^(copyright|©|\d{4}\s*[-–]\s*\d{4})',
+ ]
- truncated_tokens = tokens[:max_tokens]
- return encoding.decode(truncated_tokens)
-
-OSS_JSON_FORMAT = """# Response Formats
-## visit_content
-{"properties":{"rational":{"type":"string","description":"Locate the **specific sections/data** directly related to the user's goal within the webpage content"},"evidence":{"type":"string","description":"Identify and extract the **most relevant information** from the content, never miss any important information, output the **full original context** of the content as far as possible, it can be more than three paragraphs.","summary":{"type":"string","description":"Organize into a concise paragraph with logical flow, prioritizing clarity and judge the contribution of the information to the goal."}}}}"""
+ for line in lines:
+ line_lower = line.lower().strip()
+
+ # Skip empty lines at start
+ if not content_lines and not line.strip():
+ continue
+
+ # Skip navigation/boilerplate patterns
+ skip = False
+ for pattern in skip_patterns:
+ if re.match(pattern, line_lower):
+ skip = True
+ break
+
+ if skip:
+ continue
+
+ content_lines.append(line)
+ total_chars += len(line) + 1
+
+ if total_chars >= max_chars:
+ break
+
+ return '\n'.join(content_lines)
@register_tool('visit', allow_overwrite=True)
class Visit(BaseTool):
- # The `description` tells the agent the functionality of this tool.
name = 'visit'
description = 'Visit webpage(s) and return the summary of the content.'
- # The `parameters` tell the agent what input parameters the tool has.
parameters = {
"type": "object",
"properties": {
"url": {
"type": ["string", "array"],
- "items": {
- "type": "string"
- },
+ "items": {"type": "string"},
"minItems": 1,
- "description": "The URL(s) of the webpage(s) to visit. Can be a single URL or an array of URLs."
- },
- "goal": {
+ "description": "The URL(s) of the webpage(s) to visit."
+ },
+ "goal": {
"type": "string",
"description": "The goal of the visit for webpage(s)."
- }
+ }
},
"required": ["url", "goal"]
}
- # The `call` method is the main function of the tool.
+
+ def _validate_url(self, url: str) -> bool:
+ """Check if URL is valid and has a proper scheme."""
+ try:
+ parsed = urlparse(url)
+ return parsed.scheme in ('http', 'https') and bool(parsed.netloc)
+ except Exception:
+ return False
+
def call(self, params: Union[str, dict], **kwargs) -> str:
try:
url = params["url"]
goal = params["goal"]
- except:
- return "[Visit] Invalid request format: Input must be a JSON object containing 'url' and 'goal' fields"
-
- start_time = time.time()
-
- # Create log folder if it doesn't exist
- log_folder = "log"
- os.makedirs(log_folder, exist_ok=True)
+ except Exception:
+ return "[Visit] Invalid request: need 'url' and 'goal' fields"
if isinstance(url, str):
- response = self.readpage_jina(url, goal)
- else:
- response = []
- assert isinstance(url, List)
- start_time = time.time()
- for u in url:
- if time.time() - start_time > 900:
- cur_response = "The useful information in {url} for user goal {goal} as follows: \n\n".format(url=url, goal=goal)
- cur_response += "Evidence in page: \n" + "The provided webpage content could not be accessed. Please check the URL or file format." + "\n\n"
- cur_response += "Summary: \n" + "The webpage content could not be processed, and therefore, no information is available." + "\n\n"
- else:
- try:
- cur_response = self.readpage_jina(u, goal)
- except Exception as e:
- cur_response = f"Error fetching {u}: {str(e)}"
- response.append(cur_response)
- response = "\n=======\n".join(response)
-
- print(f'Summary Length {len(response)}; Summary Content {response}')
- return response.strip()
+ if not self._validate_url(url):
+ return f"[Visit] Invalid URL: {url}. URL must start with http:// or https://"
+ return self.readpage(url, goal)
- def call_server(self, msgs, max_retries=2):
- api_key = os.environ.get("API_KEY")
- url_llm = os.environ.get("API_BASE")
- model_name = os.environ.get("SUMMARY_MODEL_NAME", "")
- client = OpenAI(
- api_key=api_key,
- base_url=url_llm,
- )
- for attempt in range(max_retries):
+ # Multiple URLs
+ responses = []
+ start = time.time()
+ for u in url:
+ if time.time() - start > 300: # 5 min timeout for batch
+ responses.append(f"[Timeout] Skipped: {u}")
+ continue
+ if not self._validate_url(u):
+ responses.append(f"[Visit] Invalid URL: {u}")
+ continue
try:
- chat_response = client.chat.completions.create(
- model=model_name,
- messages=msgs,
- temperature=0.7
- )
- content = chat_response.choices[0].message.content
- if content:
- try:
- json.loads(content)
- except:
- # extract json from string
- left = content.find('{')
- right = content.rfind('}')
- if left != -1 and right != -1 and left <= right:
- content = content[left:right+1]
- return content
+ responses.append(self.readpage(u, goal))
except Exception as e:
- # print(e)
- if attempt == (max_retries - 1):
- return ""
- continue
-
-
- def jina_readpage(self, url: str) -> str:
- """
- Read webpage content using Jina service.
+ responses.append(f"[Error] {u}: {e}")
- Args:
- url: The URL to read
- goal: The goal/purpose of reading the page
-
- Returns:
- str: The webpage content or error message
- """
- max_retries = 3
- timeout = 50
+ return "\n\n---\n\n".join(responses)
+
+ def jina_fetch(self, url: str, timeout: int = 30) -> Optional[str]:
+ """Fetch webpage content using Jina Reader API."""
+ headers = {}
+ if JINA_API_KEYS:
+ headers["Authorization"] = f"Bearer {JINA_API_KEYS}"
- for attempt in range(max_retries):
- headers = {
- "Authorization": f"Bearer {JINA_API_KEYS}",
- }
+ for attempt in range(3):
try:
- response = requests.get(
+ resp = requests.get(
f"https://r.jina.ai/{url}",
headers=headers,
timeout=timeout
)
- if response.status_code == 200:
- webpage_content = response.text
- return webpage_content
- else:
- print(response.text)
- raise ValueError("jina readpage error")
- except Exception as e:
- time.sleep(0.5)
- if attempt == max_retries - 1:
- return "[visit] Failed to read page."
-
- return "[visit] Failed to read page."
+ if resp.status_code == 200 and len(resp.text) > 100:
+ return resp.text
+ except requests.RequestException:
+ pass
+ time.sleep(0.5)
+
+ return None
- def html_readpage_jina(self, url: str) -> str:
- max_attempts = 8
- for attempt in range(max_attempts):
- content = self.jina_readpage(url)
- service = "jina"
- print(service)
- if content and not content.startswith("[visit] Failed to read page.") and content != "[visit] Empty content." and not content.startswith("[document_parser]"):
- return content
- return "[visit] Failed to read page."
+ def direct_fetch(self, url: str, timeout: int = 20) -> Optional[str]:
+ """Fallback: fetch directly with requests."""
+ headers = {
+ "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36"
+ }
+ try:
+ resp = requests.get(url, headers=headers, timeout=timeout)
+ resp.raise_for_status()
+
+ # Basic HTML to text conversion
+ text = resp.text
+ # Remove script/style tags
+ text = re.sub(r'', '', text, flags=re.DOTALL | re.IGNORECASE)
+ text = re.sub(r'', '', text, flags=re.DOTALL | re.IGNORECASE)
+ # Remove HTML tags
+ text = re.sub(r'<[^>]+>', ' ', text)
+ # Clean whitespace
+ text = re.sub(r'\s+', ' ', text).strip()
+
+ return text if len(text) > 100 else None
+ except Exception:
+ return None
- def readpage_jina(self, url: str, goal: str) -> str:
- """
- Attempt to read webpage content by alternating between jina and aidata services.
+ def summarize_content(self, content: str, goal: str) -> Optional[dict]:
+ """Use LLM API to summarize content. Returns None if unavailable."""
+ api_key = os.environ.get("API_KEY")
+ api_base = os.environ.get("API_BASE")
+ model = os.environ.get("SUMMARY_MODEL_NAME", "")
- Args:
- url: The URL to read
- goal: The goal/purpose of reading the page
+ if not api_key or not api_base:
+ return None
+
+ try:
+ client = OpenAI(api_key=api_key, base_url=api_base)
+
+ # Truncate content for summarization
+ content = truncate_to_tokens(content, max_tokens=30000)
+
+ prompt = EXTRACTOR_PROMPT.format(webpage_content=content, goal=goal)
+
+ resp = client.chat.completions.create(
+ model=model,
+ messages=[{"role": "user", "content": prompt}],
+ temperature=0.7,
+ max_tokens=2000
+ )
+
+ result = resp.choices[0].message.content
+ if not result:
+ return None
+
+ # Parse JSON response
+ result = result.replace("```json", "").replace("```", "").strip()
+
+ # Try to extract JSON
+ left = result.find('{')
+ right = result.rfind('}')
+ if left != -1 and right > left:
+ result = result[left:right+1]
- Returns:
- str: The webpage content or error message
+ return json.loads(result)
+ except Exception as e:
+ print(f"[visit] Summarization failed: {e}")
+ return None
+
+ def readpage(self, url: str, goal: str) -> str:
+ """
+ Read and process a webpage.
+
+ Strategy:
+ 1. Try Jina Reader first (best for complex pages)
+ 2. Fallback to direct fetch if Jina fails
+ 3. Try LLM summarization if API available
+ 4. Return extracted raw content if summarization unavailable
"""
-
- summary_page_func = self.call_server
- max_retries = int(os.getenv('VISIT_SERVER_MAX_RETRIES', 1))
+ # Step 1: Fetch content
+ content = self.jina_fetch(url)
+
+ if not content:
+ content = self.direct_fetch(url)
+
+ if not content:
+ return self._format_error(url, goal, "Failed to fetch webpage content")
+
+ # Step 2: Try summarization
+ summary = self.summarize_content(content, goal)
+
+ if summary and summary.get("evidence") and summary.get("summary"):
+ return self._format_success(url, goal, summary["evidence"], summary["summary"])
+
+ # Step 3: Fallback - return extracted raw content
+ extracted = extract_main_content(content, RAW_CONTENT_MAX_CHARS)
+
+ if len(extracted) < 100:
+ return self._format_error(url, goal, "Page content too short or empty")
+
+ return self._format_raw(url, goal, extracted)
- content = self.html_readpage_jina(url)
+ def _format_success(self, url: str, goal: str, evidence: str, summary: str) -> str:
+ return f"""Content from {url} for goal: {goal}
- if content and not content.startswith("[visit] Failed to read page.") and content != "[visit] Empty content." and not content.startswith("[document_parser]"):
- content = truncate_to_tokens(content, max_tokens=95000)
- messages = [{"role":"user","content": EXTRACTOR_PROMPT.format(webpage_content=content, goal=goal)}]
- parse_retry_times = 0
- raw = summary_page_func(messages, max_retries=max_retries)
- summary_retries = 3
- while len(raw) < 10 and summary_retries >= 0:
- truncate_length = int(0.7 * len(content)) if summary_retries > 0 else 25000
- status_msg = (
- f"[visit] Summary url[{url}] "
- f"attempt {3 - summary_retries + 1}/3, "
- f"content length: {len(content)}, "
- f"truncating to {truncate_length} chars"
- ) if summary_retries > 0 else (
- f"[visit] Summary url[{url}] failed after 3 attempts, "
- f"final truncation to 25000 chars"
- )
- print(status_msg)
- content = content[:truncate_length]
- extraction_prompt = EXTRACTOR_PROMPT.format(
- webpage_content=content,
- goal=goal
- )
- messages = [{"role": "user", "content": extraction_prompt}]
- raw = summary_page_func(messages, max_retries=max_retries)
- summary_retries -= 1
+**Evidence:**
+{evidence}
- parse_retry_times = 2
- if isinstance(raw, str):
- raw = raw.replace("```json", "").replace("```", "").strip()
- while parse_retry_times < 3:
- try:
- raw = json.loads(raw)
- break
- except:
- raw = summary_page_func(messages, max_retries=max_retries)
- parse_retry_times += 1
-
- if parse_retry_times >= 3:
- useful_information = "The useful information in {url} for user goal {goal} as follows: \n\n".format(url=url, goal=goal)
- useful_information += "Evidence in page: \n" + "The provided webpage content could not be accessed. Please check the URL or file format." + "\n\n"
- useful_information += "Summary: \n" + "The webpage content could not be processed, and therefore, no information is available." + "\n\n"
- else:
- useful_information = "The useful information in {url} for user goal {goal} as follows: \n\n".format(url=url, goal=goal)
- useful_information += "Evidence in page: \n" + str(raw["evidence"]) + "\n\n"
- useful_information += "Summary: \n" + str(raw["summary"]) + "\n\n"
+**Summary:**
+{summary}"""
- if len(useful_information) < 10 and summary_retries < 0:
- print("[visit] Could not generate valid summary after maximum retries")
- useful_information = "[visit] Failed to read page"
-
- return useful_information
+ def _format_raw(self, url: str, goal: str, content: str) -> str:
+ return f"""Content from {url} for goal: {goal}
+
+**Raw Content (summarization unavailable):**
+{content}
+
+Note: Please extract the relevant information for your goal from the content above."""
- # If no valid content was obtained after all retries
- else:
- useful_information = "The useful information in {url} for user goal {goal} as follows: \n\n".format(url=url, goal=goal)
- useful_information += "Evidence in page: \n" + "The provided webpage content could not be accessed. Please check the URL or file format." + "\n\n"
- useful_information += "Summary: \n" + "The webpage content could not be processed, and therefore, no information is available." + "\n\n"
- return useful_information
+ def _format_error(self, url: str, goal: str, reason: str) -> str:
+ return f"""Could not retrieve content from {url}
+Goal: {goal}
+Reason: {reason}
-
\ No newline at end of file
+Please try a different source or search query."""