diff --git a/.env.example b/.env.example index 8558e9c..43e38c9 100644 --- a/.env.example +++ b/.env.example @@ -46,9 +46,14 @@ MAX_WORKERS=30 # API Keys and External Services # ============================================================================= -# Serper API for web search and Google Scholar -# Get your key from: https://serper.dev/ -SERPER_KEY_ID=your_key +# Exa.ai API for semantic web search +# Get your key from: https://exa.ai/ +# Exa provides AI-native neural search with: +# - Semantic understanding (not just keyword matching) +# - Built-in query optimization +# - Direct content retrieval +# - Better results for complex research queries +EXA_API_KEY=your_key # Jina API for web page reading # Get your key from: https://jina.ai/ @@ -57,8 +62,8 @@ JINA_API_KEYS=your_key # Summary model API (OpenAI-compatible) for page summarization # Get your key from: https://platform.openai.com/ API_KEY=your_key -API_BASE=your_api_base -SUMMARY_MODEL_NAME=your_summary_model_name +API_BASE=https://api.openai.com/v1 +SUMMARY_MODEL_NAME=gpt-4o-mini # Dashscope API for file parsing (PDF, Office, etc.) # Get your key from: https://dashscope.aliyun.com/ @@ -95,4 +100,24 @@ IDP_KEY_SECRET=your_idp_key_secret # These are typically set by distributed training frameworks # WORLD_SIZE=1 -# RANK=0 \ No newline at end of file +# RANK=0 + +# ============================================================================= +# MLX Configuration (Apple Silicon Only) +# ============================================================================= +# For running on Apple Silicon Macs (M1/M2/M3/M4) using MLX framework +# instead of CUDA/vLLM. Uses mlx-lm for efficient local inference. +# +# Requirements: +# pip install mlx-lm +# +# Recommended models: +# - abalogh/Tongyi-DeepResearch-30B-A3B-4bit (17GB, fits 32GB RAM) +# - Original BF16 model requires 62GB+ +# +# Usage: +# bash inference/run_mlx_infer.sh +# +# MLX_MODEL=abalogh/Tongyi-DeepResearch-30B-A3B-4bit +# MLX_HOST=127.0.0.1 +# MLX_PORT=8080 \ No newline at end of file diff --git a/inference/interactive.py b/inference/interactive.py new file mode 100644 index 0000000..38dd560 --- /dev/null +++ b/inference/interactive.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 +""" +Interactive CLI for DeepResearch on Apple Silicon (MLX) + +Usage: + python interactive.py [--model MODEL_PATH] + +Example: + python interactive.py + python interactive.py --model abalogh/Tongyi-DeepResearch-30B-A3B-4bit +""" + +import argparse +import json +import os +import sys +import time + +# Load environment variables first +from dotenv import load_dotenv +load_dotenv(os.path.join(os.path.dirname(__file__), "..", ".env")) + +# Disable tokenizer parallelism warning +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +# Optional: rich for better formatting +try: + from rich.console import Console + from rich.markdown import Markdown + from rich.panel import Panel + from rich.progress import Progress, SpinnerColumn, TextColumn + RICH_AVAILABLE = True + console = Console() +except ImportError: + RICH_AVAILABLE = False + console = None + + +def print_header(): + """Print welcome header.""" + header = """ +╔══════════════════════════════════════════════════════════════╗ +║ DeepResearch - Interactive Mode (MLX) ║ +║ Apple Silicon Optimized ║ +╚══════════════════════════════════════════════════════════════╝ +""" + if RICH_AVAILABLE: + console.print(header, style="bold blue") + else: + print(header) + + +def print_help(): + """Print help information.""" + help_text = """ +Commands: + /help - Show this help message + /quit - Exit the program (or Ctrl+C) + /clear - Clear conversation history (start fresh) + /status - Show model and memory status + +Just type your research question to begin! + +Examples: + > What is the current population of Tokyo? + > Who won the 2024 Nobel Prize in Physics? + > Explain the mechanism of CRISPR-Cas9 gene editing +""" + if RICH_AVAILABLE: + console.print(Panel(help_text, title="Help", border_style="green")) + else: + print(help_text) + + +def format_answer(answer: str): + """Format the answer for display.""" + if RICH_AVAILABLE: + console.print("\n") + console.print(Panel(Markdown(answer), title="[bold green]Answer[/]", border_style="green")) + else: + print("\n" + "=" * 60) + print("ANSWER:") + print("=" * 60) + print(answer) + print("=" * 60) + + +def main(): + parser = argparse.ArgumentParser(description="Interactive DeepResearch CLI") + parser.add_argument("--model", type=str, + default="abalogh/Tongyi-DeepResearch-30B-A3B-4bit", + help="Model path or HuggingFace ID") + parser.add_argument("--temperature", type=float, default=0.7, + help="Sampling temperature") + parser.add_argument("--max_tokens", type=int, default=4096, + help="Max tokens per generation") + parser.add_argument("--max_rounds", type=int, default=15, + help="Max research rounds per question") + args = parser.parse_args() + + print_header() + + # Set max rounds via environment + os.environ['MAX_LLM_CALL_PER_RUN'] = str(args.max_rounds) + + # Import agent after setting environment + print("Loading model (this may take a minute)...") + + try: + from run_mlx_react import MLXReactAgent, TOOL_MAP + except ImportError as e: + print(f"Error importing agent: {e}") + print("Make sure you're running from the inference directory.") + return 1 + + if RICH_AVAILABLE: + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console + ) as progress: + progress.add_task("Loading MLX model...", total=None) + agent = MLXReactAgent( + model_path=args.model, + temperature=args.temperature, + max_tokens=args.max_tokens + ) + else: + agent = MLXReactAgent( + model_path=args.model, + temperature=args.temperature, + max_tokens=args.max_tokens + ) + + print(f"\nTools available: {list(TOOL_MAP.keys())}") + print(f"Max rounds per question: {args.max_rounds}") + print_help() + + while True: + try: + # Get user input + if RICH_AVAILABLE: + query = console.input("\n[bold cyan]Research Query>[/] ").strip() + else: + query = input("\nResearch Query> ").strip() + + # Handle commands + if not query: + continue + + if query.lower() in ('/quit', '/exit', '/q'): + print("Goodbye!") + break + + if query.lower() == '/help': + print_help() + continue + + if query.lower() == '/clear': + print("Ready for a new question.") + continue + + if query.lower() == '/status': + try: + import mlx.core as mx + # Use new API (mlx >= 0.24) or fall back to deprecated + if hasattr(mx, 'get_active_memory'): + mem_gb = mx.get_active_memory() / (1024**3) + else: + mem_gb = mx.metal.get_active_memory() / (1024**3) + print(f"Model: {args.model}") + print(f"GPU Memory: {mem_gb:.1f} GB") + except Exception: + print(f"Model: {args.model}") + continue + + if query.startswith('/'): + print(f"Unknown command: {query}. Type /help for available commands.") + continue + + # Run research + print("\nResearching...\n") + start = time.time() + + data = {'item': {'question': query, 'answer': ''}} + result = agent.run(data) + + elapsed = time.time() - start + + # Display result + prediction = result.get('prediction', 'No answer found.') + termination = result.get('termination', 'unknown') + num_rounds = len([m for m in result.get('messages', []) if m.get('role') == 'assistant']) + + format_answer(prediction) + + if RICH_AVAILABLE: + console.print(f"[dim]Completed in {elapsed:.1f}s | {num_rounds} rounds | Termination: {termination}[/]") + else: + print(f"\nCompleted in {elapsed:.1f}s | {num_rounds} rounds | Termination: {termination}") + + except KeyboardInterrupt: + print("\n\nInterrupted. Type /quit to exit or continue with a new question.") + continue + except EOFError: + print("\nGoodbye!") + break + except Exception as e: + print(f"\nError: {e}") + import traceback + traceback.print_exc() + continue + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/inference/prompt.py b/inference/prompt.py index 649e1e0..d63988b 100644 --- a/inference/prompt.py +++ b/inference/prompt.py @@ -1,4 +1,19 @@ -SYSTEM_PROMPT = """You are a deep research assistant. Your core function is to conduct thorough, multi-source investigations into any topic. You must handle both broad, open-domain inquiries and queries within specialized academic fields. For every request, synthesize information from credible, diverse sources to deliver a comprehensive, accurate, and objective response. When you have gathered sufficient information and are ready to provide the definitive response, you must enclose the entire final answer within tags. +SYSTEM_PROMPT = """You are a deep research assistant. Your core function is to conduct thorough, multi-source investigations into any topic. You must handle both broad, open-domain inquiries and queries within specialized academic fields. For every request, synthesize information from credible, diverse sources to deliver a comprehensive, accurate, and objective response. + +# CRITICAL: Answer Behavior + +**You MUST provide a final answer after gathering sufficient information.** Do not continue researching indefinitely. + +Guidelines for when to provide your answer: +1. After 2-3 search queries that return relevant results, you likely have enough information +2. If multiple sources agree on key facts, you have sufficient confirmation +3. If a webpage visit fails, use the search snippets you already have +4. A good answer with available information is better than endless searching +5. When uncertain, provide the best answer you can with appropriate caveats + +**When ready to answer, use this format:** +Final reasoning about the gathered information +Your comprehensive answer here # Tools @@ -6,25 +21,11 @@ You are provided with function signatures within XML tags: -{"type": "function", "function": {"name": "search", "description": "Perform Google web searches then returns a string of the top search results. Accepts multiple queries.", "parameters": {"type": "object", "properties": {"query": {"type": "array", "items": {"type": "string", "description": "The search query."}, "minItems": 1, "description": "The list of search queries."}}, "required": ["query"]}}} -{"type": "function", "function": {"name": "visit", "description": "Visit webpage(s) and return the summary of the content.", "parameters": {"type": "object", "properties": {"url": {"type": "array", "items": {"type": "string"}, "description": "The URL(s) of the webpage(s) to visit. Can be a single URL or an array of URLs."}, "goal": {"type": "string", "description": "The specific information goal for visiting webpage(s)."}}, "required": ["url", "goal"]}}} -{"type": "function", "function": {"name": "PythonInterpreter", "description": "Executes Python code in a sandboxed environment. To use this tool, you must follow this format: -1. The 'arguments' JSON object must be empty: {}. -2. The Python code to be executed must be placed immediately after the JSON block, enclosed within and tags. - -IMPORTANT: Any output you want to see MUST be printed to standard output using the print() function. - -Example of a correct call: - -{"name": "PythonInterpreter", "arguments": {}} - -import numpy as np -# Your code here -print(f"The result is: {np.mean([1,2,3])}") - -", "parameters": {"type": "object", "properties": {}, "required": []}}} -{"type": "function", "function": {"name": "google_scholar", "description": "Leverage Google Scholar to retrieve relevant information from academic publications. Accepts multiple queries. This tool will also return results from google search", "parameters": {"type": "object", "properties": {"query": {"type": "array", "items": {"type": "string", "description": "The search query."}, "minItems": 1, "description": "The list of search queries for Google Scholar."}}, "required": ["query"]}}} -{"type": "function", "function": {"name": "parse_file", "description": "This is a tool that can be used to parse multiple user uploaded local files such as PDF, DOCX, PPTX, TXT, CSV, XLSX, DOC, ZIP, MP4, MP3.", "parameters": {"type": "object", "properties": {"files": {"type": "array", "items": {"type": "string"}, "description": "The file name of the user uploaded local files to be parsed."}}, "required": ["files"]}}} +{"type": "function", "function": {"name": "search", "description": "Perform web searches and return top results with snippets. Use this first to find relevant sources.", "parameters": {"type": "object", "properties": {"query": {"type": "array", "items": {"type": "string"}, "minItems": 1, "description": "Search queries (1-3 queries recommended)."}}, "required": ["query"]}}} +{"type": "function", "function": {"name": "visit", "description": "Visit webpage(s) to extract detailed content. Only visit if search snippets are insufficient.", "parameters": {"type": "object", "properties": {"url": {"type": "array", "items": {"type": "string"}, "description": "URL(s) to visit."}, "goal": {"type": "string", "description": "What specific information you need from the page."}}, "required": ["url", "goal"]}}} +{"type": "function", "function": {"name": "google_scholar", "description": "Search academic publications. Use for scientific/research questions.", "parameters": {"type": "object", "properties": {"query": {"type": "array", "items": {"type": "string"}, "minItems": 1, "description": "Academic search queries."}}, "required": ["query"]}}} +{"type": "function", "function": {"name": "PythonInterpreter", "description": "Execute Python code for calculations or data processing.", "parameters": {"type": "object", "properties": {}, "required": []}}} +{"type": "function", "function": {"name": "parse_file", "description": "Parse uploaded files (PDF, DOCX, etc.).", "parameters": {"type": "object", "properties": {"files": {"type": "array", "items": {"type": "string"}, "description": "File names to parse."}}, "required": ["files"]}}} For each function call, return a json object with function name and arguments within XML tags: diff --git a/inference/run_mlx_infer.sh b/inference/run_mlx_infer.sh new file mode 100755 index 0000000..dd45fd1 --- /dev/null +++ b/inference/run_mlx_infer.sh @@ -0,0 +1,75 @@ +#!/bin/bash +# MLX Inference Script for Apple Silicon (M1/M2/M3/M4) +# This script runs DeepResearch using Apple's MLX framework instead of vLLM/CUDA +# +# Uses native MLX Python API (no separate server needed) + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$SCRIPT_DIR/.." +ENV_FILE="$PROJECT_ROOT/.env" +VENV_PATH="$PROJECT_ROOT/venv" + +# Activate virtual environment if it exists +if [ -d "$VENV_PATH" ]; then + echo "Activating virtual environment..." + source "$VENV_PATH/bin/activate" +else + echo "Warning: No venv found at $VENV_PATH" + echo "Create one with: python3 -m venv $VENV_PATH" + echo "Then install: pip install mlx-lm python-dotenv requests json5 tqdm qwen-agent" + exit 1 +fi + +# Load environment variables +if [ ! -f "$ENV_FILE" ]; then + echo "Error: .env file not found at $ENV_FILE" + echo "Please copy .env.example to .env and configure your settings" + exit 1 +fi + +echo "Loading environment variables..." +set -a +source "$ENV_FILE" +set +a + +# MLX-specific configuration +MLX_MODEL="${MLX_MODEL:-abalogh/Tongyi-DeepResearch-30B-A3B-4bit}" + +# Default inference parameters +TEMPERATURE="${TEMPERATURE:-0.85}" +MAX_TOKENS="${MAX_TOKENS:-8192}" +TOP_P="${TOP_P:-0.95}" + +echo "============================================" +echo "DeepResearch MLX Inference (Apple Silicon)" +echo "============================================" +echo "Model: $MLX_MODEL" +echo "Temperature: $TEMPERATURE" +echo "Top-P: $TOP_P" +echo "Max Tokens: $MAX_TOKENS" +echo "============================================" + +# Check if mlx-lm is installed +python -c "import mlx_lm" 2>/dev/null || { + echo "Error: mlx-lm not installed. Install with: pip install mlx-lm" + exit 1 +} + +# Disable tokenizer parallelism warning +export TOKENIZERS_PARALLELISM=false + +# Run inference using native MLX API (no server needed) +cd "$SCRIPT_DIR" + +python -u run_mlx_react.py \ + --dataset "${DATASET:-$PROJECT_ROOT/eval_data/sample_questions.jsonl}" \ + --output "${OUTPUT_PATH:-./outputs}" \ + --model "$MLX_MODEL" \ + --temperature "$TEMPERATURE" \ + --top_p "$TOP_P" \ + --max_tokens "$MAX_TOKENS" \ + --roll_out_count "${ROLLOUT_COUNT:-1}" + +echo "Inference complete!" diff --git a/inference/run_mlx_react.py b/inference/run_mlx_react.py new file mode 100644 index 0000000..c5bbbac --- /dev/null +++ b/inference/run_mlx_react.py @@ -0,0 +1,636 @@ +""" +MLX React Agent Runner for Apple Silicon + +This script runs DeepResearch using Apple's MLX framework instead of vLLM/CUDA. +Uses native MLX Python API with proper chat template handling. + +Requirements: + pip install mlx-lm python-dotenv requests json5 tqdm qwen-agent + +Usage: + python run_mlx_react.py --dataset eval_data/test.jsonl --output ./outputs +""" + +import argparse +import json +import os +import signal +import sys +import time +import threading +from datetime import datetime +from typing import Any, Dict, List, Optional + +# Load environment variables before other imports +from dotenv import load_dotenv +load_dotenv(os.path.join(os.path.dirname(__file__), "..", ".env")) + +from tqdm import tqdm +import json5 +from mlx_lm import load, generate +from mlx_lm.sample_utils import make_sampler + +from prompt import SYSTEM_PROMPT + +# Disable tokenizer parallelism warning +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +# Tool registry - import tools with fallbacks +TOOL_MAP: Dict[str, Any] = {} + +try: + from tool_search import Search + TOOL_MAP['search'] = Search() +except ImportError as e: + print(f"Warning: Could not import Search tool: {e}") + +try: + from tool_visit import Visit + TOOL_MAP['visit'] = Visit() +except ImportError as e: + print(f"Warning: Could not import Visit tool: {e}") + +try: + from tool_scholar import Scholar + TOOL_MAP['google_scholar'] = Scholar() +except ImportError as e: + print(f"Warning: Could not import Scholar tool: {e}") + +try: + from tool_file import FileParser + TOOL_MAP['parse_file'] = FileParser() +except ImportError as e: + print(f"Warning: Could not import FileParser tool: {e}") + +try: + from tool_python import PythonInterpreter + TOOL_MAP['PythonInterpreter'] = PythonInterpreter() +except (ImportError, Exception) as e: + print(f"Warning: Could not import PythonInterpreter tool: {e}") + +print(f"Loaded tools: {list(TOOL_MAP.keys())}") + +MAX_LLM_CALL_PER_RUN = int(os.getenv('MAX_LLM_CALL_PER_RUN', 100)) + +# Graceful shutdown flag +shutdown_requested = False + + +def signal_handler(signum, frame): + """Handle interrupt signals gracefully.""" + global shutdown_requested + if shutdown_requested: + print("\nForce quit...") + sys.exit(1) + shutdown_requested = True + print("\nShutdown requested. Finishing current task...") + + +signal.signal(signal.SIGINT, signal_handler) +signal.signal(signal.SIGTERM, signal_handler) + + +def today_date() -> str: + return datetime.now().strftime("%Y-%m-%d") + + +class MLXReactAgent: + """ + React agent using native MLX Python API for inference on Apple Silicon. + + Uses the model's built-in chat template for proper formatting. + """ + + def __init__(self, model_path: str, temperature: float = 0.85, + top_p: float = 0.95, max_tokens: int = 8192): + self.model_path = model_path + self.temperature = temperature + self.top_p = top_p + self.max_tokens = max_tokens + + print(f"Loading model: {model_path}") + self.model, self.tokenizer = load(model_path) + print(f"Model loaded successfully (memory: {self._get_memory_usage():.1f} GB)") + + def _get_memory_usage(self) -> float: + """Get current GPU memory usage in GB.""" + try: + import mlx.core as mx + # Use new API (mlx >= 0.24) or fall back to deprecated + if hasattr(mx, 'get_active_memory'): + return mx.get_active_memory() / (1024**3) + return mx.metal.get_active_memory() / (1024**3) + except Exception: + return 0.0 + + def build_prompt(self, messages: List[Dict[str, str]]) -> str: + """ + Build prompt using tokenizer's chat template. + Falls back to manual Qwen format if template unavailable. + """ + # Try using tokenizer's built-in chat template + if hasattr(self.tokenizer, 'apply_chat_template'): + try: + prompt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + return prompt + except Exception as e: + print(f"Warning: apply_chat_template failed, using manual format: {e}") + + # Fallback: Manual Qwen/ChatML format + prompt_parts = [] + for msg in messages: + role = msg["role"] + content = msg["content"] + prompt_parts.append(f"<|im_start|>{role}\n{content}<|im_end|>") + prompt_parts.append("<|im_start|>assistant\n") + return "\n".join(prompt_parts) + + def count_tokens(self, messages: List[Dict[str, str]]) -> int: + """Count tokens using the actual tokenizer.""" + prompt = self.build_prompt(messages) + tokens = self.tokenizer.encode(prompt) + return len(tokens) + + def generate_response(self, messages: List[Dict[str, str]], max_tokens: Optional[int] = None) -> str: + """Generate response using native MLX API.""" + prompt = self.build_prompt(messages) + tokens = max_tokens or self.max_tokens + + sampler = make_sampler(temp=self.temperature, top_p=self.top_p) + + response = generate( + self.model, + self.tokenizer, + prompt=prompt, + max_tokens=tokens, + sampler=sampler, + verbose=False, + ) + + # Clean up response - remove trailing tokens + if "<|im_end|>" in response: + response = response.split("<|im_end|>")[0] + if "" in response: + response = response.split("")[0] + + return response.strip() + + def execute_tool(self, tool_name: str, tool_args: Dict[str, Any], timeout: int = 120) -> str: + """Execute a tool with timeout protection.""" + if tool_name not in TOOL_MAP: + return f"Error: Tool '{tool_name}' not found. Available: {list(TOOL_MAP.keys())}" + + # Copy args to avoid mutation + args = dict(tool_args) + result = "" + error = None + + def run_tool(): + nonlocal result, error + try: + if "python" in tool_name.lower(): + result = str(TOOL_MAP['PythonInterpreter'].call(args)) + elif tool_name == "parse_file": + import asyncio + params = {"files": args.get("files", [])} + r = asyncio.run(TOOL_MAP[tool_name].call(params, file_root_path="./eval_data/file_corpus")) + result = str(r) if not isinstance(r, str) else r + else: + result = str(TOOL_MAP[tool_name].call(args)) + except Exception as e: + error = str(e) + + thread = threading.Thread(target=run_tool) + thread.start() + thread.join(timeout=timeout) + + if thread.is_alive(): + return f"Error: Tool '{tool_name}' timed out after {timeout}s" + + if error: + return f"Error executing tool '{tool_name}': {error}" + + return result + + def run(self, data: Dict[str, Any]) -> Dict[str, Any]: + """ + Run the react agent loop for a single question. + + Args: + data: Dict with 'item' containing 'question' and optionally 'answer' + + Returns: + Dict with question, answer, messages, prediction, and termination status + """ + global shutdown_requested + + # Extract question + item = data['item'] + question = item.get('question', '') + if not question: + try: + raw_msg = item['messages'][1]["content"] + question = raw_msg.split("User:")[1].strip() if "User:" in raw_msg else raw_msg + except Exception as e: + print(f"Failed to extract question: {e}") + return {"question": "", "error": "Could not extract question"} + + answer = item.get('answer', '') + start_time = time.time() + + # Build initial messages + system_prompt = SYSTEM_PROMPT + str(today_date()) + messages: List[Dict[str, str]] = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question} + ] + + num_calls_remaining = MAX_LLM_CALL_PER_RUN + round_num = 0 + max_context_tokens = 100 * 1024 # 100K tokens (conservative for 128K model) + timeout_minutes = 120 # 2 hours + consecutive_errors = 0 + last_tool_call = "" # For loop detection + + while num_calls_remaining > 0: + # Check for shutdown + if shutdown_requested: + return { + "question": question, + "answer": answer, + "messages": messages, + "prediction": "Interrupted by user", + "termination": "interrupted" + } + + # Check timeout + elapsed = time.time() - start_time + if elapsed > timeout_minutes * 60: + return { + "question": question, + "answer": answer, + "messages": messages, + "prediction": "No answer found after timeout", + "termination": "timeout" + } + + round_num += 1 + num_calls_remaining -= 1 + + print(f"--- Round {round_num} (calls left: {num_calls_remaining}) ---") + + # Inject reminder at round 5 to encourage conclusion + if round_num == 5: + messages.append({ + "role": "user", + "content": "REMINDER: You have made several searches. If you have enough information to answer the question, please provide your final answer now using tags. Only continue searching if absolutely necessary." + }) + + # Generate response + content = self.generate_response(messages) + + preview = content[:200] + "..." if len(content) > 200 else content + print(f"Response: {preview}") + + messages.append({"role": "assistant", "content": content}) + + # Check for tool calls + if '' in content and '' in content: + tool_call_str = content.split('')[1].split('')[0] + + # Loop detection: check if same tool call as last time + if tool_call_str.strip() == last_tool_call: + print("Warning: Detected repeated tool call, forcing answer...") + messages.append({ + "role": "user", + "content": "You are repeating the same action. Stop and provide your final answer NOW based on available information.\nyour answer" + }) + content = self.generate_response(messages, max_tokens=2048) + messages.append({"role": "assistant", "content": content}) + if '' in content and '' in content: + prediction = content.split('')[1].split('')[0] + else: + prediction = content + return { + "question": question, + "answer": answer, + "messages": messages, + "prediction": prediction.strip(), + "termination": "loop_detected" + } + + last_tool_call = tool_call_str.strip() + + try: + # Handle Python interpreter specially + if "python" in tool_call_str.lower() and "" in content: + code = content.split('')[1].split('')[0].strip() + result = self.execute_tool('PythonInterpreter', {"code": code}) + else: + tool_call = json5.loads(tool_call_str.strip()) + tool_name = tool_call.get('name', '') + tool_args = tool_call.get('arguments', {}) + print(f"Tool: {tool_name} | Args: {json.dumps(tool_args)[:100]}...") + result = self.execute_tool(tool_name, tool_args) + except json.JSONDecodeError as e: + result = f'Error: Invalid JSON in tool call. {e}' + except Exception as e: + result = f'Error: Tool call failed. {e}' + + # Track consecutive errors + if result.startswith('Error:'): + consecutive_errors += 1 + if consecutive_errors >= 3: + print(f"Warning: {consecutive_errors} consecutive errors, forcing answer...") + messages.append({ + "role": "user", + "content": f"Multiple tool errors occurred. Please provide your best answer based on the information you have gathered so far.\nyour answer" + }) + content = self.generate_response(messages, max_tokens=2048) + messages.append({"role": "assistant", "content": content}) + if '' in content and '' in content: + prediction = content.split('')[1].split('')[0] + else: + prediction = content + return { + "question": question, + "answer": answer, + "messages": messages, + "prediction": prediction.strip(), + "termination": "consecutive_errors" + } + else: + consecutive_errors = 0 # Reset on success + + result_preview = result[:200] + "..." if len(result) > 200 else result + print(f"Result: {result_preview}") + + tool_response = f"\n{result}\n" + messages.append({"role": "user", "content": tool_response}) + + # Check for final answer + if '' in content and '' in content: + prediction = content.split('')[1].split('')[0] + elapsed_mins = (time.time() - start_time) / 60 + print(f"Answer found in {elapsed_mins:.1f} minutes") + return { + "question": question, + "answer": answer, + "messages": messages, + "prediction": prediction.strip(), + "termination": "answer" + } + + # Check token limit + token_count = self.count_tokens(messages) + print(f"Tokens: {token_count:,}") + + if token_count > max_context_tokens: + print(f"Token limit exceeded: {token_count:,} > {max_context_tokens:,}") + + # Force final answer + messages.append({ + "role": "user", + "content": "IMPORTANT: You have reached the maximum context length. " + "Stop making tool calls. Provide your final answer NOW based on all information above.\n" + "Format: final reasoning\nyour answer" + }) + + content = self.generate_response(messages, max_tokens=2048) + messages.append({"role": "assistant", "content": content}) + + if '' in content and '' in content: + prediction = content.split('')[1].split('')[0] + termination = "token_limit_answer" + else: + prediction = content + termination = "token_limit_no_answer" + + return { + "question": question, + "answer": answer, + "messages": messages, + "prediction": prediction.strip(), + "termination": termination + } + + # Max calls reached - try to get final answer + print("Max LLM calls reached, requesting final answer...") + messages.append({ + "role": "user", + "content": "Maximum iterations reached. Provide your final answer NOW.\n" + "your answer" + }) + + content = self.generate_response(messages, max_tokens=2048) + messages.append({"role": "assistant", "content": content}) + + if '' in content and '' in content: + prediction = content.split('')[1].split('')[0] + termination = "max_calls_answer" + else: + prediction = content if content else "No answer found." + termination = "max_calls_no_answer" + + return { + "question": question, + "answer": answer, + "messages": messages, + "prediction": prediction.strip(), + "termination": termination + } + + +def main(): + parser = argparse.ArgumentParser( + description="Run DeepResearch with MLX on Apple Silicon", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("--model", type=str, default="abalogh/Tongyi-DeepResearch-30B-A3B-4bit", + help="Model path or HuggingFace model ID") + parser.add_argument("--dataset", type=str, required=True, + help="Path to input dataset (JSON or JSONL)") + parser.add_argument("--output", type=str, default="./outputs", + help="Output directory") + parser.add_argument("--temperature", type=float, default=0.85, + help="Sampling temperature (0.0-2.0)") + parser.add_argument("--top_p", type=float, default=0.95, + help="Top-p (nucleus) sampling (0.0-1.0)") + parser.add_argument("--max_tokens", type=int, default=8192, + help="Maximum tokens per generation") + parser.add_argument("--roll_out_count", type=int, default=1, + help="Number of rollouts per question") + args = parser.parse_args() + + # Validate args + if not 0.0 <= args.temperature <= 2.0: + print("Warning: temperature should be between 0.0 and 2.0") + if not 0.0 <= args.top_p <= 1.0: + print("Warning: top_p should be between 0.0 and 1.0") + + # Setup output directory + model_name = os.path.basename(args.model.rstrip('/')) + model_dir = os.path.join(args.output, f"{model_name}_mlx") + dataset_name = os.path.splitext(os.path.basename(args.dataset))[0] + output_dir = os.path.join(model_dir, dataset_name) + os.makedirs(output_dir, exist_ok=True) + + print("=" * 60) + print("DeepResearch MLX Inference (Apple Silicon)") + print("=" * 60) + print(f"Model: {args.model}") + print(f"Dataset: {args.dataset}") + print(f"Output: {output_dir}") + print(f"Temperature: {args.temperature}") + print(f"Top-P: {args.top_p}") + print(f"Max Tokens: {args.max_tokens}") + print(f"Rollouts: {args.roll_out_count}") + print("=" * 60) + + # Load dataset + try: + if args.dataset.endswith(".json"): + with open(args.dataset, "r", encoding="utf-8") as f: + items = json.load(f) + if isinstance(items, dict): + items = [items] + elif args.dataset.endswith(".jsonl"): + with open(args.dataset, "r", encoding="utf-8") as f: + items = [json.loads(line) for line in f if line.strip()] + else: + print("Error: Dataset must be .json or .jsonl") + return 1 + except FileNotFoundError: + print(f"Error: Dataset not found at {args.dataset}") + return 1 + except json.JSONDecodeError as e: + print(f"Error: Invalid JSON in dataset: {e}") + return 1 + + print(f"Loaded {len(items)} items from dataset") + + if not items: + print("Error: No items in dataset") + return 1 + + # Initialize agent + try: + agent = MLXReactAgent( + model_path=args.model, + temperature=args.temperature, + top_p=args.top_p, + max_tokens=args.max_tokens, + ) + except Exception as e: + print(f"Error loading model: {e}") + return 1 + + # Setup output files per rollout + output_files = { + i: os.path.join(output_dir, f"iter{i}.jsonl") + for i in range(1, args.roll_out_count + 1) + } + + # Load already processed questions + processed_per_rollout: Dict[int, set] = {} + for rollout_idx in range(1, args.roll_out_count + 1): + processed: set = set() + output_file = output_files[rollout_idx] + if os.path.exists(output_file): + with open(output_file, "r", encoding="utf-8") as f: + for line in f: + try: + data = json.loads(line) + if "question" in data and "error" not in data: + processed.add(data["question"].strip()) + except json.JSONDecodeError: + pass + processed_per_rollout[rollout_idx] = processed + if processed: + print(f"Rollout {rollout_idx}: {len(processed)} already processed") + + # Build task list + tasks = [] + for rollout_idx in range(1, args.roll_out_count + 1): + processed = processed_per_rollout[rollout_idx] + for item in items: + question = item.get("question", "").strip() + if not question: + try: + user_msg = item["messages"][1]["content"] + question = user_msg.split("User:")[1].strip() if "User:" in user_msg else user_msg + item["question"] = question + except Exception: + continue + + if question and question not in processed: + tasks.append({ + "item": item.copy(), + "rollout_idx": rollout_idx, + }) + + print(f"Tasks to run: {len(tasks)}") + + if not tasks: + print("All tasks already completed!") + return 0 + + # Run tasks + write_lock = threading.Lock() + completed = 0 + failed = 0 + + for task in tqdm(tasks, desc="Processing", disable=shutdown_requested): + if shutdown_requested: + print(f"\nStopped early. Completed: {completed}, Failed: {failed}") + break + + rollout_idx = task["rollout_idx"] + output_file = output_files[rollout_idx] + + try: + result = agent.run(task) + result["rollout_idx"] = rollout_idx + result["elapsed_time"] = time.time() + + with write_lock: + with open(output_file, "a", encoding="utf-8") as f: + f.write(json.dumps(result, ensure_ascii=False) + "\n") + + completed += 1 + + except Exception as e: + failed += 1 + print(f"\nError: {e}") + import traceback + traceback.print_exc() + + error_result = { + "question": task["item"].get("question", ""), + "answer": task["item"].get("answer", ""), + "rollout_idx": rollout_idx, + "error": str(e), + "messages": [], + "prediction": "[Failed]" + } + with write_lock: + with open(output_file, "a", encoding="utf-8") as f: + f.write(json.dumps(error_result, ensure_ascii=False) + "\n") + + print("\n" + "=" * 60) + print("Inference Complete") + print("=" * 60) + print(f"Completed: {completed}") + print(f"Failed: {failed}") + print(f"Output: {output_dir}") + print("=" * 60) + + return 0 if failed == 0 else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/inference/test_mlx_connection.py b/inference/test_mlx_connection.py new file mode 100644 index 0000000..3064a7d --- /dev/null +++ b/inference/test_mlx_connection.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 +""" +Quick test script to verify MLX server connection and basic inference. +Run this after starting the MLX server to verify everything works. + +Usage: + # Terminal 1: Start MLX server + mlx_lm.server --model abalogh/Tongyi-DeepResearch-30B-A3B-4bit --port 8080 + + # Terminal 2: Run this test + python test_mlx_connection.py +""" + +import sys +from openai import OpenAI + +MLX_HOST = "127.0.0.1" +MLX_PORT = 8080 + + +def test_connection(): + """Test basic connection to MLX server.""" + print(f"Testing connection to MLX server at {MLX_HOST}:{MLX_PORT}...") + + client = OpenAI( + api_key="mlx-local", + base_url=f"http://{MLX_HOST}:{MLX_PORT}/v1", + timeout=60.0, + ) + + # Test 1: List models + print("\n1. Listing available models...") + try: + models = client.models.list() + available = [m.id for m in models.data] + print(f" Available models: {available}") + except Exception as e: + print(f" FAILED: {e}") + return False + + # Test 2: Simple completion + print("\n2. Testing simple completion...") + try: + response = client.chat.completions.create( + model=available[0] if available else "default", + messages=[ + {"role": "user", "content": "What is 2+2? Answer with just the number."} + ], + max_tokens=10, + temperature=0.1, + ) + answer = response.choices[0].message.content + print(f" Response: {answer}") + except Exception as e: + print(f" FAILED: {e}") + return False + + # Test 3: Test with system prompt (like DeepResearch uses) + print("\n3. Testing with system prompt...") + try: + response = client.chat.completions.create( + model=available[0] if available else "default", + messages=[ + {"role": "system", "content": "You are a helpful research assistant. Think step by step."}, + {"role": "user", "content": "What is the capital of Japan?"} + ], + max_tokens=100, + temperature=0.7, + ) + answer = response.choices[0].message.content or "" + print(f" Response: {answer[:200]}..." if len(answer) > 200 else f" Response: {answer}") + except Exception as e: + print(f" FAILED: {e}") + return False + + print("\n" + "=" * 50) + print("All tests passed! MLX server is working correctly.") + print("You can now run: bash inference/run_mlx_infer.sh") + print("=" * 50) + return True + + +if __name__ == "__main__": + success = test_connection() + sys.exit(0 if success else 1) diff --git a/inference/test_mlx_tool_loop.py b/inference/test_mlx_tool_loop.py new file mode 100644 index 0000000..e430d8e --- /dev/null +++ b/inference/test_mlx_tool_loop.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python3 +""" +Diagnostic test for MLX tool response injection. + +This script tests the complete tool call loop: +1. Send a question to MLX +2. Model generates ... +3. We parse and execute the tool +4. We inject ... +5. Model continues with the tool response + +Usage: + python test_mlx_tool_loop.py +""" + +import os +import sys +import json +from datetime import datetime +from typing import Optional, Tuple, Dict, Any, List + +from dotenv import load_dotenv +load_dotenv(os.path.join(os.path.dirname(__file__), "..", ".env")) + +from openai import OpenAI +import json5 + +# Import tools +sys.path.insert(0, os.path.dirname(__file__)) +from tool_search import Search +from prompt import SYSTEM_PROMPT + +TOOL_MAP: Dict[str, Any] = {"search": Search()} + + +def today_date() -> str: + return datetime.now().strftime("%Y-%m-%d") + + +def parse_tool_call(content: str) -> Tuple[Optional[str], Optional[Dict[str, Any]]]: + """Extract tool name and arguments from model output.""" + if "" not in content or "" not in content: + return None, None + + tool_call_str = content.split("")[1].split("")[0].strip() + + try: + tool_call = json5.loads(tool_call_str) + name = tool_call.get("name") if isinstance(tool_call, dict) else None + args = tool_call.get("arguments", {}) if isinstance(tool_call, dict) else {} + return name, args + except Exception as e: + print(f"Failed to parse tool call JSON: {e}") + print(f"Raw tool call: {tool_call_str}") + return None, None + + +def execute_tool(name: str, args: Dict[str, Any]) -> str: + """Execute a tool and return the result.""" + if name not in TOOL_MAP: + return f"Error: Tool '{name}' not found. Available: {list(TOOL_MAP.keys())}" + + args["params"] = args + return TOOL_MAP[name].call(args) + + +def test_tool_loop(): + """Test the complete tool call loop.""" + print("=" * 60) + print("MLX Tool Response Injection Diagnostic Test") + print("=" * 60) + + # Connect to MLX server + client = OpenAI( + api_key="mlx-local", + base_url="http://127.0.0.1:8080/v1", + timeout=300.0, + ) + + # Verify connection + try: + models = client.models.list() + print(f"Connected to MLX server. Model: {models.data[0].id}") + except Exception as e: + print(f"ERROR: Cannot connect to MLX server: {e}") + print("Make sure the MLX server is running:") + print(" mlx_lm.server --model abalogh/Tongyi-DeepResearch-30B-A3B-4bit --port 8080") + return + + # Build messages + system_prompt = SYSTEM_PROMPT + str(today_date()) + question = "What are the latest developments in quantum computing in 2024?" + + messages: List[Dict[str, str]] = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question} + ] + + print(f"\nQuestion: {question}") + print("-" * 60) + + max_rounds = 5 + for round_num in range(1, max_rounds + 1): + print(f"\n--- Round {round_num} ---") + print(f"Messages count: {len(messages)}") + print(f"Last message role: {messages[-1]['role']}") + + # Call MLX + response = client.chat.completions.create( + model=models.data[0].id, + messages=messages, # type: ignore + stop=["\n", ""], + temperature=0.85, + top_p=0.95, + max_tokens=8192, + ) + + raw_content = response.choices[0].message.content + finish_reason = response.choices[0].finish_reason + content = raw_content.strip() if raw_content else "" + + print(f"\nFinish reason: {finish_reason}") + + # Clean up if leaked + if "" in content: + content = content[:content.find("")] + + print(f"\nModel output ({len(content)} chars):") + print("-" * 40) + print(content[:1000] + "..." if len(content) > 1000 else content) + print("-" * 40) + + # Add assistant message + messages.append({"role": "assistant", "content": content}) + + # Check for final answer + if "" in content and "" in content: + answer = content.split("")[1].split("")[0] + print(f"\nFINAL ANSWER: {answer}") + print(f"Total rounds: {round_num}") + break + + # Check for tool call + tool_name, tool_args = parse_tool_call(content) + + if tool_name and tool_args is not None: + print(f"\nTool call detected: {tool_name}") + print(f"Arguments: {json.dumps(tool_args, indent=2)}") + + # Execute tool + result = execute_tool(tool_name, tool_args) + print(f"\nTool result ({len(result)} chars):") + print(result[:500] + "..." if len(result) > 500 else result) + + # Inject tool response + tool_response = f"\n{result}\n" + messages.append({"role": "user", "content": tool_response}) + print(f"\nInjected tool_response as user message") + else: + print("\nNo tool call detected in output") + if round_num < max_rounds: + print("Model may be stuck - no tool call and no answer") + + # Print final message history + print("\n" + "=" * 60) + print("FULL MESSAGE HISTORY") + print("=" * 60) + for i, msg in enumerate(messages): + role = msg["role"] + content = msg["content"] + preview = content[:200] + "..." if len(content) > 200 else content + print(f"\n[{i}] {role.upper()}:") + print(preview) + + +if __name__ == "__main__": + test_tool_loop() diff --git a/inference/tool_search.py b/inference/tool_search.py index 1a3f7b5..cb3b364 100644 --- a/inference/tool_search.py +++ b/inference/tool_search.py @@ -1,131 +1,223 @@ +""" +Exa.ai Search Tool for DeepResearch +AI-native semantic search with neural embeddings for superior research results. + +Exa.ai advantages: +- Neural/semantic search (understands meaning, not just keywords) +- Can retrieve full page contents directly +- Better for research and complex queries +- Built-in query optimization (autoprompt) +- Supports date filtering and domain restrictions +- Category filtering (research papers, news, company info, etc.) +- AI-generated highlights for quick comprehension +""" + import json -from concurrent.futures import ThreadPoolExecutor -from typing import List, Union +import os +from typing import Any, Dict, Optional, Union import requests from qwen_agent.tools.base import BaseTool, register_tool -import asyncio -from typing import Dict, List, Optional, Union -import uuid -import http.client -import json - -import os +EXA_API_KEY = os.environ.get('EXA_API_KEY') +EXA_BASE_URL = "https://api.exa.ai" -SERPER_KEY=os.environ.get('SERPER_KEY_ID') +# Valid Exa categories for filtering results +VALID_CATEGORIES = [ + "company", "research paper", "news", "pdf", + "github", "tweet", "personal site", "linkedin profile" +] @register_tool("search", allow_overwrite=True) class Search(BaseTool): name = "search" - description = "Performs batched web searches: supply an array 'query'; the tool retrieves the top 10 results for each query in one call." + description = "Performs semantic web searches using Exa.ai: supply an array 'query'; retrieves top results with AI-powered understanding. Supports category filtering for research papers, news, etc." parameters = { "type": "object", "properties": { "query": { "type": "array", - "items": { - "type": "string" - }, - "description": "Array of query strings. Include multiple complementary search queries in a single call." + "items": {"type": "string"}, + "description": "Array of query strings. Exa understands natural language queries well." + }, + "num_results": { + "type": "integer", + "description": "Number of results per query (default: 10, max: 100)", + "default": 10 }, + "include_contents": { + "type": "boolean", + "description": "Whether to include page text content and highlights", + "default": False + }, + "category": { + "type": "string", + "description": "Filter by category: 'research paper', 'news', 'company', 'pdf', 'github', 'tweet', 'personal site', 'linkedin profile'", + "enum": ["company", "research paper", "news", "pdf", "github", "tweet", "personal site", "linkedin profile"] + } }, "required": ["query"], } def __init__(self, cfg: Optional[dict] = None): super().__init__(cfg) - def google_search_with_serp(self, query: str): - def contains_chinese_basic(text: str) -> bool: - return any('\u4E00' <= char <= '\u9FFF' for char in text) - conn = http.client.HTTPSConnection("google.serper.dev") - if contains_chinese_basic(query): - payload = json.dumps({ - "q": query, - "location": "China", - "gl": "cn", - "hl": "zh-cn" - }) - - else: - payload = json.dumps({ - "q": query, - "location": "United States", - "gl": "us", - "hl": "en" - }) + self.api_key = EXA_API_KEY + if not self.api_key: + raise ValueError("EXA_API_KEY environment variable not set. Get your key from https://exa.ai/") + + def exa_search( + self, + query: str, + num_results: int = 10, + include_contents: bool = False, + category: Optional[str] = None + ) -> str: + """ + Perform a search using Exa.ai API. + + Exa supports multiple search types: + - "auto": Intelligently combines neural and other methods (default) + - "neural": AI-powered semantic search + - "deep": Comprehensive search with query expansion + + Categories available: + - "research paper": Academic papers and publications + - "news": News articles + - "company": Company websites and info + - "pdf": PDF documents + - "github": GitHub repositories + - "tweet": Twitter/X posts + - "personal site": Personal websites/blogs + - "linkedin profile": LinkedIn profiles + """ headers = { - 'X-API-KEY': SERPER_KEY, - 'Content-Type': 'application/json' - } + "Content-Type": "application/json", + "x-api-key": self.api_key + } + + payload: Dict[str, Any] = { + "query": query, + "numResults": num_results, + "type": "auto", + "useAutoprompt": True, + } + + # Add category filter if specified + if category and category in VALID_CATEGORIES: + payload["category"] = category + if include_contents: + payload["contents"] = { + "text": {"maxCharacters": 2000}, + "highlights": True + } - for i in range(5): + response = None + for attempt in range(3): try: - conn.request("POST", "/search", payload, headers) - res = conn.getresponse() + response = requests.post( + f"{EXA_BASE_URL}/search", + headers=headers, + json=payload, + timeout=30 + ) + response.raise_for_status() break - except Exception as e: - print(e) - if i == 4: - return f"Google search Timeout, return None, Please try again later." + except requests.exceptions.HTTPError as e: + if response is not None and response.status_code == 429: + return f"Exa search rate limited. Please wait and try again." + if response is not None and response.status_code == 401: + return f"Exa API key invalid. Check your EXA_API_KEY environment variable." + if attempt == 2: + return f"Exa search failed after 3 attempts: {str(e)}" + except requests.exceptions.RequestException as e: + if attempt == 2: + return f"Exa search failed after 3 attempts: {str(e)}" continue - - data = res.read() - results = json.loads(data.decode("utf-8")) - - try: - if "organic" not in results: - raise Exception(f"No results found for query: '{query}'. Use a less specific query.") - - web_snippets = list() - idx = 0 - if "organic" in results: - for page in results["organic"]: - idx += 1 - date_published = "" - if "date" in page: - date_published = "\nDate published: " + page["date"] - - source = "" - if "source" in page: - source = "\nSource: " + page["source"] - - snippet = "" - if "snippet" in page: - snippet = "\n" + page["snippet"] - - redacted_version = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{snippet}" - redacted_version = redacted_version.replace("Your browser can't play this video.", "") - web_snippets.append(redacted_version) - - content = f"A Google search for '{query}' found {len(web_snippets)} results:\n\n## Web Results\n" + "\n\n".join(web_snippets) - return content - except: - return f"No results found for '{query}'. Try with a more general query." - - - - def search_with_serp(self, query: str): - result = self.google_search_with_serp(query) - return result - - def call(self, params: Union[str, dict], **kwargs) -> str: - try: - query = params["query"] - except: - return "[Search] Invalid request format: Input must be a JSON object containing 'query' field" - if isinstance(query, str): - # 单个查询 - response = self.search_with_serp(query) + if response is None: + return "Exa search failed: no response received" + + results = response.json() + + if "results" not in results or not results["results"]: + return f"No results found for '{query}'. Try a different query." + + web_snippets = [] + for idx, result in enumerate(results["results"], 1): + title = result.get("title", "No title") + url = result.get("url", "") + published_date = result.get("publishedDate", "") + author = result.get("author", "") + + snippet_parts = [f"{idx}. [{title}]({url})"] + + if author: + snippet_parts.append(f"Author: {author}") + if published_date: + snippet_parts.append(f"Date: {published_date[:10]}") + + # Prefer highlights (AI-generated key points), then text, then snippet + if include_contents: + highlights = result.get("highlights", []) + if highlights: + snippet_parts.append("\nKey points:") + for h in highlights[:3]: + snippet_parts.append(f" • {h}") + elif "text" in result: + text = result["text"][:500] + snippet_parts.append(f"\n{text}...") + elif "snippet" in result: + snippet_parts.append(f"\n{result['snippet']}") + + web_snippets.append("\n".join(snippet_parts)) + + search_type = results.get("resolvedSearchType", "neural") + category_info = f" (category: {category})" if category else "" + content = f"Exa {search_type} search{category_info} for '{query}' found {len(web_snippets)} results:\n\n## Web Results\n\n" + "\n\n".join(web_snippets) + return content + + def call(self, params: Union[str, dict], **kwargs: Any) -> str: + params_dict: Dict[str, Any] + if isinstance(params, str): + try: + params_dict = json.loads(params) + except json.JSONDecodeError: + return "[Search] Invalid JSON input" else: - # 多个查询 - assert isinstance(query, List) + params_dict = dict(params) + + query = params_dict.get("query") + if not query: + return "[Search] Invalid request: 'query' field is required" + + raw_num = params_dict.get("num_results", 10) + num_results = int(raw_num) if raw_num is not None else 10 + include_contents = bool(params_dict.get("include_contents", False)) + category = params_dict.get("category") + + # Validate category if provided + if category and category not in VALID_CATEGORIES: + category = None + + if isinstance(query, str): + return self.exa_search(query, num_results, include_contents, category) + + if isinstance(query, list): responses = [] for q in query: - responses.append(self.search_with_serp(q)) - response = "\n=======\n".join(responses) - - return response + responses.append(self.exa_search(q, num_results, include_contents, category)) + return "\n=======\n".join(responses) + + return "[Search] Invalid query format: must be string or array of strings" + +if __name__ == "__main__": + from dotenv import load_dotenv + + env_path = os.path.join(os.path.dirname(__file__), "..", ".env") + load_dotenv(env_path) + + searcher = Search() + result = searcher.call({"query": ["What is retrieval augmented generation?"]}) + print(result) diff --git a/inference/tool_visit.py b/inference/tool_visit.py index 92e4e3a..7892244 100644 --- a/inference/tool_visit.py +++ b/inference/tool_visit.py @@ -1,256 +1,276 @@ import json import os -import signal -import threading -from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import List, Union +import re +from typing import List, Union, Optional import requests from qwen_agent.tools.base import BaseTool, register_tool from prompt import EXTRACTOR_PROMPT from openai import OpenAI -import random -from urllib.parse import urlparse, unquote +from urllib.parse import urlparse import time -from transformers import AutoTokenizer import tiktoken VISIT_SERVER_TIMEOUT = int(os.getenv("VISIT_SERVER_TIMEOUT", 200)) WEBCONTENT_MAXLENGTH = int(os.getenv("WEBCONTENT_MAXLENGTH", 150000)) - JINA_API_KEYS = os.getenv("JINA_API_KEYS", "") +# Maximum content length to return when summarization fails +RAW_CONTENT_MAX_CHARS = int(os.getenv("RAW_CONTENT_MAX_CHARS", 8000)) + -@staticmethod def truncate_to_tokens(text: str, max_tokens: int = 95000) -> str: - encoding = tiktoken.get_encoding("cl100k_base") + """Truncate text to a maximum number of tokens.""" + try: + encoding = tiktoken.get_encoding("cl100k_base") + tokens = encoding.encode(text) + if len(tokens) <= max_tokens: + return text + return encoding.decode(tokens[:max_tokens]) + except Exception: + # Fallback: rough char estimate (4 chars per token) + max_chars = max_tokens * 4 + return text[:max_chars] if len(text) > max_chars else text + + +def extract_main_content(html_text: str, max_chars: int = 8000) -> str: + """ + Extract main content from HTML/markdown text. + Removes boilerplate like navigation, footers, etc. + """ + lines = html_text.split('\n') + content_lines = [] + total_chars = 0 - tokens = encoding.encode(text) - if len(tokens) <= max_tokens: - return text + skip_patterns = [ + r'^#{1,2}\s*(navigation|menu|footer|sidebar|cookie|privacy|terms)', + r'^\s*(\||---)', # Table separators + r'^\s*\[.*\]\(.*\)\s*$', # Standalone links + r'^(copyright|©|\d{4}\s*[-–]\s*\d{4})', + ] - truncated_tokens = tokens[:max_tokens] - return encoding.decode(truncated_tokens) - -OSS_JSON_FORMAT = """# Response Formats -## visit_content -{"properties":{"rational":{"type":"string","description":"Locate the **specific sections/data** directly related to the user's goal within the webpage content"},"evidence":{"type":"string","description":"Identify and extract the **most relevant information** from the content, never miss any important information, output the **full original context** of the content as far as possible, it can be more than three paragraphs.","summary":{"type":"string","description":"Organize into a concise paragraph with logical flow, prioritizing clarity and judge the contribution of the information to the goal."}}}}""" + for line in lines: + line_lower = line.lower().strip() + + # Skip empty lines at start + if not content_lines and not line.strip(): + continue + + # Skip navigation/boilerplate patterns + skip = False + for pattern in skip_patterns: + if re.match(pattern, line_lower): + skip = True + break + + if skip: + continue + + content_lines.append(line) + total_chars += len(line) + 1 + + if total_chars >= max_chars: + break + + return '\n'.join(content_lines) @register_tool('visit', allow_overwrite=True) class Visit(BaseTool): - # The `description` tells the agent the functionality of this tool. name = 'visit' description = 'Visit webpage(s) and return the summary of the content.' - # The `parameters` tell the agent what input parameters the tool has. parameters = { "type": "object", "properties": { "url": { "type": ["string", "array"], - "items": { - "type": "string" - }, + "items": {"type": "string"}, "minItems": 1, - "description": "The URL(s) of the webpage(s) to visit. Can be a single URL or an array of URLs." - }, - "goal": { + "description": "The URL(s) of the webpage(s) to visit." + }, + "goal": { "type": "string", "description": "The goal of the visit for webpage(s)." - } + } }, "required": ["url", "goal"] } - # The `call` method is the main function of the tool. + + def _validate_url(self, url: str) -> bool: + """Check if URL is valid and has a proper scheme.""" + try: + parsed = urlparse(url) + return parsed.scheme in ('http', 'https') and bool(parsed.netloc) + except Exception: + return False + def call(self, params: Union[str, dict], **kwargs) -> str: try: url = params["url"] goal = params["goal"] - except: - return "[Visit] Invalid request format: Input must be a JSON object containing 'url' and 'goal' fields" - - start_time = time.time() - - # Create log folder if it doesn't exist - log_folder = "log" - os.makedirs(log_folder, exist_ok=True) + except Exception: + return "[Visit] Invalid request: need 'url' and 'goal' fields" if isinstance(url, str): - response = self.readpage_jina(url, goal) - else: - response = [] - assert isinstance(url, List) - start_time = time.time() - for u in url: - if time.time() - start_time > 900: - cur_response = "The useful information in {url} for user goal {goal} as follows: \n\n".format(url=url, goal=goal) - cur_response += "Evidence in page: \n" + "The provided webpage content could not be accessed. Please check the URL or file format." + "\n\n" - cur_response += "Summary: \n" + "The webpage content could not be processed, and therefore, no information is available." + "\n\n" - else: - try: - cur_response = self.readpage_jina(u, goal) - except Exception as e: - cur_response = f"Error fetching {u}: {str(e)}" - response.append(cur_response) - response = "\n=======\n".join(response) - - print(f'Summary Length {len(response)}; Summary Content {response}') - return response.strip() + if not self._validate_url(url): + return f"[Visit] Invalid URL: {url}. URL must start with http:// or https://" + return self.readpage(url, goal) - def call_server(self, msgs, max_retries=2): - api_key = os.environ.get("API_KEY") - url_llm = os.environ.get("API_BASE") - model_name = os.environ.get("SUMMARY_MODEL_NAME", "") - client = OpenAI( - api_key=api_key, - base_url=url_llm, - ) - for attempt in range(max_retries): + # Multiple URLs + responses = [] + start = time.time() + for u in url: + if time.time() - start > 300: # 5 min timeout for batch + responses.append(f"[Timeout] Skipped: {u}") + continue + if not self._validate_url(u): + responses.append(f"[Visit] Invalid URL: {u}") + continue try: - chat_response = client.chat.completions.create( - model=model_name, - messages=msgs, - temperature=0.7 - ) - content = chat_response.choices[0].message.content - if content: - try: - json.loads(content) - except: - # extract json from string - left = content.find('{') - right = content.rfind('}') - if left != -1 and right != -1 and left <= right: - content = content[left:right+1] - return content + responses.append(self.readpage(u, goal)) except Exception as e: - # print(e) - if attempt == (max_retries - 1): - return "" - continue - - - def jina_readpage(self, url: str) -> str: - """ - Read webpage content using Jina service. + responses.append(f"[Error] {u}: {e}") - Args: - url: The URL to read - goal: The goal/purpose of reading the page - - Returns: - str: The webpage content or error message - """ - max_retries = 3 - timeout = 50 + return "\n\n---\n\n".join(responses) + + def jina_fetch(self, url: str, timeout: int = 30) -> Optional[str]: + """Fetch webpage content using Jina Reader API.""" + headers = {} + if JINA_API_KEYS: + headers["Authorization"] = f"Bearer {JINA_API_KEYS}" - for attempt in range(max_retries): - headers = { - "Authorization": f"Bearer {JINA_API_KEYS}", - } + for attempt in range(3): try: - response = requests.get( + resp = requests.get( f"https://r.jina.ai/{url}", headers=headers, timeout=timeout ) - if response.status_code == 200: - webpage_content = response.text - return webpage_content - else: - print(response.text) - raise ValueError("jina readpage error") - except Exception as e: - time.sleep(0.5) - if attempt == max_retries - 1: - return "[visit] Failed to read page." - - return "[visit] Failed to read page." + if resp.status_code == 200 and len(resp.text) > 100: + return resp.text + except requests.RequestException: + pass + time.sleep(0.5) + + return None - def html_readpage_jina(self, url: str) -> str: - max_attempts = 8 - for attempt in range(max_attempts): - content = self.jina_readpage(url) - service = "jina" - print(service) - if content and not content.startswith("[visit] Failed to read page.") and content != "[visit] Empty content." and not content.startswith("[document_parser]"): - return content - return "[visit] Failed to read page." + def direct_fetch(self, url: str, timeout: int = 20) -> Optional[str]: + """Fallback: fetch directly with requests.""" + headers = { + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36" + } + try: + resp = requests.get(url, headers=headers, timeout=timeout) + resp.raise_for_status() + + # Basic HTML to text conversion + text = resp.text + # Remove script/style tags + text = re.sub(r']*>.*?', '', text, flags=re.DOTALL | re.IGNORECASE) + text = re.sub(r']*>.*?', '', text, flags=re.DOTALL | re.IGNORECASE) + # Remove HTML tags + text = re.sub(r'<[^>]+>', ' ', text) + # Clean whitespace + text = re.sub(r'\s+', ' ', text).strip() + + return text if len(text) > 100 else None + except Exception: + return None - def readpage_jina(self, url: str, goal: str) -> str: - """ - Attempt to read webpage content by alternating between jina and aidata services. + def summarize_content(self, content: str, goal: str) -> Optional[dict]: + """Use LLM API to summarize content. Returns None if unavailable.""" + api_key = os.environ.get("API_KEY") + api_base = os.environ.get("API_BASE") + model = os.environ.get("SUMMARY_MODEL_NAME", "") - Args: - url: The URL to read - goal: The goal/purpose of reading the page + if not api_key or not api_base: + return None + + try: + client = OpenAI(api_key=api_key, base_url=api_base) + + # Truncate content for summarization + content = truncate_to_tokens(content, max_tokens=30000) + + prompt = EXTRACTOR_PROMPT.format(webpage_content=content, goal=goal) + + resp = client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": prompt}], + temperature=0.7, + max_tokens=2000 + ) + + result = resp.choices[0].message.content + if not result: + return None + + # Parse JSON response + result = result.replace("```json", "").replace("```", "").strip() + + # Try to extract JSON + left = result.find('{') + right = result.rfind('}') + if left != -1 and right > left: + result = result[left:right+1] - Returns: - str: The webpage content or error message + return json.loads(result) + except Exception as e: + print(f"[visit] Summarization failed: {e}") + return None + + def readpage(self, url: str, goal: str) -> str: + """ + Read and process a webpage. + + Strategy: + 1. Try Jina Reader first (best for complex pages) + 2. Fallback to direct fetch if Jina fails + 3. Try LLM summarization if API available + 4. Return extracted raw content if summarization unavailable """ - - summary_page_func = self.call_server - max_retries = int(os.getenv('VISIT_SERVER_MAX_RETRIES', 1)) + # Step 1: Fetch content + content = self.jina_fetch(url) + + if not content: + content = self.direct_fetch(url) + + if not content: + return self._format_error(url, goal, "Failed to fetch webpage content") + + # Step 2: Try summarization + summary = self.summarize_content(content, goal) + + if summary and summary.get("evidence") and summary.get("summary"): + return self._format_success(url, goal, summary["evidence"], summary["summary"]) + + # Step 3: Fallback - return extracted raw content + extracted = extract_main_content(content, RAW_CONTENT_MAX_CHARS) + + if len(extracted) < 100: + return self._format_error(url, goal, "Page content too short or empty") + + return self._format_raw(url, goal, extracted) - content = self.html_readpage_jina(url) + def _format_success(self, url: str, goal: str, evidence: str, summary: str) -> str: + return f"""Content from {url} for goal: {goal} - if content and not content.startswith("[visit] Failed to read page.") and content != "[visit] Empty content." and not content.startswith("[document_parser]"): - content = truncate_to_tokens(content, max_tokens=95000) - messages = [{"role":"user","content": EXTRACTOR_PROMPT.format(webpage_content=content, goal=goal)}] - parse_retry_times = 0 - raw = summary_page_func(messages, max_retries=max_retries) - summary_retries = 3 - while len(raw) < 10 and summary_retries >= 0: - truncate_length = int(0.7 * len(content)) if summary_retries > 0 else 25000 - status_msg = ( - f"[visit] Summary url[{url}] " - f"attempt {3 - summary_retries + 1}/3, " - f"content length: {len(content)}, " - f"truncating to {truncate_length} chars" - ) if summary_retries > 0 else ( - f"[visit] Summary url[{url}] failed after 3 attempts, " - f"final truncation to 25000 chars" - ) - print(status_msg) - content = content[:truncate_length] - extraction_prompt = EXTRACTOR_PROMPT.format( - webpage_content=content, - goal=goal - ) - messages = [{"role": "user", "content": extraction_prompt}] - raw = summary_page_func(messages, max_retries=max_retries) - summary_retries -= 1 +**Evidence:** +{evidence} - parse_retry_times = 2 - if isinstance(raw, str): - raw = raw.replace("```json", "").replace("```", "").strip() - while parse_retry_times < 3: - try: - raw = json.loads(raw) - break - except: - raw = summary_page_func(messages, max_retries=max_retries) - parse_retry_times += 1 - - if parse_retry_times >= 3: - useful_information = "The useful information in {url} for user goal {goal} as follows: \n\n".format(url=url, goal=goal) - useful_information += "Evidence in page: \n" + "The provided webpage content could not be accessed. Please check the URL or file format." + "\n\n" - useful_information += "Summary: \n" + "The webpage content could not be processed, and therefore, no information is available." + "\n\n" - else: - useful_information = "The useful information in {url} for user goal {goal} as follows: \n\n".format(url=url, goal=goal) - useful_information += "Evidence in page: \n" + str(raw["evidence"]) + "\n\n" - useful_information += "Summary: \n" + str(raw["summary"]) + "\n\n" +**Summary:** +{summary}""" - if len(useful_information) < 10 and summary_retries < 0: - print("[visit] Could not generate valid summary after maximum retries") - useful_information = "[visit] Failed to read page" - - return useful_information + def _format_raw(self, url: str, goal: str, content: str) -> str: + return f"""Content from {url} for goal: {goal} + +**Raw Content (summarization unavailable):** +{content} + +Note: Please extract the relevant information for your goal from the content above.""" - # If no valid content was obtained after all retries - else: - useful_information = "The useful information in {url} for user goal {goal} as follows: \n\n".format(url=url, goal=goal) - useful_information += "Evidence in page: \n" + "The provided webpage content could not be accessed. Please check the URL or file format." + "\n\n" - useful_information += "Summary: \n" + "The webpage content could not be processed, and therefore, no information is available." + "\n\n" - return useful_information + def _format_error(self, url: str, goal: str, reason: str) -> str: + return f"""Could not retrieve content from {url} +Goal: {goal} +Reason: {reason} - \ No newline at end of file +Please try a different source or search query."""