diff --git a/.gitignore b/.gitignore index 47d718f..9ea8892 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,7 @@ doc .rust* debug/ storage/ -.roo/ \ No newline at end of file +.roo/ + +# claude +.mcp.json \ No newline at end of file diff --git a/README.md b/README.md index c1533b7..b17456e 100644 --- a/README.md +++ b/README.md @@ -111,26 +111,31 @@ necessary for crates that require specific features to be enabled for `cargo doc` to succeed (e.g., crates requiring a runtime feature like `async-stripe`). -```bash -# Set the API key (replace with your actual key) -export OPENAI_API_KEY="sk-..." - -# Example: Run server for the latest 1.x version of serde -rustdocs_mcp_server "serde@^1.0" -# Example: Run server for a specific version of reqwest -rustdocs_mcp_server "reqwest@0.12.0" +#### Multi-Crate Mode -# Example: Run server for the latest version of tokio -rustdocs_mcp_server tokio +You can also run a single server instance that provides documentation for multiple crates by listing them as arguments: -# Example: Run server for async-stripe, enabling a required runtime feature -rustdocs_mcp_server "async-stripe@0.40" -F runtime-tokio-hyper-rustls +```bash +# Set the API key (replace with your actual key) +export OPENAI_API_KEY="sk-..." -# Example: Run server for another crate with multiple features -rustdocs_mcp_server "some-crate@1.2" --features feat1,feat2 +# Example: Run server for multiple crates +rust-mcp-docs rmcp serde chrono:serde tokio openai tera dotenvy clap:derive:env + +# This creates tools for each crate: +# - query_rmcp_docs +# - query_serde_docs +# - query_chrono_docs +# - query_tokio_docs +# - query_openai_docs +# - query_tera_docs +# - query_dotenvy_docs +# - query_clap_docs ``` +In multi-crate mode, crate specifications can include features after a colon (e.g., `chrono:serde`, `clap:derive:env`). Each crate will have its own documentation tool named `query_{crate_name}_docs`. + On the first run for a specific crate version _and feature set_, the server will: diff --git a/src/doc_loader.rs b/src/doc_loader.rs index 9ab5ede..2f8ba44 100644 --- a/src/doc_loader.rs +++ b/src/doc_loader.rs @@ -126,8 +126,7 @@ edition = "2021" // Iterate through subdirectories in `target/doc` and find the one containing `index.html`. let base_doc_path = temp_dir_path.join("doc"); - let mut target_docs_path: Option = None; - let mut found_count = 0; + let mut found_paths: Vec = Vec::new(); if base_doc_path.is_dir() { for entry_result in fs::read_dir(&base_doc_path)? { @@ -136,31 +135,58 @@ edition = "2021" let dir_path = entry.path(); let index_html_path = dir_path.join("index.html"); if index_html_path.is_file() { - if target_docs_path.is_none() { - target_docs_path = Some(dir_path); - } - found_count += 1; - } else { + found_paths.push(dir_path); } } } + eprintln!("[DEBUG] Found {} directories with index.html files", found_paths.len()); + for (i, path) in found_paths.iter().enumerate() { + eprintln!("[DEBUG] Directory {}: {}", i + 1, path.display()); + } } - let docs_path = match (found_count, target_docs_path) { - (1, Some(path)) => { - path - }, - (0, _) => { + let docs_path = match found_paths.len() { + 0 => { return Err(DocLoaderError::CargoLib(anyhow::anyhow!( "Could not find any subdirectory containing index.html within '{}'. Cargo doc might have failed or produced unexpected output.", base_doc_path.display() ))); }, - (count, _) => { - return Err(DocLoaderError::CargoLib(anyhow::anyhow!( - "Expected exactly one subdirectory containing index.html within '{}', but found {}. Cannot determine the correct documentation path.", - base_doc_path.display(), count - ))); + 1 => { + found_paths.into_iter().next().unwrap() + }, + _ => { + // Multiple directories found - look specifically for the crate name directory + let crate_name_normalized = crate_name.replace('-', "_"); + eprintln!("[DEBUG] Multiple directories found, looking for crate directory: {}", crate_name_normalized); + + // Find the directory that matches the crate name + let matching_dir = found_paths.iter().find(|path| { + if let Some(dir_name) = path.file_name() { + if let Some(dir_name_str) = dir_name.to_str() { + return dir_name_str == crate_name_normalized; + } + } + false + }); + + match matching_dir { + Some(crate_dir) => { + eprintln!("[DEBUG] Found crate-specific directory: {}", crate_dir.display()); + crate_dir.clone() + }, + None => { + eprintln!("[DEBUG] Crate-specific directory '{}' not found in available directories:", crate_name_normalized); + for path in &found_paths { + if let Some(dir_name) = path.file_name() { + eprintln!("[DEBUG] - {}", dir_name.to_string_lossy()); + } + } + // Fallback to the first directory found + eprintln!("[DEBUG] Using first available directory: {}", found_paths[0].display()); + found_paths.into_iter().next().unwrap() + } + } } }; // --- End finding documentation directory --- diff --git a/src/main.rs b/src/main.rs index cfd2cf1..b22178f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,7 +9,7 @@ use crate::{ doc_loader::Document, embeddings::{generate_embeddings, CachedDocumentEmbedding, OPENAI_CLIENT}, error::ServerError, - server::RustDocsServer, // Import the updated RustDocsServer + server::{RustDocsServer, CrateData}, // Import the updated RustDocsServer and CrateData }; use async_openai::{Client as OpenAIClient, config::OpenAIConfig}; use bincode::config; @@ -22,7 +22,7 @@ use rmcp::{ ServiceExt, // Import the ServiceExt trait for .serve() and .waiting() }; use std::{ - collections::hash_map::DefaultHasher, + collections::{hash_map::DefaultHasher, HashMap}, env, fs::{self, File}, hash::{Hash, Hasher}, // Import hashing utilities @@ -37,13 +37,52 @@ use xdg::BaseDirectories; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Cli { - /// The package ID specification (e.g., "serde@^1.0", "tokio"). - #[arg()] // Positional argument - package_spec: String, + /// The package ID specifications with optional features (e.g., "serde@^1.0:feature1:feature2", "tokio", "reqwest@0.12:json"). + #[arg(required = true)] // Positional arguments, at least one required + package_specs: Vec, +} - /// Optional features to enable for the crate when generating documentation. - #[arg(short = 'F', long, value_delimiter = ',', num_args = 0..)] // Allow multiple comma-separated values +#[derive(Debug, Clone)] +struct CrateSpec { + name: String, + version_req: String, features: Option>, + original_spec: String, +} + +// Helper function to parse crate specification with features +fn parse_crate_spec(spec: &str) -> Result { + // Split by ':' to separate crate spec from features + let parts: Vec<&str> = spec.split(':').collect(); + + if parts.is_empty() { + return Err("Empty crate specification".to_string()); + } + + let crate_part = parts[0]; + let features = if parts.len() > 1 { + Some(parts[1..].iter().map(|s| s.to_string()).collect()) + } else { + None + }; + + // Parse the crate part using PackageIdSpec + let package_spec = PackageIdSpec::parse(crate_part).map_err(|e| { + format!("Failed to parse package ID spec '{}': {}", crate_part, e) + })?; + + let name = package_spec.name().to_string(); + let version_req = package_spec + .version() + .map(|v| v.to_string()) + .unwrap_or_else(|| "*".to_string()); + + Ok(CrateSpec { + name, + version_req, + features, + original_spec: spec.to_string(), + }) } // Helper function to create a stable hash from features @@ -67,119 +106,23 @@ async fn main() -> Result<(), ServerError> { // --- Parse CLI Arguments --- let cli = Cli::parse(); - let specid_str = cli.package_spec.trim().to_string(); // Trim whitespace - let features = cli.features.map(|f| { - f.into_iter().map(|s| s.trim().to_string()).collect() // Trim each feature - }); - - // Parse the specid string - let spec = PackageIdSpec::parse(&specid_str).map_err(|e| { - ServerError::Config(format!( - "Failed to parse package ID spec '{}': {}", - specid_str, e - )) - })?; - - let crate_name = spec.name().to_string(); - let crate_version_req = spec - .version() - .map(|v| v.to_string()) - .unwrap_or_else(|| "*".to_string()); - - eprintln!( - "Target Spec: {}, Parsed Name: {}, Version Req: {}, Features: {:?}", - specid_str, crate_name, crate_version_req, features - ); - - // --- Determine Paths (incorporating features) --- - - // Sanitize the version requirement string - let sanitized_version_req = crate_version_req - .replace(|c: char| !c.is_alphanumeric() && c != '.' && c != '-', "_"); - - // Generate a stable hash for the features to use in the path - let features_hash = hash_features(&features); - - // Construct the relative path component including features hash - let embeddings_relative_path = PathBuf::from(&crate_name) - .join(&sanitized_version_req) - .join(&features_hash) // Add features hash as a directory level - .join("embeddings.bin"); - - #[cfg(not(target_os = "windows"))] - let embeddings_file_path = { - let xdg_dirs = BaseDirectories::with_prefix("rustdocs-mcp-server") - .map_err(|e| ServerError::Xdg(format!("Failed to get XDG directories: {}", e)))?; - xdg_dirs - .place_data_file(embeddings_relative_path) - .map_err(ServerError::Io)? - }; - - #[cfg(target_os = "windows")] - let embeddings_file_path = { - let cache_dir = dirs::cache_dir().ok_or_else(|| { - ServerError::Config("Could not determine cache directory on Windows".to_string()) + + // Parse all package specs with features + let mut parsed_crates = Vec::new(); + for spec_str in &cli.package_specs { + let crate_spec = parse_crate_spec(spec_str.trim()).map_err(|e| { + ServerError::Config(e) })?; - let app_cache_dir = cache_dir.join("rustdocs-mcp-server"); - // Ensure the base app cache directory exists - fs::create_dir_all(&app_cache_dir).map_err(ServerError::Io)?; - app_cache_dir.join(embeddings_relative_path) - }; - - eprintln!("Cache file path: {:?}", embeddings_file_path); - - // --- Try Loading Embeddings and Documents from Cache --- - let mut loaded_from_cache = false; - let mut loaded_embeddings: Option)>> = None; - let mut loaded_documents_from_cache: Option> = None; - - if embeddings_file_path.exists() { - eprintln!( - "Attempting to load cached data from: {:?}", - embeddings_file_path - ); - match File::open(&embeddings_file_path) { - Ok(file) => { - let reader = BufReader::new(file); - match bincode::decode_from_reader::, _, _>( - reader, - config::standard(), - ) { - Ok(cached_data) => { - eprintln!( - "Successfully loaded {} items from cache. Separating data...", - cached_data.len() - ); - let mut embeddings = Vec::with_capacity(cached_data.len()); - let mut documents = Vec::with_capacity(cached_data.len()); - for item in cached_data { - embeddings.push((item.path.clone(), Array1::from(item.vector))); - documents.push(Document { - path: item.path, - content: item.content, - }); - } - loaded_embeddings = Some(embeddings); - loaded_documents_from_cache = Some(documents); - loaded_from_cache = true; - } - Err(e) => { - eprintln!("Failed to decode cache file: {}. Will regenerate.", e); - } - } - } - Err(e) => { - eprintln!("Failed to open cache file: {}. Will regenerate.", e); - } - } - } else { - eprintln!("Cache file not found. Will generate."); + parsed_crates.push(crate_spec); } - // --- Generate or Use Loaded Embeddings --- - let mut generated_tokens: Option = None; - let mut generation_cost: Option = None; - let mut documents_for_server: Vec = loaded_documents_from_cache.unwrap_or_default(); + eprintln!("Target Crates:"); + for crate_spec in &parsed_crates { + let features_str = crate_spec.features.as_ref() + .map(|f| format!(" [features: {}]", f.join(", "))) + .unwrap_or_default(); + eprintln!(" - {}@{}{}", crate_spec.name, crate_spec.version_req, features_str); + } // --- Initialize OpenAI Client (needed for question embedding even if cache hit) --- let openai_client = if let Ok(api_base) = env::var("OPENAI_API_BASE") { @@ -192,131 +135,242 @@ async fn main() -> Result<(), ServerError> { .set(openai_client.clone()) // Clone the client for the OnceCell .expect("Failed to set OpenAI client"); - let final_embeddings = match loaded_embeddings { - Some(embeddings) => { - eprintln!("Using embeddings and documents loaded from cache."); - embeddings - } - None => { - eprintln!("Proceeding with documentation loading and embedding generation."); - - let _openai_api_key = env::var("OPENAI_API_KEY") - .map_err(|_| ServerError::MissingEnvVar("OPENAI_API_KEY".to_string()))?; - - eprintln!( - "Loading documents for crate: {} (Version Req: {}, Features: {:?})", - crate_name, crate_version_req, features - ); - // Pass features to load_documents - let loaded_documents = - doc_loader::load_documents(&crate_name, &crate_version_req, features.as_ref())?; // Pass features here - eprintln!("Loaded {} documents.", loaded_documents.len()); - documents_for_server = loaded_documents.clone(); - - eprintln!("Generating embeddings..."); - let embedding_model: String = env::var("EMBEDDING_MODEL") - .unwrap_or_else(|_| "text-embedding-3-small".to_string()); - let (generated_embeddings, total_tokens) = - generate_embeddings(&openai_client, &loaded_documents, &embedding_model).await?; - - let cost_per_million = 0.02; - let estimated_cost = (total_tokens as f64 / 1_000_000.0) * cost_per_million; + // Check for API key + let _openai_api_key = env::var("OPENAI_API_KEY") + .map_err(|_| ServerError::MissingEnvVar("OPENAI_API_KEY".to_string()))?; + + // --- Process Each Crate --- + let mut crates_data = HashMap::new(); + let mut all_loaded_from_cache = Vec::new(); + let mut total_generated_tokens = 0; + let mut total_generation_cost = 0.0; + + for crate_spec in &parsed_crates { + let crate_name = &crate_spec.name; + let crate_version_req = &crate_spec.version_req; + let features = &crate_spec.features; + + eprintln!("Processing crate: {} (Version Req: {}, Features: {:?})", crate_name, crate_version_req, features); + + // --- Determine Paths (incorporating per-crate features) --- + // Generate a stable hash for this crate's features + let features_hash = hash_features(features); + + // Sanitize the version requirement string + let sanitized_version_req = crate_version_req + .replace(|c: char| !c.is_alphanumeric() && c != '.' && c != '-', "_"); + + // Construct the relative path component including features hash + let embeddings_relative_path = PathBuf::from(crate_name) + .join(&sanitized_version_req) + .join(&features_hash) // Add features hash as a directory level + .join("embeddings.bin"); + + #[cfg(not(target_os = "windows"))] + let embeddings_file_path = { + let xdg_dirs = BaseDirectories::with_prefix("rustdocs-mcp-server") + .map_err(|e| ServerError::Xdg(format!("Failed to get XDG directories: {}", e)))?; + xdg_dirs + .place_data_file(embeddings_relative_path) + .map_err(ServerError::Io)? + }; + + #[cfg(target_os = "windows")] + let embeddings_file_path = { + let cache_dir = dirs::cache_dir().ok_or_else(|| { + ServerError::Config("Could not determine cache directory on Windows".to_string()) + })?; + let app_cache_dir = cache_dir.join("rustdocs-mcp-server"); + // Ensure the base app cache directory exists + fs::create_dir_all(&app_cache_dir).map_err(ServerError::Io)?; + app_cache_dir.join(embeddings_relative_path) + }; + + eprintln!("Cache file path for {}: {:?}", crate_name, embeddings_file_path); + + // --- Try Loading Embeddings and Documents from Cache --- + let mut loaded_from_cache = false; + let mut loaded_embeddings: Option)>> = None; + let mut loaded_documents_from_cache: Option> = None; + + if embeddings_file_path.exists() { eprintln!( - "Embedding generation cost for {} tokens: ${:.6}", - total_tokens, estimated_cost - ); - generated_tokens = Some(total_tokens); - generation_cost = Some(estimated_cost); - - eprintln!( - "Saving generated documents and embeddings to: {:?}", + "Attempting to load cached data from: {:?}", embeddings_file_path ); - - let mut combined_cache_data: Vec = Vec::new(); - let embedding_map: std::collections::HashMap> = - generated_embeddings.clone().into_iter().collect(); - - for doc in &loaded_documents { - if let Some(embedding_array) = embedding_map.get(&doc.path) { - combined_cache_data.push(CachedDocumentEmbedding { - path: doc.path.clone(), - content: doc.content.clone(), - vector: embedding_array.to_vec(), - }); - } else { - eprintln!( - "Warning: Embedding not found for document path: {}. Skipping from cache.", - doc.path - ); - } - } - - match bincode::encode_to_vec(&combined_cache_data, config::standard()) { - Ok(encoded_bytes) => { - if let Some(parent_dir) = embeddings_file_path.parent() { - if !parent_dir.exists() { - if let Err(e) = fs::create_dir_all(parent_dir) { - eprintln!( - "Warning: Failed to create cache directory {}: {}", - parent_dir.display(), - e - ); + match File::open(&embeddings_file_path) { + Ok(file) => { + let reader = BufReader::new(file); + match bincode::decode_from_reader::, _, _>( + reader, + config::standard(), + ) { + Ok(cached_data) => { + eprintln!( + "Successfully loaded {} items from cache. Separating data...", + cached_data.len() + ); + let mut embeddings = Vec::with_capacity(cached_data.len()); + let mut documents = Vec::with_capacity(cached_data.len()); + for item in cached_data { + embeddings.push((item.path.clone(), Array1::from(item.vector))); + documents.push(Document { + path: item.path, + content: item.content, + }); } + loaded_embeddings = Some(embeddings); + loaded_documents_from_cache = Some(documents); + loaded_from_cache = true; + } + Err(e) => { + eprintln!("Failed to decode cache file: {}. Will regenerate.", e); } - } - if let Err(e) = fs::write(&embeddings_file_path, encoded_bytes) { - eprintln!("Warning: Failed to write cache file: {}", e); - } else { - eprintln!( - "Cache saved successfully ({} items).", - combined_cache_data.len() - ); } } Err(e) => { - eprintln!("Warning: Failed to encode data for cache: {}", e); + eprintln!("Failed to open cache file: {}. Will regenerate.", e); } } - generated_embeddings + } else { + eprintln!("Cache file not found. Will generate."); } - }; - // --- Initialize and Start Server --- - eprintln!( - "Initializing server for crate: {} (Version Req: {}, Features: {:?})", - crate_name, crate_version_req, features - ); + // --- Generate or Use Loaded Embeddings --- + let mut generated_tokens: Option = None; + let mut generation_cost: Option = None; + let mut documents_for_server: Vec = loaded_documents_from_cache.unwrap_or_default(); - let features_str = features - .as_ref() - .map(|f| format!(" Features: {:?}", f)) - .unwrap_or_default(); + let final_embeddings = match loaded_embeddings { + Some(embeddings) => { + eprintln!("Using embeddings and documents loaded from cache for {}.", crate_name); + all_loaded_from_cache.push(crate_name.clone()); + embeddings + } + None => { + eprintln!("Proceeding with documentation loading and embedding generation for {}.", crate_name); + + eprintln!( + "Loading documents for crate: {} (Version Req: {}, Features: {:?})", + crate_name, crate_version_req, features + ); + // Pass features to load_documents + let loaded_documents = + doc_loader::load_documents(crate_name, crate_version_req, features.as_ref())?; + eprintln!("Loaded {} documents for {}.", loaded_documents.len(), crate_name); + documents_for_server = loaded_documents.clone(); + + eprintln!("Generating embeddings for {}...", crate_name); + let embedding_model: String = env::var("EMBEDDING_MODEL") + .unwrap_or_else(|_| "text-embedding-3-small".to_string()); + let (generated_embeddings, total_tokens) = + generate_embeddings(&openai_client, &loaded_documents, &embedding_model).await?; + + let cost_per_million = 0.02; + let estimated_cost = (total_tokens as f64 / 1_000_000.0) * cost_per_million; + eprintln!( + "Embedding generation cost for {} ({} tokens): ${:.6}", + crate_name, total_tokens, estimated_cost + ); + generated_tokens = Some(total_tokens); + generation_cost = Some(estimated_cost); + total_generated_tokens += total_tokens; + total_generation_cost += estimated_cost; + + eprintln!( + "Saving generated documents and embeddings to: {:?}", + embeddings_file_path + ); + + let mut combined_cache_data: Vec = Vec::new(); + let embedding_map: std::collections::HashMap> = + generated_embeddings.clone().into_iter().collect(); + + for doc in &loaded_documents { + if let Some(embedding_array) = embedding_map.get(&doc.path) { + combined_cache_data.push(CachedDocumentEmbedding { + path: doc.path.clone(), + content: doc.content.clone(), + vector: embedding_array.to_vec(), + }); + } else { + eprintln!( + "Warning: Embedding not found for document path: {}. Skipping from cache.", + doc.path + ); + } + } + + match bincode::encode_to_vec(&combined_cache_data, config::standard()) { + Ok(encoded_bytes) => { + if let Some(parent_dir) = embeddings_file_path.parent() { + if !parent_dir.exists() { + if let Err(e) = fs::create_dir_all(parent_dir) { + eprintln!( + "Warning: Failed to create cache directory {}: {}", + parent_dir.display(), + e + ); + } + } + } + if let Err(e) = fs::write(&embeddings_file_path, encoded_bytes) { + eprintln!("Warning: Failed to write cache file: {}", e); + } else { + eprintln!( + "Cache saved successfully ({} items).", + combined_cache_data.len() + ); + } + } + Err(e) => { + eprintln!("Warning: Failed to encode data for cache: {}", e); + } + } + generated_embeddings + } + }; + + // Create metadata string for this crate + let metadata = if loaded_from_cache { + format!("Version: {}, Features: {:?}, Loaded from cache", crate_version_req, features) + } else { + let tokens = generated_tokens.unwrap_or(0); + let cost = generation_cost.unwrap_or(0.0); + format!("Version: {}, Features: {:?}, Generated {} embeddings for {} tokens (Cost: ${:.6})", + crate_version_req, features, final_embeddings.len(), tokens, cost) + }; + + // Store crate data + crates_data.insert(crate_name.clone(), CrateData { + documents: documents_for_server, + embeddings: final_embeddings, + metadata, + }); + + eprintln!("Completed processing crate: {}", crate_name); + } - let startup_message = if loaded_from_cache { + // --- Create startup message for all crates --- + let startup_message = if all_loaded_from_cache.len() == parsed_crates.len() { format!( - "Server for crate '{}' (Version Req: '{}'{}) initialized. Loaded {} embeddings from cache.", - crate_name, crate_version_req, features_str, final_embeddings.len() + "Server initialized with {} crates: {}. All loaded from cache.", + parsed_crates.len(), + parsed_crates.iter().map(|spec| spec.name.as_str()).collect::>().join(", ") ) } else { - let tokens = generated_tokens.unwrap_or(0); - let cost = generation_cost.unwrap_or(0.0); format!( - "Server for crate '{}' (Version Req: '{}'{}) initialized. Generated {} embeddings for {} tokens (Est. Cost: ${:.6}).", - crate_name, - crate_version_req, - features_str, - final_embeddings.len(), - tokens, - cost + "Server initialized with {} crates: {}. Generated {} total tokens (Est. Cost: ${:.6}).", + parsed_crates.len(), + parsed_crates.iter().map(|spec| spec.name.as_str()).collect::>().join(", "), + total_generated_tokens, + total_generation_cost ) }; // Create the service instance using the updated ::new() let service = RustDocsServer::new( - crate_name.clone(), // Pass crate_name directly - documents_for_server, - final_embeddings, + crates_data, startup_message, )?; @@ -329,7 +383,7 @@ async fn main() -> Result<(), ServerError> { ServerError::McpRuntime(e.to_string()) // Use the new McpRuntime variant })?; - eprintln!("{} Docs MCP server running...", &crate_name); + eprintln!("Multi-crate Docs MCP server running..."); // Wait for the server to complete (e.g., stdin closed) server_handle.waiting().await.map_err(|e| { diff --git a/src/server.rs b/src/server.rs index 9e886ca..13a8222 100644 --- a/src/server.rs +++ b/src/server.rs @@ -48,9 +48,18 @@ use rmcp::{ use schemars::JsonSchema; // Import JsonSchema use serde::Deserialize; // Import Deserialize use serde_json::json; -use std::{/* borrow::Cow, */ env, sync::Arc}; // Removed borrow::Cow +use std::{/* borrow::Cow, */ env, sync::Arc, collections::HashMap}; // Removed borrow::Cow use tokio::sync::Mutex; +// --- Structs for Multi-Crate Support --- + +#[derive(Debug, Clone)] +pub struct CrateData { + pub documents: Vec, + pub embeddings: Vec<(String, Array1)>, + pub metadata: String, // Version, features info for logging +} + // --- Argument Struct for the Tool --- #[derive(Debug, Deserialize, JsonSchema)] @@ -65,9 +74,7 @@ struct QueryRustDocsArgs { // No longer needs ServerState, holds data directly #[derive(Clone)] // Add Clone for tool macro requirements pub struct RustDocsServer { - crate_name: Arc, // Use Arc for cheap cloning - documents: Arc>, - embeddings: Arc)>>, + crates: Arc>, // Map of crate name to crate data peer: Arc>>>, // Uses tokio::sync::Mutex startup_message: Arc>>, // Keep the message itself startup_message_sent: Arc>, // Flag to track if sent (using tokio::sync::Mutex) @@ -77,16 +84,12 @@ pub struct RustDocsServer { impl RustDocsServer { // Updated constructor pub fn new( - crate_name: String, - documents: Vec, - embeddings: Vec<(String, Array1)>, + crates: HashMap, startup_message: String, ) -> Result { // Keep ServerError for potential future init errors Ok(Self { - crate_name: Arc::new(crate_name), - documents: Arc::new(documents), - embeddings: Arc::new(embeddings), + crates: Arc::new(crates), peer: Arc::new(Mutex::new(None)), // Uses tokio::sync::Mutex startup_message: Arc::new(Mutex::new(Some(startup_message))), // Initialize message startup_message_sent: Arc::new(Mutex::new(false)), // Initialize flag to false @@ -123,6 +126,23 @@ impl RustDocsServer { fn _create_resource_text(&self, uri: &str, name: &str) -> Resource { RawResource::new(uri, name.to_string()).no_annotation() } + + // Helper function to extract crate name from tool name + fn extract_crate_name_from_tool(&self, tool_name: &str) -> Option<&str> { + if tool_name.starts_with("query_") && tool_name.ends_with("_docs") { + let crate_part = &tool_name[6..tool_name.len()-5]; // Remove "query_" and "_docs" + let crate_name = crate_part.replace('_', "-"); // Convert underscores back to hyphens + // Check if this crate exists in our crates map + if self.crates.contains_key(&crate_name) { + // We need to return a reference that lives long enough, so let's find it in the keys + self.crates.keys().find(|&k| k == &crate_name).map(|s| s.as_str()) + } else { + None + } + } else { + None + } + } } // --- Tool Implementation --- @@ -132,13 +152,11 @@ impl RustDocsServer { impl RustDocsServer { // Define the tool using the tool macro // Name removed; will be handled dynamically by overriding list_tools/get_tool - #[tool( - description = "Query documentation for a specific Rust crate using semantic search and LLM summarization." - )] - async fn query_rust_docs( + // Generic query method that can work with any crate + async fn query_crate_docs( &self, - #[tool(aggr)] // Aggregate arguments into the struct - args: QueryRustDocsArgs, + crate_name: &str, + question: &str, ) -> Result { // --- Send Startup Message (if not already sent) --- let mut sent_guard = self.startup_message_sent.lock().await; @@ -157,16 +175,17 @@ impl RustDocsServer { drop(sent_guard); } - // Argument validation for crate_name removed - - let question = &args.question; + // Get the crate data + let crate_data = self.crates.get(crate_name).ok_or_else(|| { + McpError::invalid_params(format!("Crate '{}' not found", crate_name), None) + })?; // Log received query via MCP self.send_log( LoggingLevel::Info, format!( "Received query for crate '{}': {}", - self.crate_name, question + crate_name, question ), ); @@ -199,7 +218,7 @@ impl RustDocsServer { // --- Find Best Matching Document --- let mut best_match: Option<(&str, f32)> = None; - for (path, doc_embedding) in self.embeddings.iter() { + for (path, doc_embedding) in crate_data.embeddings.iter() { let score = cosine_similarity(question_vector.view(), doc_embedding.view()); if best_match.is_none() || score > best_match.unwrap().1 { best_match = Some((path, score)); @@ -210,7 +229,7 @@ impl RustDocsServer { let response_text = match best_match { Some((best_path, _score)) => { eprintln!("Best match found: {}", best_path); - let context_doc = self.documents.iter().find(|doc| doc.path == best_path); + let context_doc = crate_data.documents.iter().find(|doc| doc.path == best_path); if let Some(doc) = context_doc { let system_prompt = format!( @@ -218,7 +237,7 @@ impl RustDocsServer { Answer the user's question based *only* on the provided context. \ If the context does not contain the answer, say so. \ Do not make up information. Be clear, concise, and comprehensive providing example usage code when possible.", - self.crate_name + crate_name ); let user_prompt = format!( "Context:\n---\n{}\n---\n\nQuestion: {}", @@ -278,14 +297,25 @@ impl RustDocsServer { // --- Format and Return Result --- Ok(CallToolResult::success(vec![Content::text(format!( "From {} docs: {}", - self.crate_name, response_text + crate_name, response_text ))])) } + + #[tool( + description = "Query documentation for a specific Rust crate using semantic search and LLM summarization." + )] + async fn query_rust_docs( + &self, + #[tool(aggr)] // Aggregate arguments into the struct + _args: QueryRustDocsArgs, + ) -> Result { + // This method is now just a placeholder - actual routing happens in call_tool + Err(McpError::invalid_params("Tool routing should happen in call_tool".to_string(), None)) + } } // --- ServerHandler Implementation --- -#[tool(tool_box)] // Use imported tool macro directly impl ServerHandler for RustDocsServer { fn get_info(&self) -> ServerInfo { // Define capabilities using the builder @@ -295,6 +325,8 @@ impl ServerHandler for RustDocsServer { // Add other capabilities like resources, prompts if needed later .build(); + let crate_list: Vec = self.crates.keys().cloned().collect(); + ServerInfo { protocol_version: ProtocolVersion::V_2024_11_05, // Use latest known version capabilities, @@ -302,12 +334,12 @@ impl ServerHandler for RustDocsServer { name: "rust-docs-mcp-server".to_string(), version: env!("CARGO_PKG_VERSION").to_string(), }, - // Provide instructions based on the specific crate + // Provide instructions based on all loaded crates instructions: Some(format!( - "This server provides tools to query documentation for the '{}' crate. \ - Use the 'query_rust_docs' tool with a specific question to get information \ - about its API, usage, and examples, derived from its official documentation.", - self.crate_name + "This server provides tools to query documentation for the following Rust crates: {}. \ + Use the appropriate tool (e.g., 'query_serde_docs', 'query_tokio_docs') with a specific question to get information \ + about each crate's API, usage, and examples, derived from their official documentation.", + crate_list.join(", ") )), } } @@ -320,11 +352,13 @@ impl ServerHandler for RustDocsServer { _request: PaginatedRequestParam, _context: RequestContext, ) -> Result { - // Example: Return the crate name as a resource + // Return resources for all crates + let resources: Vec = self.crates.keys() + .map(|crate_name| self._create_resource_text(&format!("crate://{}", crate_name), crate_name)) + .collect(); + Ok(ListResourcesResult { - resources: vec![ - self._create_resource_text(&format!("crate://{}", self.crate_name), "crate_name"), - ], + resources, next_cursor: None, }) } @@ -334,14 +368,22 @@ impl ServerHandler for RustDocsServer { request: ReadResourceRequestParam, _context: RequestContext, ) -> Result { - let expected_uri = format!("crate://{}", self.crate_name); - if request.uri == expected_uri { - Ok(ReadResourceResult { - contents: vec![ResourceContents::text( - self.crate_name.as_str(), // Explicitly get &str from Arc - &request.uri, - )], - }) + // Check if the URI matches any of our crates + if request.uri.starts_with("crate://") { + let crate_name = &request.uri[8..]; // Remove "crate://" prefix + if self.crates.contains_key(crate_name) { + Ok(ReadResourceResult { + contents: vec![ResourceContents::text( + crate_name, + &request.uri, + )], + }) + } else { + Err(McpError::resource_not_found( + format!("Crate '{}' not found", crate_name), + Some(json!({ "uri": request.uri })), + )) + } } else { Err(McpError::resource_not_found( format!("Resource URI not found: {}", request.uri), @@ -383,4 +425,61 @@ impl ServerHandler for RustDocsServer { resource_templates: Vec::new(), // No templates defined yet }) } + + async fn call_tool( + &self, + request: rmcp::model::CallToolRequestParam, + _context: RequestContext, + ) -> Result { + // Extract crate name from tool name + if let Some(crate_name) = self.extract_crate_name_from_tool(&request.name) { + // Parse the arguments for our tool + let args: QueryRustDocsArgs = serde_json::from_value(request.arguments.into()) + .map_err(|e| McpError::invalid_params(format!("Invalid arguments: {}", e), None))?; + + // Call the query method with the specific crate + self.query_crate_docs(crate_name, &args.question).await + } else { + Err(McpError::invalid_params( + format!("Tool '{}' not found", request.name), + None, + )) + } + } + + async fn list_tools( + &self, + _request: rmcp::model::PaginatedRequestParam, + _context: RequestContext, + ) -> Result { + let mut generator = schemars::r#gen::SchemaGenerator::default(); + let schema = QueryRustDocsArgs::json_schema(&mut generator); + + // Create a tool for each crate + let mut tools = Vec::new(); + for crate_name in self.crates.keys() { + let dynamic_tool_name = format!("query_{}_docs", crate_name.replace('-', "_")); + + let tool = rmcp::model::Tool { + name: dynamic_tool_name.into(), + description: format!( + "Query documentation for the '{}' crate using semantic search and LLM summarization.", + crate_name + ).into(), + input_schema: serde_json::to_value(&schema) + .map_err(|e| McpError::internal_error(format!("Failed to generate schema: {}", e), None))? + .as_object() + .cloned() + .map(Arc::new) + .unwrap_or_else(|| Arc::new(serde_json::Map::new())), + }; + + tools.push(tool); + } + + Ok(rmcp::model::ListToolsResult { + tools, + next_cursor: None, + }) + } }