diff --git a/src/apps/desktop/src/api/app_state.rs b/src/apps/desktop/src/api/app_state.rs index c00d3c45..d685697a 100644 --- a/src/apps/desktop/src/api/app_state.rs +++ b/src/apps/desktop/src/api/app_state.rs @@ -123,7 +123,9 @@ impl AppState { let mcp_service = match mcp::MCPService::new(config_service.clone()) { Ok(service) => { log::info!("MCP service initialized successfully"); - Some(Arc::new(service)) + let service = Arc::new(service); + mcp::set_global_mcp_service(service.clone()); + Some(service) } Err(e) => { log::warn!("Failed to initialize MCP service: {}", e); @@ -437,4 +439,4 @@ impl AppState { pub async fn is_remote_workspace(&self) -> bool { self.remote_workspace.read().await.is_some() } -} \ No newline at end of file +} diff --git a/src/apps/desktop/src/api/mcp_api.rs b/src/apps/desktop/src/api/mcp_api.rs index f4728e2f..d825a02d 100644 --- a/src/apps/desktop/src/api/mcp_api.rs +++ b/src/apps/desktop/src/api/mcp_api.rs @@ -1,9 +1,17 @@ //! MCP API use crate::api::app_state::AppState; +use bitfun_core::service::mcp::auth::{ + has_stored_oauth_credentials, MCPRemoteOAuthSessionSnapshot, +}; +use bitfun_core::service::mcp::config::MCPConfigService; +use bitfun_core::service::mcp::protocol::{ + MCPPrompt, MCPResource, PromptsGetResult, ResourcesReadResult, +}; use bitfun_core::service::mcp::MCPServerType; use bitfun_core::service::runtime::{RuntimeManager, RuntimeSource}; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use tauri::State; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -12,10 +20,23 @@ pub struct MCPServerInfo { pub id: String, pub name: String, pub status: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub status_message: Option, pub server_type: String, + pub transport: String, pub enabled: bool, pub auto_start: bool, #[serde(skip_serializing_if = "Option::is_none")] + pub url: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub auth_configured: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub auth_source: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub oauth_enabled: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub xaa_enabled: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub command: Option, #[serde(skip_serializing_if = "Option::is_none")] pub command_available: Option, @@ -23,6 +44,79 @@ pub struct MCPServerInfo { pub command_source: Option, #[serde(skip_serializing_if = "Option::is_none")] pub command_resolved_path: Option, + pub start_supported: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub start_disabled_reason: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ListMCPResourcesRequest { + pub server_id: String, + #[serde(default)] + pub refresh: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ReadMCPResourceRequest { + pub server_id: String, + pub resource_uri: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ListMCPPromptsRequest { + pub server_id: String, + #[serde(default)] + pub refresh: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GetMCPPromptRequest { + pub server_id: String, + pub prompt_name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option>, +} + +async fn load_mcp_resources( + mcp_service: &bitfun_core::service::mcp::MCPService, + server_id: &str, + refresh: bool, +) -> Result, String> { + let manager = mcp_service.server_manager(); + let mut resources = manager.get_cached_resources(server_id).await; + + if refresh || resources.is_empty() { + manager + .refresh_server_resource_catalog(server_id) + .await + .map_err(|e| e.to_string())?; + resources = manager.get_cached_resources(server_id).await; + } + + Ok(resources) +} + +async fn load_mcp_prompts( + mcp_service: &bitfun_core::service::mcp::MCPService, + server_id: &str, + refresh: bool, +) -> Result, String> { + let manager = mcp_service.server_manager(); + let mut prompts = manager.get_cached_prompts(server_id).await; + + if refresh || prompts.is_empty() { + manager + .refresh_server_prompt_catalog(server_id) + .await + .map_err(|e| e.to_string())?; + prompts = manager.get_cached_prompts(server_id).await; + } + + Ok(prompts) } #[tauri::command] @@ -76,46 +170,79 @@ pub async fn get_mcp_servers(state: State<'_, AppState>) -> Result "system".to_string(), - RuntimeSource::Managed => "managed".to_string(), - }) - }); - let resolved_path = runtime_manager - .as_ref() - .and_then(|manager| manager.resolve_command(&command)) - .and_then(|resolved| resolved.resolved_path); - (Some(command), available, source, resolved_path) + let transport = config.resolved_transport(); + let static_auth_configured = if matches!(config.server_type, MCPServerType::Remote) { + MCPConfigService::has_remote_authorization(&config) + } else { + false + }; + let oauth_enabled = if matches!(config.server_type, MCPServerType::Remote) { + true + } else { + false + }; + let oauth_auth_configured = if oauth_enabled { + has_stored_oauth_credentials(&config.id) + .await + .unwrap_or(false) + } else { + false + }; + + let (command, command_available, command_source, command_resolved_path) = + if transport == bitfun_core::service::mcp::MCPServerTransport::Stdio { + if let Some(command) = config.command.clone() { + let capability = runtime_manager + .as_ref() + .map(|manager| manager.get_command_capability(&command)); + let available = capability.as_ref().map(|c| c.available); + let source = capability.and_then(|c| { + c.source.map(|source| match source { + RuntimeSource::System => "system".to_string(), + RuntimeSource::Managed => "managed".to_string(), + }) + }); + let resolved_path = runtime_manager + .as_ref() + .and_then(|manager| manager.resolve_command(&command)) + .and_then(|resolved| resolved.resolved_path); + (Some(command), available, source, resolved_path) + } else { + (None, None, None, None) + } } else { (None, None, None, None) - } - } else { - (None, None, None, None) + }; + + let (start_supported, start_disabled_reason) = match config.server_type { + MCPServerType::Remote if transport.as_str() == "sse" => ( + false, + Some("Remote MCP SSE transport is not yet supported".to_string()), + ), + _ => (true, None), }; - let status = match mcp_service + let (status, status_message) = match mcp_service .server_manager() .get_server_status(&config.id) .await { - Ok(s) => format!("{:?}", s), + Ok(s) => { + let status_message = mcp_service + .server_manager() + .get_server_status_message(&config.id) + .await + .ok() + .flatten(); + (format!("{:?}", s), status_message) + } Err(_) => { if !config.enabled { - "Stopped".to_string() + ("Stopped".to_string(), None) } else if config.auto_start { - "Starting".to_string() + ("Starting".to_string(), None) } else { - "Uninitialized".to_string() + ("Uninitialized".to_string(), None) } } }; @@ -124,19 +251,121 @@ pub async fn get_mcp_servers(state: State<'_, AppState>) -> Result, + request: ListMCPResourcesRequest, +) -> Result, String> { + let mcp_service = state + .mcp_service + .as_ref() + .ok_or_else(|| "MCP service not initialized".to_string())?; + + load_mcp_resources(mcp_service.as_ref(), &request.server_id, request.refresh).await +} + +#[tauri::command] +pub async fn read_mcp_resource( + state: State<'_, AppState>, + request: ReadMCPResourceRequest, +) -> Result { + let mcp_service = state + .mcp_service + .as_ref() + .ok_or_else(|| "MCP service not initialized".to_string())?; + + let connection = mcp_service + .server_manager() + .get_connection(&request.server_id) + .await + .ok_or_else(|| format!("MCP server not connected: {}", request.server_id))?; + + connection + .read_resource(&request.resource_uri) + .await + .map_err(|e| e.to_string()) +} + +#[tauri::command] +pub async fn list_mcp_prompts( + state: State<'_, AppState>, + request: ListMCPPromptsRequest, +) -> Result, String> { + let mcp_service = state + .mcp_service + .as_ref() + .ok_or_else(|| "MCP service not initialized".to_string())?; + + load_mcp_prompts(mcp_service.as_ref(), &request.server_id, request.refresh).await +} + +#[tauri::command] +pub async fn get_mcp_prompt( + state: State<'_, AppState>, + request: GetMCPPromptRequest, +) -> Result { + let mcp_service = state + .mcp_service + .as_ref() + .ok_or_else(|| "MCP service not initialized".to_string())?; + + let connection = mcp_service + .server_manager() + .get_connection(&request.server_id) + .await + .ok_or_else(|| format!("MCP server not connected: {}", request.server_id))?; + + connection + .get_prompt(&request.prompt_name, request.arguments) + .await + .map_err(|e| e.to_string()) +} + #[tauri::command] pub async fn start_mcp_server(state: State<'_, AppState>, server_id: String) -> Result<(), String> { let mcp_service = state @@ -278,7 +507,7 @@ pub struct McpUiResourcePermissions { #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct FetchMCPAppResourceRequest { - /// MCP server ID (e.g. from tool name mcp_{server_id}_{tool_name}) + /// MCP server ID (e.g. from tool name mcp__{server_id}__{tool_name}) pub server_id: String, /// Full resource URI, e.g. "ui://my-server/widget" pub resource_uri: String, @@ -316,7 +545,7 @@ pub async fn get_mcp_tool_ui_uri( _state: State<'_, AppState>, tool_name: String, ) -> Result, String> { - if !tool_name.starts_with("mcp_") { + if !tool_name.starts_with("mcp__") { return Ok(None); } let registry = bitfun_core::agentic::tools::registry::get_global_tool_registry(); @@ -412,6 +641,65 @@ pub struct SendMCPAppMessageResponse { pub response: serde_json::Value, } +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SubmitMCPInteractionError { + #[serde(skip_serializing_if = "Option::is_none")] + pub code: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SubmitMCPInteractionResponseRequest { + pub interaction_id: String, + pub approve: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UpdateMCPRemoteAuthRequest { + pub server_id: String, + pub authorization_value: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ClearMCPRemoteAuthRequest { + pub server_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct DeleteMCPServerRequest { + pub server_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct StartMCPRemoteOAuthRequest { + pub server_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GetMCPRemoteOAuthSessionRequest { + pub server_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CancelMCPRemoteOAuthRequest { + pub server_id: String, +} + #[tauri::command] pub async fn send_mcp_app_message( state: State<'_, AppState>, @@ -486,3 +774,140 @@ pub async fn send_mcp_app_message( }); Ok(SendMCPAppMessageResponse { response }) } + +#[tauri::command] +pub async fn submit_mcp_interaction_response( + state: State<'_, AppState>, + request: SubmitMCPInteractionResponseRequest, +) -> Result<(), String> { + let mcp_service = state + .mcp_service + .as_ref() + .ok_or_else(|| "MCP service not initialized".to_string())?; + + let error_message = request.error.as_ref().and_then(|e| e.message.clone()); + let error_code = request.error.as_ref().and_then(|e| e.code); + let error_data = request.error.as_ref().and_then(|e| e.data.clone()); + + mcp_service + .server_manager() + .submit_interaction_response( + &request.interaction_id, + request.approve, + request.result, + error_message, + error_code, + error_data, + ) + .await + .map_err(|e| e.to_string())?; + + Ok(()) +} + +#[tauri::command] +pub async fn update_mcp_remote_auth( + state: State<'_, AppState>, + request: UpdateMCPRemoteAuthRequest, +) -> Result<(), String> { + let mcp_service = state + .mcp_service + .as_ref() + .ok_or_else(|| "MCP service not initialized".to_string())?; + + mcp_service + .server_manager() + .reauthenticate_remote_server(&request.server_id, &request.authorization_value) + .await + .map_err(|e| e.to_string())?; + + Ok(()) +} + +#[tauri::command] +pub async fn clear_mcp_remote_auth( + state: State<'_, AppState>, + request: ClearMCPRemoteAuthRequest, +) -> Result<(), String> { + let mcp_service = state + .mcp_service + .as_ref() + .ok_or_else(|| "MCP service not initialized".to_string())?; + + mcp_service + .server_manager() + .clear_remote_server_auth(&request.server_id) + .await + .map_err(|e| e.to_string())?; + + Ok(()) +} + +#[tauri::command] +pub async fn delete_mcp_server( + state: State<'_, AppState>, + request: DeleteMCPServerRequest, +) -> Result<(), String> { + let mcp_service = state + .mcp_service + .as_ref() + .ok_or_else(|| "MCP service not initialized".to_string())?; + + mcp_service + .server_manager() + .remove_server(&request.server_id) + .await + .map_err(|e| e.to_string())?; + + Ok(()) +} + +#[tauri::command] +pub async fn start_mcp_remote_oauth( + state: State<'_, AppState>, + request: StartMCPRemoteOAuthRequest, +) -> Result { + let mcp_service = state + .mcp_service + .as_ref() + .ok_or_else(|| "MCP service not initialized".to_string())?; + + mcp_service + .server_manager() + .start_remote_oauth_authorization(&request.server_id) + .await + .map_err(|e| e.to_string()) +} + +#[tauri::command] +pub async fn get_mcp_remote_oauth_session( + state: State<'_, AppState>, + request: GetMCPRemoteOAuthSessionRequest, +) -> Result, String> { + let mcp_service = state + .mcp_service + .as_ref() + .ok_or_else(|| "MCP service not initialized".to_string())?; + + Ok(mcp_service + .server_manager() + .get_remote_oauth_session(&request.server_id) + .await) +} + +#[tauri::command] +pub async fn cancel_mcp_remote_oauth( + state: State<'_, AppState>, + request: CancelMCPRemoteOAuthRequest, +) -> Result<(), String> { + let mcp_service = state + .mcp_service + .as_ref() + .ok_or_else(|| "MCP service not initialized".to_string())?; + + mcp_service + .server_manager() + .cancel_remote_oauth_authorization(&request.server_id) + .await + .map_err(|e| e.to_string()) +} diff --git a/src/apps/desktop/src/lib.rs b/src/apps/desktop/src/lib.rs index 6a669a73..c536ae79 100644 --- a/src/apps/desktop/src/lib.rs +++ b/src/apps/desktop/src/lib.rs @@ -527,6 +527,10 @@ pub async fn run() { initialize_mcp_servers, api::mcp_api::initialize_mcp_servers_non_destructive, get_mcp_servers, + api::mcp_api::list_mcp_resources, + api::mcp_api::read_mcp_resource, + api::mcp_api::list_mcp_prompts, + api::mcp_api::get_mcp_prompt, start_mcp_server, stop_mcp_server, restart_mcp_server, @@ -536,6 +540,13 @@ pub async fn run() { get_mcp_tool_ui_uri, fetch_mcp_app_resource, send_mcp_app_message, + submit_mcp_interaction_response, + update_mcp_remote_auth, + clear_mcp_remote_auth, + api::mcp_api::delete_mcp_server, + api::mcp_api::start_mcp_remote_oauth, + api::mcp_api::get_mcp_remote_oauth_session, + api::mcp_api::cancel_mcp_remote_oauth, lsp_initialize, lsp_start_server_for_file, lsp_stop_server, diff --git a/src/crates/core/Cargo.toml b/src/crates/core/Cargo.toml index c0834aa4..116f8f97 100644 --- a/src/crates/core/Cargo.toml +++ b/src/crates/core/Cargo.toml @@ -70,6 +70,7 @@ eventsource-stream = { workspace = true } # MCP Streamable HTTP client (official rust-sdk used by Codex) rmcp = { version = "0.12.0", default-features = false, features = [ + "auth", "base64", "client", "macros", diff --git a/src/crates/core/src/agentic/agents/registry.rs b/src/crates/core/src/agentic/agents/registry.rs index ff81df0b..d8466a3c 100644 --- a/src/crates/core/src/agentic/agents/registry.rs +++ b/src/crates/core/src/agentic/agents/registry.rs @@ -155,6 +155,28 @@ async fn get_subagent_configs() -> HashMap { } } +fn merge_dynamic_mcp_tools( + mut configured_tools: Vec, + registered_tool_names: &[String], +) -> Vec { + for tool_name in registered_tool_names { + if !tool_name.starts_with("mcp__") { + continue; + } + + if configured_tools + .iter() + .any(|existing| existing == tool_name) + { + continue; + } + + configured_tools.push(tool_name.clone()); + } + + configured_tools +} + /// Agent category #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum AgentCategory { @@ -386,13 +408,16 @@ impl AgentRegistry { match entry.category { AgentCategory::Mode => { let mode_configs = get_mode_configs().await; + let registered_tool_names = get_all_registered_tool_names().await; let valid_tools: HashSet = - get_all_registered_tool_names().await.into_iter().collect(); - resolve_effective_tools( + registered_tool_names.iter().cloned().collect(); + let resolved_tools = resolve_effective_tools( &entry.agent.default_tools(), mode_configs.get(agent_type), &valid_tools, - ) + ); + + merge_dynamic_mcp_tools(resolved_tools, ®istered_tool_names) } AgentCategory::SubAgent | AgentCategory::Hidden => entry.agent.default_tools(), } @@ -1038,7 +1063,7 @@ pub fn get_agent_registry() -> Arc { #[cfg(test)] mod tests { - use super::default_model_id_for_builtin_agent; + use super::{default_model_id_for_builtin_agent, merge_dynamic_mcp_tools}; #[test] fn top_level_modes_default_to_auto() { @@ -1052,4 +1077,27 @@ mod tests { assert_eq!(default_model_id_for_builtin_agent("Explore"), "primary"); assert_eq!(default_model_id_for_builtin_agent("CodeReview"), "primary"); } + + #[test] + fn merge_dynamic_mcp_tools_appends_registered_mcp_tools_once() { + let configured_tools = vec!["Read".to_string(), "Bash".to_string()]; + let registered_tool_names = vec![ + "Read".to_string(), + "mcp__notion__notion-search".to_string(), + "mcp__github__list_issues".to_string(), + "mcp__notion__notion-search".to_string(), + ]; + + let merged = merge_dynamic_mcp_tools(configured_tools, ®istered_tool_names); + + assert_eq!( + merged, + vec![ + "Read".to_string(), + "Bash".to_string(), + "mcp__notion__notion-search".to_string(), + "mcp__github__list_issues".to_string(), + ] + ); + } } diff --git a/src/crates/core/src/agentic/execution/execution_engine.rs b/src/crates/core/src/agentic/execution/execution_engine.rs index 9e4e48b4..3acced74 100644 --- a/src/crates/core/src/agentic/execution/execution_engine.rs +++ b/src/crates/core/src/agentic/execution/execution_engine.rs @@ -16,10 +16,10 @@ use crate::agentic::tools::framework::ToolOptions; use crate::agentic::tools::{get_all_registered_tools, SubagentParentInfo}; use crate::agentic::util::build_remote_workspace_layout_preview; use crate::agentic::{WorkspaceBackend, WorkspaceBinding}; -use crate::service::remote_ssh::workspace_state::get_remote_workspace_manager; use crate::infrastructure::ai::get_global_ai_client_factory; use crate::service::config::get_global_config_service; use crate::service::config::types::{ModelCapability, ModelCategory}; +use crate::service::remote_ssh::workspace_state::get_remote_workspace_manager; use crate::util::errors::{BitFunError, BitFunResult}; use crate::util::token_counter::TokenCounter; use crate::util::types::Message as AIMessage; @@ -664,13 +664,7 @@ impl ExecutionEngine { match self .context_compressor - .compress_turns( - session_id, - context_window, - turns.len(), - turns, - tail_policy, - ) + .compress_turns(session_id, context_window, turns.len(), turns, tail_policy) .await { Ok(compression_result) => { @@ -947,30 +941,21 @@ impl ExecutionEngine { if let Some(mgr) = get_remote_workspace_manager() { let ssh_opt = mgr.get_ssh_manager().await; let fs_opt = mgr.get_file_service().await; - let (kernel_name, hostname) = - if let Some(ref ssh) = ssh_opt { - if let Some(info) = ssh.get_server_info(cid).await { - (info.os_type, info.hostname) - } else { - ( - "Linux".to_string(), - "remote".to_string(), - ) - } + let (kernel_name, hostname) = if let Some(ref ssh) = ssh_opt { + if let Some(info) = ssh.get_server_info(cid).await { + (info.os_type, info.hostname) } else { - ( - "Linux".to_string(), - "remote".to_string(), - ) - }; - let connection_display_name = - match &ws.backend { - WorkspaceBackend::Remote { - connection_name, - .. - } => connection_name.clone(), - _ => cid.to_string(), - }; + ("Linux".to_string(), "remote".to_string()) + } + } else { + ("Linux".to_string(), "remote".to_string()) + }; + let connection_display_name = match &ws.backend { + WorkspaceBackend::Remote { + connection_name, .. + } => connection_name.clone(), + _ => cid.to_string(), + }; let remote_layout = if let Some(ref fs) = fs_opt { match build_remote_workspace_layout_preview( fs, @@ -1473,7 +1458,7 @@ impl ExecutionEngine { .await } - /// Get available tool names and definitions: 1. Tool itself is enabled 2. Allowed in mode or is MCP tool + /// Get available tool names and definitions: 1. Tool itself is enabled 2. Explicitly allowed in mode config async fn get_available_tools_and_definitions( &self, mode_allowed_tools: &[String], @@ -1529,8 +1514,7 @@ impl ExecutionEngine { } let tool_name = tool.name().to_string(); - // MCP tools are automatically allowed (all tools starting with mcp_) - if mode_allowed_tools.contains(&tool_name) || tool_name.starts_with("mcp_") { + if mode_allowed_tools.contains(&tool_name) { let description = tool .description_with_context(Some(&description_context)) .await diff --git a/src/crates/core/src/agentic/tools/implementations/mcp_tools.rs b/src/crates/core/src/agentic/tools/implementations/mcp_tools.rs new file mode 100644 index 00000000..32d7d207 --- /dev/null +++ b/src/crates/core/src/agentic/tools/implementations/mcp_tools.rs @@ -0,0 +1,683 @@ +//! Built-in MCP resource/prompt tools. + +use crate::agentic::tools::framework::{ + Tool, ToolRenderOptions, ToolResult, ToolUseContext, ValidationResult, +}; +use crate::service::mcp::adapter::PromptAdapter; +use crate::service::mcp::get_global_mcp_service; +use crate::service::mcp::protocol::{MCPPrompt, MCPResource, MCPResourceContent}; +use crate::service::mcp::MCPServerManager; +use crate::util::errors::{BitFunError, BitFunResult}; +use async_trait::async_trait; +use serde_json::{json, Value}; +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; + +const DEFAULT_RENDER_CHAR_LIMIT: usize = 32_000; + +fn tool_error(message: impl Into) -> BitFunError { + BitFunError::tool(message.into()) +} + +fn truncate_text(text: &str, max_chars: usize) -> (String, bool) { + let truncated = text.chars().count() > max_chars; + let rendered = if truncated { + text.chars().take(max_chars).collect() + } else { + text.to_string() + }; + (rendered, truncated) +} + +async fn get_mcp_server_manager() -> BitFunResult> { + get_global_mcp_service() + .map(|service| service.server_manager()) + .ok_or_else(|| tool_error("MCP service is not initialized")) +} + +async fn list_resources_for_server( + manager: &Arc, + server_id: &str, + refresh: bool, +) -> BitFunResult> { + let mut resources = manager.get_cached_resources(server_id).await; + if refresh || resources.is_empty() { + manager.refresh_server_resource_catalog(server_id).await?; + resources = manager.get_cached_resources(server_id).await; + } + Ok(resources) +} + +async fn list_prompts_for_server( + manager: &Arc, + server_id: &str, + refresh: bool, +) -> BitFunResult> { + let mut prompts = manager.get_cached_prompts(server_id).await; + if refresh || prompts.is_empty() { + manager.refresh_server_prompt_catalog(server_id).await?; + prompts = manager.get_cached_prompts(server_id).await; + } + Ok(prompts) +} + +fn validate_required_string(input: &Value, field_name: &str) -> ValidationResult { + match input.get(field_name).and_then(|value| value.as_str()) { + Some(value) if !value.trim().is_empty() => ValidationResult::default(), + Some(_) => ValidationResult { + result: false, + message: Some(format!("{} cannot be empty", field_name)), + error_code: Some(400), + meta: None, + }, + None => ValidationResult { + result: false, + message: Some(format!("{} is required", field_name)), + error_code: Some(400), + meta: None, + }, + } +} + +fn render_resource_catalog(resources: &[MCPResource]) -> String { + if resources.is_empty() { + return "No MCP resources available.".to_string(); + } + + resources + .iter() + .map(|resource| { + let mut lines = vec![format!( + "- {} ({})", + resource.title.as_deref().unwrap_or(&resource.name), + resource.uri + )]; + if resource.title.as_deref() != Some(resource.name.as_str()) { + lines.push(format!(" Name: {}", resource.name)); + } + if let Some(description) = &resource.description { + lines.push(format!(" Description: {}", description)); + } + if let Some(mime_type) = &resource.mime_type { + lines.push(format!(" MIME type: {}", mime_type)); + } + if let Some(size) = resource.size { + lines.push(format!(" Size: {} bytes", size)); + } + lines.join("\n") + }) + .collect::>() + .join("\n\n") +} + +fn render_resource_contents(contents: &[MCPResourceContent], max_chars: usize) -> String { + let mut rendered = String::new(); + let mut remaining = max_chars; + let mut truncated_any = false; + + for (index, content) in contents.iter().enumerate() { + if index > 0 { + rendered.push_str("\n\n---\n\n"); + } + + rendered.push_str(&format!("Resource URI: {}", content.uri)); + if let Some(mime_type) = &content.mime_type { + rendered.push_str(&format!("\nMIME type: {}", mime_type)); + } + + if let Some(text) = &content.content { + let slice_limit = remaining.max(1); + let (text_chunk, truncated) = truncate_text(text, slice_limit); + rendered.push_str("\n\n"); + rendered.push_str(&text_chunk); + truncated_any |= truncated; + remaining = remaining.saturating_sub(text_chunk.chars().count()); + } else if content.blob.is_some() { + rendered.push_str("\n\n[Binary resource content omitted]"); + } else { + rendered.push_str("\n\n[Empty resource content]"); + } + + if remaining == 0 { + truncated_any = true; + break; + } + } + + if truncated_any { + rendered + .push_str("\n\n[Output truncated after reaching the MCP resource tool size limit.]"); + } + + rendered +} + +fn render_prompt_catalog(prompts: &[MCPPrompt]) -> String { + if prompts.is_empty() { + return "No MCP prompts available.".to_string(); + } + + prompts + .iter() + .map(|prompt| { + let mut lines = vec![format!( + "- {}", + prompt.title.as_deref().unwrap_or(&prompt.name) + )]; + if prompt.title.as_deref() != Some(prompt.name.as_str()) { + lines.push(format!(" Name: {}", prompt.name)); + } + if let Some(description) = &prompt.description { + lines.push(format!(" Description: {}", description)); + } + if let Some(arguments) = &prompt.arguments { + if !arguments.is_empty() { + let args = arguments + .iter() + .map(|argument| { + let required = if argument.required { + "required" + } else { + "optional" + }; + match &argument.description { + Some(description) => { + format!("{} ({}, {})", argument.name, required, description) + } + None => format!("{} ({})", argument.name, required), + } + }) + .collect::>() + .join(", "); + lines.push(format!(" Arguments: {}", args)); + } + } + lines.join("\n") + }) + .collect::>() + .join("\n\n") +} + +pub struct ListMCPResourcesTool; + +impl ListMCPResourcesTool { + pub fn new() -> Self { + Self + } +} + +#[async_trait] +impl Tool for ListMCPResourcesTool { + fn name(&self) -> &str { + "ListMCPResources" + } + + async fn description(&self) -> BitFunResult { + Ok("Lists MCP resources exposed by a connected MCP server. Use this before ReadMCPResource when you need to inspect available MCP-hosted files, docs, or structured context.".to_string()) + } + + fn input_schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "server_id": { + "type": "string", + "description": "The MCP server ID to inspect." + }, + "refresh": { + "type": "boolean", + "description": "When true, refresh the server catalog before returning resources.", + "default": false + } + }, + "required": ["server_id"], + "additionalProperties": false + }) + } + + fn is_readonly(&self) -> bool { + true + } + + fn is_concurrency_safe(&self, _input: Option<&Value>) -> bool { + true + } + + fn needs_permissions(&self, _input: Option<&Value>) -> bool { + false + } + + async fn validate_input( + &self, + input: &Value, + _context: Option<&ToolUseContext>, + ) -> ValidationResult { + validate_required_string(input, "server_id") + } + + fn render_tool_use_message(&self, input: &Value, options: &ToolRenderOptions) -> String { + let server_id = input + .get("server_id") + .and_then(|value| value.as_str()) + .unwrap_or("unknown"); + if options.verbose { + format!("Listing MCP resources from server: {}", server_id) + } else { + format!("List MCP resources from {}", server_id) + } + } + + async fn call_impl( + &self, + input: &Value, + _context: &ToolUseContext, + ) -> BitFunResult> { + let server_id = input + .get("server_id") + .and_then(|value| value.as_str()) + .ok_or_else(|| tool_error("server_id is required"))?; + let refresh = input + .get("refresh") + .and_then(|value| value.as_bool()) + .unwrap_or(false); + + let manager = get_mcp_server_manager().await?; + let resources = list_resources_for_server(&manager, server_id, refresh).await?; + let count = resources.len(); + let rendered = render_resource_catalog(&resources); + + Ok(vec![ToolResult::ok( + json!({ + "server_id": server_id, + "resources": resources, + "count": count, + }), + Some(rendered), + )]) + } +} + +pub struct ReadMCPResourceTool { + max_render_chars: usize, +} + +impl ReadMCPResourceTool { + pub fn new() -> Self { + Self { + max_render_chars: DEFAULT_RENDER_CHAR_LIMIT, + } + } +} + +#[async_trait] +impl Tool for ReadMCPResourceTool { + fn name(&self) -> &str { + "ReadMCPResource" + } + + async fn description(&self) -> BitFunResult { + Ok("Reads a specific MCP resource by URI from a connected MCP server. Use ListMCPResources first if you do not already know the resource URI.".to_string()) + } + + fn input_schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "server_id": { + "type": "string", + "description": "The MCP server ID that owns the resource." + }, + "uri": { + "type": "string", + "description": "The full MCP resource URI to read." + } + }, + "required": ["server_id", "uri"], + "additionalProperties": false + }) + } + + fn is_readonly(&self) -> bool { + true + } + + fn is_concurrency_safe(&self, _input: Option<&Value>) -> bool { + true + } + + fn needs_permissions(&self, _input: Option<&Value>) -> bool { + false + } + + async fn validate_input( + &self, + input: &Value, + _context: Option<&ToolUseContext>, + ) -> ValidationResult { + let server_validation = validate_required_string(input, "server_id"); + if !server_validation.result { + return server_validation; + } + validate_required_string(input, "uri") + } + + fn render_tool_use_message(&self, input: &Value, options: &ToolRenderOptions) -> String { + let uri = input + .get("uri") + .and_then(|value| value.as_str()) + .unwrap_or("unknown"); + if options.verbose { + format!("Reading MCP resource: {}", uri) + } else { + format!("Read MCP resource {}", uri) + } + } + + async fn call_impl( + &self, + input: &Value, + _context: &ToolUseContext, + ) -> BitFunResult> { + let server_id = input + .get("server_id") + .and_then(|value| value.as_str()) + .ok_or_else(|| tool_error("server_id is required"))?; + let uri = input + .get("uri") + .and_then(|value| value.as_str()) + .ok_or_else(|| tool_error("uri is required"))?; + + let manager = get_mcp_server_manager().await?; + let connection = manager + .get_connection(server_id) + .await + .ok_or_else(|| tool_error(format!("MCP server not connected: {}", server_id)))?; + let result = connection.read_resource(uri).await?; + let content_count = result.contents.len(); + let rendered = render_resource_contents(&result.contents, self.max_render_chars); + + Ok(vec![ToolResult::ok( + json!({ + "server_id": server_id, + "uri": uri, + "contents": result.contents, + "content_count": content_count, + }), + Some(rendered), + )]) + } +} + +pub struct ListMCPPromptsTool; + +impl ListMCPPromptsTool { + pub fn new() -> Self { + Self + } +} + +#[async_trait] +impl Tool for ListMCPPromptsTool { + fn name(&self) -> &str { + "ListMCPPrompts" + } + + async fn description(&self) -> BitFunResult { + Ok("Lists MCP prompts exposed by a connected MCP server. Use this before GetMCPPrompt when you need reusable server-provided prompt templates.".to_string()) + } + + fn input_schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "server_id": { + "type": "string", + "description": "The MCP server ID to inspect." + }, + "refresh": { + "type": "boolean", + "description": "When true, refresh the server catalog before returning prompts.", + "default": false + } + }, + "required": ["server_id"], + "additionalProperties": false + }) + } + + fn is_readonly(&self) -> bool { + true + } + + fn is_concurrency_safe(&self, _input: Option<&Value>) -> bool { + true + } + + fn needs_permissions(&self, _input: Option<&Value>) -> bool { + false + } + + async fn validate_input( + &self, + input: &Value, + _context: Option<&ToolUseContext>, + ) -> ValidationResult { + validate_required_string(input, "server_id") + } + + fn render_tool_use_message(&self, input: &Value, options: &ToolRenderOptions) -> String { + let server_id = input + .get("server_id") + .and_then(|value| value.as_str()) + .unwrap_or("unknown"); + if options.verbose { + format!("Listing MCP prompts from server: {}", server_id) + } else { + format!("List MCP prompts from {}", server_id) + } + } + + async fn call_impl( + &self, + input: &Value, + _context: &ToolUseContext, + ) -> BitFunResult> { + let server_id = input + .get("server_id") + .and_then(|value| value.as_str()) + .ok_or_else(|| tool_error("server_id is required"))?; + let refresh = input + .get("refresh") + .and_then(|value| value.as_bool()) + .unwrap_or(false); + + let manager = get_mcp_server_manager().await?; + let prompts = list_prompts_for_server(&manager, server_id, refresh).await?; + let count = prompts.len(); + let rendered = render_prompt_catalog(&prompts); + + Ok(vec![ToolResult::ok( + json!({ + "server_id": server_id, + "prompts": prompts, + "count": count, + }), + Some(rendered), + )]) + } +} + +pub struct GetMCPPromptTool { + max_render_chars: usize, +} + +impl GetMCPPromptTool { + pub fn new() -> Self { + Self { + max_render_chars: DEFAULT_RENDER_CHAR_LIMIT, + } + } +} + +#[async_trait] +impl Tool for GetMCPPromptTool { + fn name(&self) -> &str { + "GetMCPPrompt" + } + + async fn description(&self) -> BitFunResult { + Ok("Fetches a named MCP prompt template from a connected MCP server and renders it into plain text for the model. Pass prompt arguments when the server requires them.".to_string()) + } + + fn input_schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "server_id": { + "type": "string", + "description": "The MCP server ID that owns the prompt." + }, + "name": { + "type": "string", + "description": "The MCP prompt name." + }, + "arguments": { + "type": "object", + "description": "Optional string arguments for the prompt template.", + "additionalProperties": { + "type": "string" + } + } + }, + "required": ["server_id", "name"], + "additionalProperties": false + }) + } + + fn is_readonly(&self) -> bool { + true + } + + fn is_concurrency_safe(&self, _input: Option<&Value>) -> bool { + true + } + + fn needs_permissions(&self, _input: Option<&Value>) -> bool { + false + } + + async fn validate_input( + &self, + input: &Value, + _context: Option<&ToolUseContext>, + ) -> ValidationResult { + let server_validation = validate_required_string(input, "server_id"); + if !server_validation.result { + return server_validation; + } + + let name_validation = validate_required_string(input, "name"); + if !name_validation.result { + return name_validation; + } + + if let Some(arguments) = input.get("arguments") { + let Some(object) = arguments.as_object() else { + return ValidationResult { + result: false, + message: Some("arguments must be an object".to_string()), + error_code: Some(400), + meta: None, + }; + }; + + let invalid_keys = object + .iter() + .filter_map(|(key, value)| (!value.is_string()).then_some(key.clone())) + .collect::>(); + if !invalid_keys.is_empty() { + return ValidationResult { + result: false, + message: Some(format!( + "arguments values must be strings: {}", + invalid_keys.into_iter().collect::>().join(", ") + )), + error_code: Some(400), + meta: None, + }; + } + } + + ValidationResult::default() + } + + fn render_tool_use_message(&self, input: &Value, options: &ToolRenderOptions) -> String { + let name = input + .get("name") + .and_then(|value| value.as_str()) + .unwrap_or("unknown"); + if options.verbose { + format!("Fetching MCP prompt: {}", name) + } else { + format!("Get MCP prompt {}", name) + } + } + + async fn call_impl( + &self, + input: &Value, + _context: &ToolUseContext, + ) -> BitFunResult> { + let server_id = input + .get("server_id") + .and_then(|value| value.as_str()) + .ok_or_else(|| tool_error("server_id is required"))?; + let name = input + .get("name") + .and_then(|value| value.as_str()) + .ok_or_else(|| tool_error("name is required"))?; + + let arguments = input.get("arguments").and_then(|value| { + value.as_object().map(|object| { + object + .iter() + .filter_map(|(key, value)| { + value + .as_str() + .map(|string| (key.clone(), string.to_string())) + }) + .collect::>() + }) + }); + + let manager = get_mcp_server_manager().await?; + let connection = manager + .get_connection(server_id) + .await + .ok_or_else(|| tool_error(format!("MCP server not connected: {}", server_id)))?; + let result = connection.get_prompt(name, arguments.clone()).await?; + let prompt_text = + PromptAdapter::to_system_prompt(&crate::service::mcp::protocol::MCPPromptContent { + name: name.to_string(), + messages: result.messages.clone(), + }); + let (rendered_text, truncated) = truncate_text(&prompt_text, self.max_render_chars); + let mut rendered = rendered_text; + if truncated { + rendered + .push_str("\n\n[Output truncated after reaching the MCP prompt tool size limit.]"); + } + + Ok(vec![ToolResult::ok( + json!({ + "server_id": server_id, + "name": name, + "arguments": arguments, + "description": result.description, + "messages": result.messages, + "prompt_text": prompt_text, + }), + Some(rendered), + )]) + } +} diff --git a/src/crates/core/src/agentic/tools/implementations/mod.rs b/src/crates/core/src/agentic/tools/implementations/mod.rs index 91b1ea28..0435df36 100644 --- a/src/crates/core/src/agentic/tools/implementations/mod.rs +++ b/src/crates/core/src/agentic/tools/implementations/mod.rs @@ -23,6 +23,7 @@ pub mod grep_tool; pub mod log_tool; pub mod ls_tool; pub mod mermaid_interactive_tool; +pub mod mcp_tools; pub mod miniapp_init_tool; pub mod session_control_tool; pub mod session_message_tool; @@ -55,6 +56,9 @@ pub use grep_tool::GrepTool; pub use log_tool::LogTool; pub use ls_tool::LSTool; pub use mermaid_interactive_tool::MermaidInteractiveTool; +pub use mcp_tools::{ + GetMCPPromptTool, ListMCPPromptsTool, ListMCPResourcesTool, ReadMCPResourceTool, +}; pub use miniapp_init_tool::InitMiniAppTool; pub use session_control_tool::SessionControlTool; pub use session_message_tool::SessionMessageTool; diff --git a/src/crates/core/src/agentic/tools/registry.rs b/src/crates/core/src/agentic/tools/registry.rs index a0915c7c..d4799989 100644 --- a/src/crates/core/src/agentic/tools/registry.rs +++ b/src/crates/core/src/agentic/tools/registry.rs @@ -64,7 +64,7 @@ impl ToolRegistry { /// Remove all tools from the MCP server pub fn unregister_mcp_server_tools(&mut self, server_id: &str) { - let prefix = format!("mcp_{}_", server_id); + let prefix = format!("mcp__{}__", server_id); let to_remove: Vec = self .tools .keys() @@ -112,6 +112,10 @@ impl ToolRegistry { // Web tool self.register_tool(Arc::new(WebSearchTool::new())); self.register_tool(Arc::new(WebFetchTool::new())); + self.register_tool(Arc::new(ListMCPResourcesTool::new())); + self.register_tool(Arc::new(ReadMCPResourceTool::new())); + self.register_tool(Arc::new(ListMCPPromptsTool::new())); + self.register_tool(Arc::new(GetMCPPromptTool::new())); // Mermaid interactive chart tool self.register_tool(Arc::new(MermaidInteractiveTool::new())); diff --git a/src/crates/core/src/service/mcp/adapter/context.rs b/src/crates/core/src/service/mcp/adapter/context.rs index 7fd4738d..eb286bf2 100644 --- a/src/crates/core/src/service/mcp/adapter/context.rs +++ b/src/crates/core/src/service/mcp/adapter/context.rs @@ -184,12 +184,23 @@ impl MCPContextProvider { BitFunError::NotFound(format!("MCP server connection not found: {}", server_id)) })?; - let result = connection.list_resources(None).await?; + let mut resources = manager.get_cached_resources(server_id).await; + if resources.is_empty() { + if let Err(e) = manager.refresh_server_resource_catalog(server_id).await { + debug!( + "Failed to refresh resources catalog cache; falling back to direct list: server_id={} error={}", + server_id, e + ); + } + resources = manager.get_cached_resources(server_id).await; + } + + if resources.is_empty() { + resources = connection.list_resources(None).await?.resources; + } let relevant = ResourceAdapter::filter_and_rank( - result.resources, - query, - 0.1, // Lower threshold; we do additional filtering later + resources, query, 0.1, // Lower threshold; we do additional filtering later 50, // Up to 50 per server ); @@ -252,21 +263,33 @@ impl MCPContextProvider { for server_id in server_ids { if let Some(connection) = self.server_manager.get_connection(&server_id).await { - if let Ok(result) = connection.list_prompts(None).await { - for prompt in result.prompts { - if prompt_names.contains(&prompt.name) { - if let Ok(content) = connection - .get_prompt(&prompt.name, Some(arguments.clone())) - .await - { - let text = super::prompt::PromptAdapter::to_system_prompt( - &crate::service::mcp::protocol::MCPPromptContent { - name: prompt.name.clone(), - messages: content.messages, - }, - ); - enhancements.push(text); - } + let mut prompts = self.server_manager.get_cached_prompts(&server_id).await; + if prompts.is_empty() { + let _ = self + .server_manager + .refresh_server_prompt_catalog(&server_id) + .await; + prompts = self.server_manager.get_cached_prompts(&server_id).await; + } + if prompts.is_empty() { + if let Ok(result) = connection.list_prompts(None).await { + prompts = result.prompts; + } + } + + for prompt in prompts { + if prompt_names.contains(&prompt.name) { + if let Ok(content) = connection + .get_prompt(&prompt.name, Some(arguments.clone())) + .await + { + let text = super::prompt::PromptAdapter::to_system_prompt( + &crate::service::mcp::protocol::MCPPromptContent { + name: prompt.name.clone(), + messages: content.messages, + }, + ); + enhancements.push(text); } } } diff --git a/src/crates/core/src/service/mcp/adapter/mod.rs b/src/crates/core/src/service/mcp/adapter/mod.rs index 40f7725d..e6ad72ab 100644 --- a/src/crates/core/src/service/mcp/adapter/mod.rs +++ b/src/crates/core/src/service/mcp/adapter/mod.rs @@ -2,10 +2,10 @@ //! //! Adapts MCP resources, prompts, and tools to BitFun's agentic system. -pub mod context; -pub mod prompt; -pub mod resource; -pub mod tool; +mod context; +mod prompt; +mod resource; +mod tool; pub use context::{ContextEnhancer, MCPContextProvider}; pub use prompt::PromptAdapter; diff --git a/src/crates/core/src/service/mcp/adapter/tool.rs b/src/crates/core/src/service/mcp/adapter/tool.rs index 6ce1f2d1..11a93196 100644 --- a/src/crates/core/src/service/mcp/adapter/tool.rs +++ b/src/crates/core/src/service/mcp/adapter/tool.rs @@ -6,7 +6,7 @@ use crate::agentic::tools::framework::{ Tool, ToolRenderOptions, ToolResult, ToolUseContext, ValidationResult, }; use crate::service::mcp::protocol::{MCPTool, MCPToolResult}; -use crate::service::mcp::server::connection::MCPConnection; +use crate::service::mcp::server::MCPConnection; use crate::util::errors::BitFunResult; use async_trait::async_trait; use log::{debug, error, info, warn}; @@ -22,6 +22,8 @@ pub struct MCPToolWrapper { } impl MCPToolWrapper { + const MAX_RESULT_TEXT_CHARS: usize = 12_000; + /// Creates a new MCP tool wrapper. pub fn new( mcp_tool: MCPTool, @@ -29,7 +31,7 @@ impl MCPToolWrapper { server_id: String, server_name: String, ) -> Self { - let full_name = format!("mcp_{}_{}", server_id, mcp_tool.name); + let full_name = format!("mcp__{}__{}", server_id, mcp_tool.name); Self { mcp_tool, connection, @@ -37,23 +39,73 @@ impl MCPToolWrapper { full_name, } } + + fn annotations(&self) -> crate::service::mcp::protocol::MCPToolAnnotations { + self.mcp_tool.annotations.clone().unwrap_or_default() + } + + fn tool_title(&self) -> String { + self.mcp_tool + .annotations + .as_ref() + .and_then(|annotations| annotations.title.clone()) + .or_else(|| self.mcp_tool.title.clone()) + .unwrap_or_else(|| self.mcp_tool.name.clone()) + } + + fn behavior_hints(&self) -> Vec<&'static str> { + let annotations = self.annotations(); + let mut hints = Vec::new(); + if annotations.read_only_hint.unwrap_or(false) { + hints.push("read-only"); + } + if annotations.destructive_hint.unwrap_or(false) { + hints.push("destructive"); + } + if annotations.open_world_hint.unwrap_or(false) { + hints.push("open-world"); + } + hints + } + + fn truncate_for_assistant(text: String) -> String { + let char_count = text.chars().count(); + if char_count <= Self::MAX_RESULT_TEXT_CHARS { + return text; + } + + let truncated: String = text.chars().take(Self::MAX_RESULT_TEXT_CHARS).collect(); + format!( + "{}\n[Result truncated: {} of {} characters shown]", + truncated, + Self::MAX_RESULT_TEXT_CHARS, + char_count + ) + } } #[async_trait] impl Tool for MCPToolWrapper { fn name(&self) -> &str { // Use server_id as a prefix to avoid naming conflicts. - // Example: mcp_github_search_repos + // Example: mcp__github__search_repos &self.full_name } async fn description(&self) -> BitFunResult { - Ok(format!( + let mut description = format!( "Tool '{}' from MCP server '{}': {}", - self.mcp_tool.name, + self.tool_title(), self.server_name, self.mcp_tool.description.as_deref().unwrap_or("") - )) + ); + + let hints = self.behavior_hints(); + if !hints.is_empty() { + description.push_str(&format!(" [Hints: {}]", hints.join(", "))); + } + + Ok(description) } fn input_schema(&self) -> Value { @@ -69,7 +121,7 @@ impl Tool for MCPToolWrapper { } fn user_facing_name(&self) -> String { - format!("{} ({})", self.mcp_tool.name, self.server_name) + format!("{} ({})", self.tool_title(), self.server_name) } async fn is_enabled(&self) -> bool { @@ -77,17 +129,15 @@ impl Tool for MCPToolWrapper { } fn is_readonly(&self) -> bool { - // MCP tools are non-readonly by default (requires permission confirmation). - false + self.annotations().read_only_hint.unwrap_or(false) } fn is_concurrency_safe(&self, _input: Option<&Value>) -> bool { - false + self.is_readonly() } fn needs_permissions(&self, _input: Option<&Value>) -> bool { - // MCP tools require permissions by default. - true + !self.is_readonly() } async fn validate_input( @@ -119,7 +169,7 @@ impl Tool for MCPToolWrapper { } if let Some(contents) = result.content { - return contents + let rendered = contents .iter() .map(|c| match c { crate::service::mcp::protocol::MCPToolResultContent::Text { text } => { @@ -149,6 +199,11 @@ impl Tool for MCPToolWrapper { }) .collect::>() .join("\n"); + return Self::truncate_for_assistant(rendered); + } + + if let Some(structured_content) = result.structured_content { + return Self::truncate_for_assistant(structured_content.to_string()); } } @@ -158,21 +213,24 @@ impl Tool for MCPToolWrapper { fn render_tool_use_message(&self, input: &Value, _options: &ToolRenderOptions) -> String { format!( "Using MCP tool '{}' from '{}' with input: {}", - self.mcp_tool.name, self.server_name, input + self.tool_title(), + self.server_name, + input ) } fn render_tool_use_rejected_message(&self) -> String { format!( "MCP tool '{}' from '{}' was rejected by user", - self.mcp_tool.name, self.server_name + self.tool_title(), + self.server_name ) } fn render_tool_result_message(&self, output: &Value) -> String { format!( "MCP tool '{}' completed. Result: {}", - self.mcp_tool.name, + self.tool_title(), self.render_result_for_assistant(output) ) } @@ -184,7 +242,8 @@ impl Tool for MCPToolWrapper { ) -> BitFunResult> { info!( "Calling MCP tool: {} from server: {}", - self.mcp_tool.name, self.server_name + self.tool_title(), + self.server_name ); debug!( "Input: {}", diff --git a/src/crates/core/src/service/mcp/auth.rs b/src/crates/core/src/service/mcp/auth.rs new file mode 100644 index 00000000..f105ec01 --- /dev/null +++ b/src/crates/core/src/service/mcp/auth.rs @@ -0,0 +1,431 @@ +//! OAuth support for remote MCP servers. + +use aes_gcm::aead::{Aead, KeyInit}; +use aes_gcm::{Aes256Gcm, Nonce}; +use anyhow::{Context, Result}; +use async_trait::async_trait; +use base64::{Engine, engine::general_purpose::STANDARD as B64}; +use rand::RngCore; +use rmcp::transport::auth::{ + AuthorizationManager, CredentialStore, OAuthState, StoredCredentials, +}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::path::PathBuf; +use tokio::net::TcpListener; +use tokio::sync::Mutex; + +use crate::infrastructure::filesystem::path_manager::try_get_path_manager_arc; +use crate::service::mcp::server::{MCPServerConfig, MCPServerOAuthConfig}; +use crate::util::errors::{BitFunError, BitFunResult}; + +const NONCE_LEN: usize = 12; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "camelCase")] +pub enum MCPRemoteOAuthStatus { + AwaitingBrowser, + AwaitingCallback, + ExchangingToken, + Authorized, + Failed, + Cancelled, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct MCPRemoteOAuthSessionSnapshot { + pub server_id: String, + pub status: MCPRemoteOAuthStatus, + #[serde(skip_serializing_if = "Option::is_none")] + pub authorization_url: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub redirect_uri: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, +} + +impl MCPRemoteOAuthSessionSnapshot { + pub fn new( + server_id: impl Into, + status: MCPRemoteOAuthStatus, + authorization_url: Option, + redirect_uri: Option, + message: Option, + ) -> Self { + Self { + server_id: server_id.into(), + status, + authorization_url, + redirect_uri, + message, + } + } +} + +pub struct PreparedMCPRemoteOAuthAuthorization { + pub state: OAuthState, + pub listener: TcpListener, + pub authorization_url: String, + pub redirect_uri: String, +} + +#[derive(Serialize, Deserialize, Default)] +struct VaultFile { + entries: HashMap, +} + +pub struct MCPRemoteOAuthCredentialVault { + key_path: PathBuf, + vault_path: PathBuf, + lock: Mutex<()>, +} + +impl MCPRemoteOAuthCredentialVault { + pub fn new() -> BitFunResult { + let data_dir = try_get_path_manager_arc()?.user_data_dir(); + Ok(Self { + key_path: data_dir.join(".mcp_oauth_vault.key"), + vault_path: data_dir.join("mcp_oauth_vault.json"), + lock: Mutex::new(()), + }) + } + + async fn ensure_key(&self) -> Result<[u8; 32]> { + if self.key_path.exists() { + let bytes = tokio::fs::read(&self.key_path) + .await + .context("read MCP OAuth vault key")?; + if bytes.len() != 32 { + anyhow::bail!("invalid MCP OAuth vault key length"); + } + let mut key = [0u8; 32]; + key.copy_from_slice(&bytes); + return Ok(key); + } + + if let Some(parent) = self.key_path.parent() { + tokio::fs::create_dir_all(parent).await?; + } + + let mut key = [0u8; 32]; + rand::rngs::OsRng.fill_bytes(&mut key); + tokio::fs::write(&self.key_path, key.as_slice()) + .await + .context("write MCP OAuth vault key")?; + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let _ = std::fs::set_permissions( + &self.key_path, + std::fs::Permissions::from_mode(0o600), + ); + } + + Ok(key) + } + + fn encrypt_value(key: &[u8; 32], plaintext: &str) -> Result { + let cipher = Aes256Gcm::new_from_slice(key).map_err(|e| anyhow::anyhow!("{}", e))?; + let mut nonce = [0u8; NONCE_LEN]; + rand::rngs::OsRng.fill_bytes(&mut nonce); + let ciphertext = cipher + .encrypt(Nonce::from_slice(&nonce), plaintext.as_bytes()) + .map_err(|e| anyhow::anyhow!("encrypt: {}", e))?; + + let mut blob = Vec::with_capacity(NONCE_LEN + ciphertext.len()); + blob.extend_from_slice(&nonce); + blob.extend_from_slice(&ciphertext); + Ok(B64.encode(blob)) + } + + fn decrypt_value(key: &[u8; 32], blob_b64: &str) -> Result { + let blob = B64 + .decode(blob_b64) + .context("base64 decode MCP OAuth vault entry")?; + if blob.len() <= NONCE_LEN { + anyhow::bail!("MCP OAuth vault entry too short"); + } + + let (nonce, ciphertext) = blob.split_at(NONCE_LEN); + let cipher = Aes256Gcm::new_from_slice(key).map_err(|e| anyhow::anyhow!("{}", e))?; + let plaintext = cipher + .decrypt(Nonce::from_slice(nonce), ciphertext) + .map_err(|e| anyhow::anyhow!("decrypt: {}", e))?; + String::from_utf8(plaintext).context("utf8 decode MCP OAuth vault entry") + } + + pub async fn load(&self, server_id: &str) -> Result> { + let _guard = self.lock.lock().await; + if !self.key_path.exists() || !self.vault_path.exists() { + return Ok(None); + } + + let bytes = tokio::fs::read(&self.key_path) + .await + .context("read MCP OAuth vault key")?; + if bytes.len() != 32 { + anyhow::bail!("invalid MCP OAuth vault key length"); + } + let mut key = [0u8; 32]; + key.copy_from_slice(&bytes); + + let body = tokio::fs::read_to_string(&self.vault_path) + .await + .unwrap_or_default(); + let file: VaultFile = serde_json::from_str(&body).unwrap_or_default(); + let Some(entry) = file.entries.get(server_id) else { + return Ok(None); + }; + + let plaintext = match Self::decrypt_value(&key, entry) { + Ok(plaintext) => plaintext, + Err(error) => { + log::warn!( + "Failed to decrypt MCP OAuth credentials for server {}: {}", + server_id, + error + ); + return Ok(None); + } + }; + + Ok(Some(serde_json::from_str(&plaintext)?)) + } + + pub async fn store(&self, server_id: &str, credentials: &StoredCredentials) -> Result<()> { + let _guard = self.lock.lock().await; + let key = self.ensure_key().await?; + + let mut file: VaultFile = if self.vault_path.exists() { + let body = tokio::fs::read_to_string(&self.vault_path) + .await + .unwrap_or_default(); + serde_json::from_str(&body).unwrap_or_default() + } else { + VaultFile::default() + }; + + let plaintext = serde_json::to_string(credentials)?; + let encrypted = Self::encrypt_value(&key, &plaintext)?; + file.entries.insert(server_id.to_string(), encrypted); + + if let Some(parent) = self.vault_path.parent() { + tokio::fs::create_dir_all(parent).await?; + } + + tokio::fs::write(&self.vault_path, serde_json::to_string_pretty(&file)?) + .await + .context("write MCP OAuth vault")?; + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let _ = std::fs::set_permissions( + &self.vault_path, + std::fs::Permissions::from_mode(0o600), + ); + } + + Ok(()) + } + + pub async fn clear(&self, server_id: &str) -> Result<()> { + let _guard = self.lock.lock().await; + if !self.vault_path.exists() { + return Ok(()); + } + + let body = tokio::fs::read_to_string(&self.vault_path) + .await + .unwrap_or_default(); + let mut file: VaultFile = serde_json::from_str(&body).unwrap_or_default(); + file.entries.remove(server_id); + + if file.entries.is_empty() { + let _ = tokio::fs::remove_file(&self.vault_path).await; + } else { + tokio::fs::write(&self.vault_path, serde_json::to_string_pretty(&file)?).await?; + } + + Ok(()) + } +} + +#[derive(Clone)] +pub struct MCPRemoteOAuthCredentialStore { + server_id: String, +} + +impl MCPRemoteOAuthCredentialStore { + pub fn new(server_id: impl Into) -> Self { + Self { + server_id: server_id.into(), + } + } +} + +#[async_trait] +impl CredentialStore for MCPRemoteOAuthCredentialStore { + async fn load(&self) -> Result, rmcp::transport::auth::AuthError> { + MCPRemoteOAuthCredentialVault::new() + .map_err(|error| rmcp::transport::auth::AuthError::InternalError(error.to_string()))? + .load(&self.server_id) + .await + .map_err(|error| rmcp::transport::auth::AuthError::InternalError(error.to_string())) + } + + async fn save( + &self, + credentials: StoredCredentials, + ) -> Result<(), rmcp::transport::auth::AuthError> { + MCPRemoteOAuthCredentialVault::new() + .map_err(|error| rmcp::transport::auth::AuthError::InternalError(error.to_string()))? + .store(&self.server_id, &credentials) + .await + .map_err(|error| rmcp::transport::auth::AuthError::InternalError(error.to_string())) + } + + async fn clear(&self) -> Result<(), rmcp::transport::auth::AuthError> { + MCPRemoteOAuthCredentialVault::new() + .map_err(|error| rmcp::transport::auth::AuthError::InternalError(error.to_string()))? + .clear(&self.server_id) + .await + .map_err(|error| rmcp::transport::auth::AuthError::InternalError(error.to_string())) + } +} + +pub fn map_auth_error(error: impl ToString) -> BitFunError { + BitFunError::MCPError(format!("OAuth error: {}", error.to_string())) +} + +pub async fn has_stored_oauth_credentials(server_id: &str) -> BitFunResult { + let store = MCPRemoteOAuthCredentialStore::new(server_id.to_string()); + let credentials = store.load().await.map_err(map_auth_error)?; + Ok(credentials.and_then(|entry| entry.token_response).is_some()) +} + +pub async fn clear_stored_oauth_credentials(server_id: &str) -> BitFunResult<()> { + MCPRemoteOAuthCredentialStore::new(server_id.to_string()) + .clear() + .await + .map_err(map_auth_error) +} + +pub async fn build_authorization_manager( + server_id: &str, + server_url: &str, +) -> BitFunResult<(AuthorizationManager, bool)> { + let mut manager = AuthorizationManager::new(server_url) + .await + .map_err(map_auth_error)?; + manager.set_credential_store(MCPRemoteOAuthCredentialStore::new(server_id.to_string())); + let initialized = manager + .initialize_from_store() + .await + .map_err(map_auth_error)?; + Ok((manager, initialized)) +} + +fn normalize_callback_host(config: &MCPServerOAuthConfig) -> String { + config + .callback_host + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + .unwrap_or("127.0.0.1") + .to_string() +} + +fn normalize_callback_path(config: &MCPServerOAuthConfig) -> String { + let path = config + .callback_path + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + .unwrap_or("/oauth/callback"); + + if path.starts_with('/') { + path.to_string() + } else { + format!("/{}", path) + } +} + +fn effective_oauth_config(config: &MCPServerConfig) -> MCPServerOAuthConfig { + let mut oauth = config.oauth.clone().unwrap_or_default(); + if oauth.client_name.is_none() { + oauth.client_name = Some(format!("BitFun MCP Client ({})", config.name)); + } + oauth +} + +pub async fn prepare_remote_oauth_authorization( + config: &MCPServerConfig, +) -> BitFunResult { + let oauth = effective_oauth_config(config); + let server_url = config.url.as_deref().ok_or_else(|| { + BitFunError::Configuration(format!( + "Remote MCP server '{}' must have a URL for OAuth", + config.id + )) + })?; + + let host = normalize_callback_host(&oauth); + let listener = TcpListener::bind((host.as_str(), oauth.callback_port.unwrap_or(0))) + .await + .map_err(|error| { + BitFunError::MCPError(format!( + "Failed to bind OAuth callback listener for server '{}': {}", + config.id, error + )) + })?; + let port = listener + .local_addr() + .map_err(|error| { + BitFunError::MCPError(format!( + "Failed to resolve OAuth callback listener for server '{}': {}", + config.id, error + )) + })? + .port(); + let redirect_uri = format!("http://{}:{}{}", host, port, normalize_callback_path(&oauth)); + + let scopes = oauth.scopes.iter().map(String::as_str).collect::>(); + let mut state = OAuthState::new(server_url, None) + .await + .map_err(map_auth_error)?; + if let OAuthState::Unauthorized(manager) = &mut state { + manager.set_credential_store(MCPRemoteOAuthCredentialStore::new(config.id.clone())); + } + + match oauth.client_metadata_url.as_deref() { + Some(client_metadata_url) => { + state + .start_authorization_with_metadata_url( + &scopes, + &redirect_uri, + oauth.client_name.as_deref(), + Some(client_metadata_url), + ) + .await + .map_err(map_auth_error)?; + } + None => { + state + .start_authorization(&scopes, &redirect_uri, oauth.client_name.as_deref()) + .await + .map_err(map_auth_error)?; + } + } + + let authorization_url = state.get_authorization_url().await.map_err(map_auth_error)?; + + Ok(PreparedMCPRemoteOAuthAuthorization { + state, + listener, + authorization_url, + redirect_uri, + }) +} diff --git a/src/crates/core/src/service/mcp/config/cursor_format.rs b/src/crates/core/src/service/mcp/config/cursor_format.rs index 959979cd..955ea22d 100644 --- a/src/crates/core/src/service/mcp/config/cursor_format.rs +++ b/src/crates/core/src/service/mcp/config/cursor_format.rs @@ -1,16 +1,54 @@ use log::warn; -use crate::service::mcp::server::{MCPServerConfig, MCPServerType}; +use crate::service::mcp::server::{MCPServerConfig, MCPServerTransport, MCPServerType}; use crate::util::errors::BitFunResult; use super::ConfigLocation; +fn parse_source(value: &str) -> Option { + match value.trim() { + "local" => Some(MCPServerType::Local), + "remote" => Some(MCPServerType::Remote), + _ => None, + } +} + +fn parse_transport(value: &str) -> Option { + match value.trim() { + "stdio" => Some(MCPServerTransport::Stdio), + "sse" => Some(MCPServerTransport::Sse), + "http" | "streamable_http" | "streamable-http" | "streamablehttp" => { + Some(MCPServerTransport::StreamableHttp) + } + _ => None, + } +} + +fn parse_legacy_type(value: &str) -> Option<(Option, Option)> { + match value.trim() { + "stdio" => Some((None, Some(MCPServerTransport::Stdio))), + "local" => Some((Some(MCPServerType::Local), Some(MCPServerTransport::Stdio))), + "sse" => Some((Some(MCPServerType::Remote), Some(MCPServerTransport::Sse))), + "remote" => Some(( + Some(MCPServerType::Remote), + Some(MCPServerTransport::StreamableHttp), + )), + "http" | "streamable_http" | "streamable-http" | "streamablehttp" => Some(( + Some(MCPServerType::Remote), + Some(MCPServerTransport::StreamableHttp), + )), + _ => None, + } +} + pub(super) fn config_to_cursor_format(config: &MCPServerConfig) -> serde_json::Value { let mut cursor_config = serde_json::Map::new(); - let type_str = match config.server_type { - MCPServerType::Local | MCPServerType::Container => "stdio", - MCPServerType::Remote => "streamable-http", + let type_str = match (config.server_type, config.resolved_transport()) { + (MCPServerType::Local, _) => "stdio", + (MCPServerType::Remote, MCPServerTransport::Sse) => "sse", + (MCPServerType::Remote, MCPServerTransport::StreamableHttp) => "streamable-http", + (MCPServerType::Remote, MCPServerTransport::Stdio) => "streamable-http", }; cursor_config.insert("type".to_string(), serde_json::json!(type_str)); @@ -44,6 +82,14 @@ pub(super) fn config_to_cursor_format(config: &MCPServerConfig) -> serde_json::V cursor_config.insert("url".to_string(), serde_json::json!(url)); } + if let Some(oauth) = &config.oauth { + cursor_config.insert("oauth".to_string(), serde_json::json!(oauth)); + } + + if let Some(xaa) = &config.xaa { + cursor_config.insert("xaa".to_string(), serde_json::json!(xaa)); + } + serde_json::Value::Object(cursor_config) } @@ -55,25 +101,6 @@ pub(super) fn parse_cursor_format( if let Some(mcp_servers) = config.get("mcpServers").and_then(|v| v.as_object()) { for (server_id, server_config) in mcp_servers { if let Some(obj) = server_config.as_object() { - let server_type = match obj.get("type").and_then(|v| v.as_str()) { - Some("stdio") => MCPServerType::Local, - Some("sse") => MCPServerType::Remote, - Some("streamable-http") => MCPServerType::Remote, - Some("streamable_http") => MCPServerType::Remote, - Some("streamablehttp") => MCPServerType::Remote, - Some("remote") => MCPServerType::Remote, - Some("http") => MCPServerType::Remote, - Some("local") => MCPServerType::Local, - Some("container") => MCPServerType::Container, - _ => { - if obj.contains_key("url") { - MCPServerType::Remote - } else { - MCPServerType::Local - } - } - }; - let command = obj .get("command") .and_then(|v| v.as_str()) @@ -116,6 +143,65 @@ pub(super) fn parse_cursor_format( .and_then(|v| v.as_str()) .map(|s| s.to_string()); + let explicit_source_value = obj.get("source").and_then(|v| v.as_str()); + let explicit_source = match explicit_source_value { + Some(value) => match parse_source(value) { + Some(parsed) => Some(parsed), + None => { + warn!( + "Unsupported MCP source for server '{}': {}", + server_id, value + ); + continue; + } + }, + None => None, + }; + let explicit_transport_value = obj.get("transport").and_then(|v| v.as_str()); + let explicit_transport = match explicit_transport_value { + Some(value) => match parse_transport(value) { + Some(parsed) => Some(parsed), + None => { + warn!( + "Unsupported MCP transport for server '{}': {}", + server_id, value + ); + continue; + } + }, + None => None, + }; + let legacy_type_value = obj.get("type").and_then(|v| v.as_str()); + let legacy_type = match legacy_type_value { + Some(value) => match parse_legacy_type(value) { + Some(parsed) => Some(parsed), + None => { + warn!( + "Unsupported MCP type for server '{}': {}", + server_id, value + ); + continue; + } + }, + None => None, + }; + + let server_type = explicit_source + .or_else(|| legacy_type.and_then(|(source, _)| source)) + .unwrap_or_else(|| { + if url.is_some() { + MCPServerType::Remote + } else { + MCPServerType::Local + } + }); + let transport = explicit_transport + .or_else(|| legacy_type.and_then(|(_, transport)| transport)) + .unwrap_or(match server_type { + MCPServerType::Local => MCPServerTransport::Stdio, + MCPServerType::Remote => MCPServerTransport::StreamableHttp, + }); + let name = obj .get("name") .and_then(|v| v.as_str()) @@ -134,6 +220,7 @@ pub(super) fn parse_cursor_format( id: server_id.clone(), name, server_type, + transport: Some(transport), command, args, env, @@ -144,6 +231,14 @@ pub(super) fn parse_cursor_format( location: ConfigLocation::User, capabilities: Vec::new(), settings: Default::default(), + oauth: obj + .get("oauth") + .cloned() + .and_then(|value| serde_json::from_value(value).ok()), + xaa: obj + .get("xaa") + .cloned() + .and_then(|value| serde_json::from_value(value).ok()), }; servers.push(server_config); @@ -155,3 +250,94 @@ pub(super) fn parse_cursor_format( Ok(servers) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn config_to_cursor_format_emits_stdio_for_local_and_sse_for_remote() { + let local = MCPServerConfig { + id: "local".to_string(), + name: "local".to_string(), + server_type: MCPServerType::Local, + transport: Some(MCPServerTransport::Stdio), + command: Some("docker".to_string()), + args: Vec::new(), + env: Default::default(), + headers: Default::default(), + url: None, + auto_start: true, + enabled: true, + location: ConfigLocation::User, + capabilities: Vec::new(), + settings: Default::default(), + oauth: None, + xaa: None, + }; + let sse = MCPServerConfig { + id: "remote".to_string(), + name: "remote".to_string(), + server_type: MCPServerType::Remote, + transport: Some(MCPServerTransport::Sse), + command: None, + args: Vec::new(), + env: Default::default(), + headers: Default::default(), + url: Some("https://example.com/sse".to_string()), + auto_start: true, + enabled: true, + location: ConfigLocation::User, + capabilities: Vec::new(), + settings: Default::default(), + oauth: None, + xaa: None, + }; + + assert_eq!( + config_to_cursor_format(&local) + .get("type") + .and_then(|v| v.as_str()), + Some("stdio") + ); + assert_eq!( + config_to_cursor_format(&sse) + .get("type") + .and_then(|v| v.as_str()), + Some("sse") + ); + } + + #[test] + fn parse_cursor_format_preserves_remote_transport() { + let config = serde_json::json!({ + "mcpServers": { + "remote-sse": { + "type": "sse", + "url": "https://example.com/sse" + } + } + }); + + let parsed = parse_cursor_format(&config).expect("parse should succeed"); + assert_eq!(parsed.len(), 1); + assert_eq!(parsed[0].server_type, MCPServerType::Remote); + assert_eq!(parsed[0].transport, Some(MCPServerTransport::Sse)); + } + + #[test] + fn parse_cursor_format_rejects_container_type() { + let config = serde_json::json!({ + "mcpServers": { + "docker-server": { + "type": "container", + "command": "docker", + "args": ["run", "--rm", "-i", "example/server"] + } + } + }); + + let parsed = parse_cursor_format(&config).expect("parse should succeed"); + assert!(parsed.is_empty()); + } +} diff --git a/src/crates/core/src/service/mcp/config/json_config.rs b/src/crates/core/src/service/mcp/config/json_config.rs index 2d06f9b8..392395e9 100644 --- a/src/crates/core/src/service/mcp/config/json_config.rs +++ b/src/crates/core/src/service/mcp/config/json_config.rs @@ -5,6 +5,38 @@ use crate::util::errors::{BitFunError, BitFunResult}; use super::service::MCPConfigService; impl MCPConfigService { + fn normalize_source(value: &str) -> Option<&'static str> { + match value.trim() { + "local" => Some("local"), + "remote" => Some("remote"), + _ => None, + } + } + + fn normalize_transport(value: &str) -> Option<&'static str> { + match value.trim() { + "stdio" => Some("stdio"), + "sse" => Some("sse"), + "http" | "streamable_http" | "streamable-http" | "streamablehttp" => { + Some("streamable-http") + } + _ => None, + } + } + + fn normalize_legacy_type(value: &str) -> Option<(Option<&'static str>, Option<&'static str>)> { + match value.trim() { + "stdio" => Some((None, Some("stdio"))), + "local" => Some((Some("local"), Some("stdio"))), + "sse" => Some((Some("remote"), Some("sse"))), + "remote" => Some((Some("remote"), Some("streamable-http"))), + "http" | "streamable_http" | "streamable-http" | "streamablehttp" => { + Some((Some("remote"), Some("streamable-http"))) + } + _ => None, + } + } + /// Loads MCP JSON config (Cursor format). pub async fn load_mcp_json_config(&self) -> BitFunResult { match self @@ -75,6 +107,16 @@ impl MCPConfigService { .and_then(|v| v.as_str()) .map(|s| s.trim()) .filter(|s| !s.is_empty()); + let source_str = obj + .get("source") + .and_then(|v| v.as_str()) + .map(|s| s.trim()) + .filter(|s| !s.is_empty()); + let transport_str = obj + .get("transport") + .and_then(|v| v.as_str()) + .map(|s| s.trim()) + .filter(|s| !s.is_empty()); let command = obj .get("command") @@ -88,7 +130,7 @@ impl MCPConfigService { .map(|s| s.trim()) .filter(|s| !s.is_empty()); - let inferred_transport = match (command.is_some(), url.is_some()) { + match (command.is_some(), url.is_some()) { (true, true) => { let error_msg = format!( "Server '{}' must not set both 'command' and 'url' fields", @@ -97,8 +139,6 @@ impl MCPConfigService { error!("{}", error_msg); return Err(BitFunError::validation(error_msg)); } - (true, false) => "stdio", - (false, true) => "streamable-http", (false, false) => { let error_msg = format!( "Server '{}' must provide either 'command' (stdio) or 'url' (streamable-http)", @@ -107,34 +147,98 @@ impl MCPConfigService { error!("{}", error_msg); return Err(BitFunError::validation(error_msg)); } + _ => {} + } + + let legacy_type = match type_str { + Some(value) => Self::normalize_legacy_type(value).ok_or_else(|| { + BitFunError::validation(format!( + "Server '{}' has unsupported 'type' value: '{}'", + server_id, value + )) + })?, + None => (None, None), }; - if let Some(t) = type_str { - let normalized_transport = match t { - "stdio" | "local" | "container" => "stdio", - "sse" | "remote" | "http" | "streamable_http" | "streamable-http" - | "streamablehttp" => "streamable-http", - _ => { + let explicit_source = match source_str { + Some(value) => Some(Self::normalize_source(value).ok_or_else(|| { + BitFunError::validation(format!( + "Server '{}' has unsupported 'source' value: '{}'", + server_id, value + )) + })?), + None => legacy_type.0, + }; + let explicit_transport = match transport_str { + Some(value) => Some(Self::normalize_transport(value).ok_or_else(|| { + BitFunError::validation(format!( + "Server '{}' has unsupported 'transport' value: '{}'", + server_id, value + )) + })?), + None => legacy_type.1, + }; + + let effective_source = match (command.is_some(), url.is_some()) { + (true, false) => match explicit_source { + Some("remote") => { + let error_msg = format!( + "Server '{}' source='remote' conflicts with command-based configuration", + server_id + ); + error!("{}", error_msg); + return Err(BitFunError::validation(error_msg)); + } + Some(source) => source, + None => "local", + }, + (false, true) => match explicit_source { + Some("local") => { let error_msg = format!( - "Server '{}' has unsupported 'type' value: '{}'", - server_id, t + "Server '{}' source='{}' conflicts with url-based configuration", + server_id, + explicit_source.unwrap_or("unknown") ); error!("{}", error_msg); return Err(BitFunError::validation(error_msg)); } - }; + Some(source) => source, + None => "remote", + }, + _ => unreachable!(), + }; - if normalized_transport != inferred_transport { - let error_msg = format!( - "Server '{}' 'type' conflicts with provided fields (type='{}')", - server_id, t - ); - error!("{}", error_msg); - return Err(BitFunError::validation(error_msg)); + let effective_transport = match effective_source { + "local" => { + if let Some(transport) = explicit_transport { + if transport != "stdio" { + let error_msg = format!( + "Server '{}' source='{}' must use stdio transport", + server_id, effective_source + ); + error!("{}", error_msg); + return Err(BitFunError::validation(error_msg)); + } + } + "stdio" } - } + "remote" => match explicit_transport.unwrap_or("streamable-http") { + "streamable-http" | "sse" => { + explicit_transport.unwrap_or("streamable-http") + } + _ => { + let error_msg = format!( + "Server '{}' remote source must use 'streamable-http' or 'sse' transport", + server_id + ); + error!("{}", error_msg); + return Err(BitFunError::validation(error_msg)); + } + }, + _ => unreachable!(), + }; - if inferred_transport == "stdio" && command.is_none() { + if effective_transport == "stdio" && command.is_none() { let error_msg = format!( "Server '{}' (stdio) must provide 'command' field", server_id @@ -143,10 +247,12 @@ impl MCPConfigService { return Err(BitFunError::validation(error_msg)); } - if inferred_transport == "streamable-http" && url.is_none() { + if (effective_transport == "streamable-http" || effective_transport == "sse") + && url.is_none() + { let error_msg = format!( - "Server '{}' (streamable-http) must provide 'url' field", - server_id + "Server '{}' ({}) must provide 'url' field", + server_id, effective_transport ); error!("{}", error_msg); return Err(BitFunError::validation(error_msg)); @@ -169,6 +275,33 @@ impl MCPConfigService { return Err(BitFunError::validation(error_msg)); } } + + if let Some(headers) = obj.get("headers") { + if !headers.is_object() { + let error_msg = + format!("Server '{}' 'headers' field must be an object", server_id); + error!("{}", error_msg); + return Err(BitFunError::validation(error_msg)); + } + } + + if let Some(oauth) = obj.get("oauth") { + if !oauth.is_object() { + let error_msg = + format!("Server '{}' 'oauth' field must be an object", server_id); + error!("{}", error_msg); + return Err(BitFunError::validation(error_msg)); + } + } + + if let Some(xaa) = obj.get("xaa") { + if !xaa.is_object() { + let error_msg = + format!("Server '{}' 'xaa' field must be an object", server_id); + error!("{}", error_msg); + return Err(BitFunError::validation(error_msg)); + } + } } else { let error_msg = format!("Server '{}' config must be an object", server_id); error!("{}", error_msg); @@ -199,3 +332,17 @@ impl MCPConfigService { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::MCPConfigService; + + #[test] + fn normalize_legacy_type_rejects_container_and_preserves_sse() { + assert_eq!(MCPConfigService::normalize_legacy_type("container"), None); + assert_eq!( + MCPConfigService::normalize_legacy_type("sse"), + Some((Some("remote"), Some("sse"))) + ); + } +} diff --git a/src/crates/core/src/service/mcp/config/service.rs b/src/crates/core/src/service/mcp/config/service.rs index db8cf8ec..aaabc188 100644 --- a/src/crates/core/src/service/mcp/config/service.rs +++ b/src/crates/core/src/service/mcp/config/service.rs @@ -1,4 +1,5 @@ use log::{info, warn}; +use std::collections::{BTreeMap, HashMap}; use std::sync::Arc; use crate::service::config::ConfigService; @@ -13,6 +14,163 @@ pub struct MCPConfigService { } impl MCPConfigService { + const AUTHORIZATION_KEYS: [&'static str; 3] = + ["Authorization", "authorization", "AUTHORIZATION"]; + + fn config_signature(config: &MCPServerConfig) -> String { + let env: BTreeMap<_, _> = config.env.clone().into_iter().collect(); + let headers: BTreeMap<_, _> = config.headers.clone().into_iter().collect(); + serde_json::json!({ + "serverType": config.server_type, + "transport": config.resolved_transport().as_str(), + "command": config.command, + "args": config.args, + "env": env, + "headers": headers, + "url": config.url, + "oauth": config.oauth, + "xaa": config.xaa, + }) + .to_string() + } + + fn precedence(location: ConfigLocation) -> u8 { + match location { + ConfigLocation::BuiltIn => 0, + ConfigLocation::User => 1, + ConfigLocation::Project => 2, + } + } + + fn merge_configs( + merged: &mut Vec, + source: Vec, + signature_index: &mut HashMap, + id_index: &mut HashMap, + ) { + for config in source { + let config_id = config.id.clone(); + let signature = Self::config_signature(&config); + + if let Some(existing_index) = id_index.get(&config_id).copied() { + let previous = &merged[existing_index]; + warn!( + "Overriding MCP config by id: id={} previous_location={:?} new_location={:?}", + config_id, previous.location, config.location + ); + + let previous_signature = Self::config_signature(previous); + merged[existing_index] = config; + signature_index.remove(&previous_signature); + signature_index.insert(signature, existing_index); + continue; + } + + if let Some(existing_index) = signature_index.get(&signature).copied() { + let previous = &merged[existing_index]; + if Self::precedence(previous.location) <= Self::precedence(config.location) { + warn!( + "Deduplicating MCP config by content signature: previous_id={} previous_location={:?} replacement_id={} replacement_location={:?}", + previous.id, previous.location, config_id, config.location + ); + + id_index.remove(&previous.id); + merged[existing_index] = config; + id_index.insert(config_id, existing_index); + signature_index.insert(signature, existing_index); + } + continue; + } + + let next_index = merged.len(); + signature_index.insert(signature, next_index); + id_index.insert(config_id, next_index); + merged.push(config); + } + } + + fn parse_config_array( + &self, + servers: &[serde_json::Value], + location: ConfigLocation, + ) -> Vec { + servers + .iter() + .filter_map( + |value| match serde_json::from_value::(value.clone()) { + Ok(mut config) => { + config.location = location; + Some(config) + } + Err(e) => { + warn!( + "Failed to parse MCP config item at {:?} scope: {}", + location, e + ); + None + } + }, + ) + .collect() + } + + fn normalize_authorization_value(value: &str) -> Option { + let trimmed = value.trim(); + if trimmed.is_empty() { + return None; + } + + if trimmed.to_ascii_lowercase().starts_with("bearer ") + || trimmed.contains(char::is_whitespace) + { + return Some(trimmed.to_string()); + } + + Some(format!("Bearer {}", trimmed)) + } + + fn config_authorization_from_map( + map: &std::collections::HashMap, + ) -> Option { + Self::AUTHORIZATION_KEYS + .iter() + .find_map(|key| map.get(*key).cloned()) + .filter(|value| !value.trim().is_empty()) + } + + fn remove_authorization_keys(map: &mut std::collections::HashMap) { + for key in Self::AUTHORIZATION_KEYS { + map.remove(key); + } + } + + pub fn get_remote_authorization_value(config: &MCPServerConfig) -> Option { + Self::config_authorization_from_map(&config.headers) + .or_else(|| Self::config_authorization_from_map(&config.env)) + } + + pub fn get_remote_authorization_source(config: &MCPServerConfig) -> Option<&'static str> { + if Self::config_authorization_from_map(&config.headers).is_some() { + Some("headers") + } else if Self::config_authorization_from_map(&config.env).is_some() { + Some("env") + } else { + None + } + } + + pub fn has_remote_authorization(config: &MCPServerConfig) -> bool { + Self::get_remote_authorization_value(config).is_some() + } + + pub fn has_remote_oauth(config: &MCPServerConfig) -> bool { + config.oauth.is_some() + } + + pub fn has_remote_xaa(config: &MCPServerConfig) -> bool { + config.xaa.is_some() + } + /// Creates a new MCP configuration service. pub fn new(config_service: Arc) -> BitFunResult { Ok(Self { config_service }) @@ -20,28 +178,45 @@ impl MCPConfigService { /// Loads all MCP server configurations. pub async fn load_all_configs(&self) -> BitFunResult> { - let mut configs = Vec::new(); - - let builtin = self.load_builtin_configs().await?; - configs.extend(builtin); - - match self.load_user_configs().await { - Ok(user_configs) => { - configs.extend(user_configs); - } + let builtin_configs = self.load_builtin_configs().await?; + let user_configs = match self.load_user_configs().await { + Ok(user_configs) => user_configs, Err(e) => { warn!("Failed to load user-level MCP configs: {}", e); + Vec::new() } - } + }; - match self.load_project_configs().await { - Ok(project_configs) => { - configs.extend(project_configs); - } + let project_configs = match self.load_project_configs().await { + Ok(project_configs) => project_configs, Err(e) => { warn!("Failed to load project-level MCP configs: {}", e); + Vec::new() } - } + }; + + let mut configs = Vec::new(); + let mut signature_index = HashMap::new(); + let mut id_index = HashMap::new(); + + Self::merge_configs( + &mut configs, + builtin_configs, + &mut signature_index, + &mut id_index, + ); + Self::merge_configs( + &mut configs, + user_configs, + &mut signature_index, + &mut id_index, + ); + Self::merge_configs( + &mut configs, + project_configs, + &mut signature_index, + &mut id_index, + ); info!("Loaded {} MCP server config(s)", configs.len()); Ok(configs) @@ -70,19 +245,7 @@ impl MCPConfigService { } if let Some(servers) = config_value.as_array() { - let configs: Vec = servers - .iter() - .filter_map(|v| { - match serde_json::from_value::(v.clone()) { - Ok(config) => Some(config), - Err(e) => { - warn!("Failed to parse MCP config item: {}", e); - None - } - } - }) - .collect(); - return Ok(configs); + return Ok(self.parse_config_array(servers, ConfigLocation::User)); } warn!("Invalid MCP config format, returning empty list"); @@ -100,12 +263,20 @@ impl MCPConfigService { .await { Ok(config_value) => { + if config_value + .get("mcpServers") + .and_then(|v| v.as_object()) + .is_some() + { + let mut configs = super::cursor_format::parse_cursor_format(&config_value)?; + for config in &mut configs { + config.location = ConfigLocation::Project; + } + return Ok(configs); + } + if let Some(servers) = config_value.as_array() { - let configs: Vec = servers - .iter() - .filter_map(|v| serde_json::from_value(v.clone()).ok()) - .collect(); - Ok(configs) + Ok(self.parse_config_array(servers, ConfigLocation::Project)) } else { Ok(Vec::new()) } @@ -134,6 +305,58 @@ impl MCPConfigService { } } + pub async fn set_remote_authorization( + &self, + server_id: &str, + authorization_value: &str, + ) -> BitFunResult { + let mut config = self.get_server_config(server_id).await?.ok_or_else(|| { + BitFunError::NotFound(format!("MCP server config not found: {}", server_id)) + })?; + + if config.server_type != crate::service::mcp::server::MCPServerType::Remote { + return Err(BitFunError::Validation(format!( + "MCP server '{}' is not a remote server", + server_id + ))); + } + + let normalized = + Self::normalize_authorization_value(authorization_value).ok_or_else(|| { + BitFunError::Validation("Authorization value cannot be empty".to_string()) + })?; + + Self::remove_authorization_keys(&mut config.headers); + Self::remove_authorization_keys(&mut config.env); + config + .headers + .insert("Authorization".to_string(), normalized); + + self.save_server_config(&config).await?; + Ok(config) + } + + pub async fn clear_remote_authorization( + &self, + server_id: &str, + ) -> BitFunResult { + let mut config = self.get_server_config(server_id).await?.ok_or_else(|| { + BitFunError::NotFound(format!("MCP server config not found: {}", server_id)) + })?; + + if config.server_type != crate::service::mcp::server::MCPServerType::Remote { + return Err(BitFunError::Validation(format!( + "MCP server '{}' is not a remote server", + server_id + ))); + } + + Self::remove_authorization_keys(&mut config.headers); + Self::remove_authorization_keys(&mut config.env); + self.save_server_config(&config).await?; + Ok(config) + } + /// Saves user-level configuration. async fn save_user_config(&self, config: &MCPServerConfig) -> BitFunResult<()> { let current_value = self @@ -223,3 +446,142 @@ impl MCPConfigService { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::service::mcp::server::MCPServerType; + + fn make_config( + id: &str, + location: ConfigLocation, + server_type: MCPServerType, + command: Option<&str>, + url: Option<&str>, + ) -> MCPServerConfig { + MCPServerConfig { + id: id.to_string(), + name: id.to_string(), + server_type, + transport: None, + command: command.map(str::to_string), + args: Vec::new(), + env: HashMap::new(), + headers: HashMap::new(), + url: url.map(str::to_string), + auto_start: true, + enabled: true, + location, + capabilities: Vec::new(), + settings: Default::default(), + oauth: None, + xaa: None, + } + } + + #[test] + fn merge_configs_prefers_higher_precedence_when_ids_match() { + let mut merged = Vec::new(); + let mut signature_index = HashMap::new(); + let mut id_index = HashMap::new(); + + MCPConfigService::merge_configs( + &mut merged, + vec![make_config( + "github", + ConfigLocation::User, + MCPServerType::Remote, + None, + Some("https://example.com/mcp"), + )], + &mut signature_index, + &mut id_index, + ); + MCPConfigService::merge_configs( + &mut merged, + vec![make_config( + "github", + ConfigLocation::Project, + MCPServerType::Remote, + None, + Some("https://project.example.com/mcp"), + )], + &mut signature_index, + &mut id_index, + ); + + assert_eq!(merged.len(), 1); + assert_eq!(merged[0].location, ConfigLocation::Project); + assert_eq!( + merged[0].url.as_deref(), + Some("https://project.example.com/mcp") + ); + } + + #[test] + fn merge_configs_deduplicates_same_server_content_across_ids() { + let mut merged = Vec::new(); + let mut signature_index = HashMap::new(); + let mut id_index = HashMap::new(); + + MCPConfigService::merge_configs( + &mut merged, + vec![make_config( + "github-user", + ConfigLocation::User, + MCPServerType::Remote, + None, + Some("https://example.com/mcp"), + )], + &mut signature_index, + &mut id_index, + ); + MCPConfigService::merge_configs( + &mut merged, + vec![make_config( + "github-project", + ConfigLocation::Project, + MCPServerType::Remote, + None, + Some("https://example.com/mcp"), + )], + &mut signature_index, + &mut id_index, + ); + + assert_eq!(merged.len(), 1); + assert_eq!(merged[0].id, "github-project"); + assert_eq!(merged[0].location, ConfigLocation::Project); + } + + #[test] + fn remote_authorization_prefers_headers_and_normalizes_tokens() { + let mut config = make_config( + "remote-auth", + ConfigLocation::User, + MCPServerType::Remote, + None, + Some("https://example.com/mcp"), + ); + config + .env + .insert("Authorization".to_string(), "legacy-token".to_string()); + config.headers.insert( + "Authorization".to_string(), + "Bearer header-token".to_string(), + ); + + assert_eq!( + MCPConfigService::get_remote_authorization_value(&config).as_deref(), + Some("Bearer header-token") + ); + assert_eq!( + MCPConfigService::get_remote_authorization_source(&config), + Some("headers") + ); + assert_eq!( + MCPConfigService::normalize_authorization_value("plain-token").as_deref(), + Some("Bearer plain-token") + ); + } +} diff --git a/src/crates/core/src/service/mcp/mod.rs b/src/crates/core/src/service/mcp/mod.rs index 53b8943a..843550e2 100644 --- a/src/crates/core/src/service/mcp/mod.rs +++ b/src/crates/core/src/service/mcp/mod.rs @@ -10,11 +10,15 @@ //! - `config`: MCP configuration management pub mod adapter; +pub mod auth; pub mod config; pub mod protocol; pub mod server; -// Re-export main components. +use std::sync::Arc; +use std::sync::OnceLock; + +// Stable public surface for the MCP service. pub use protocol::{ MCPCapability, MCPMessage, MCPNotification, MCPProtocolVersion, MCPRequest, MCPResponse, MCPServerInfo, @@ -22,7 +26,7 @@ pub use protocol::{ pub use server::{ MCPConnection, MCPConnectionPool, MCPServerConfig, MCPServerManager, MCPServerStatus, - MCPServerType, + MCPServerTransport, MCPServerType, }; pub use adapter::{ @@ -33,19 +37,19 @@ pub use config::{ConfigLocation, MCPConfigService}; /// MCP service interface. pub struct MCPService { - server_manager: std::sync::Arc, - config_service: std::sync::Arc, - context_provider: std::sync::Arc, + server_manager: Arc, + config_service: Arc, + context_provider: Arc, } impl MCPService { /// Creates a new MCP service instance. pub fn new( - config_service: std::sync::Arc, + config_service: Arc, ) -> crate::util::errors::BitFunResult { - let mcp_config_service = std::sync::Arc::new(MCPConfigService::new(config_service)?); - let server_manager = std::sync::Arc::new(MCPServerManager::new(mcp_config_service.clone())); - let context_provider = std::sync::Arc::new(MCPContextProvider::new(server_manager.clone())); + let mcp_config_service = Arc::new(MCPConfigService::new(config_service)?); + let server_manager = Arc::new(MCPServerManager::new(mcp_config_service.clone())); + let context_provider = Arc::new(MCPContextProvider::new(server_manager.clone())); Ok(Self { server_manager, @@ -55,17 +59,29 @@ impl MCPService { } /// Returns the server manager. - pub fn server_manager(&self) -> std::sync::Arc { + pub fn server_manager(&self) -> Arc { self.server_manager.clone() } /// Returns the context provider. - pub fn context_provider(&self) -> std::sync::Arc { + pub fn context_provider(&self) -> Arc { self.context_provider.clone() } /// Returns the configuration service. - pub fn config_service(&self) -> std::sync::Arc { + pub fn config_service(&self) -> Arc { self.config_service.clone() } } + +static GLOBAL_MCP_SERVICE: OnceLock> = OnceLock::new(); + +/// Stores the global MCP service for code paths that cannot receive it via DI yet. +pub fn set_global_mcp_service(service: Arc) { + let _ = GLOBAL_MCP_SERVICE.set(service); +} + +/// Returns the global MCP service if it has been initialized. +pub fn get_global_mcp_service() -> Option> { + GLOBAL_MCP_SERVICE.get().cloned() +} diff --git a/src/crates/core/src/service/mcp/protocol/mod.rs b/src/crates/core/src/service/mcp/protocol/mod.rs index deaa22a1..1d6c903b 100644 --- a/src/crates/core/src/service/mcp/protocol/mod.rs +++ b/src/crates/core/src/service/mcp/protocol/mod.rs @@ -3,10 +3,10 @@ //! Implements the core protocol definitions of Model Context Protocol and JSON-RPC 2.0 //! communication. -pub mod jsonrpc; -pub mod transport; -pub mod transport_remote; -pub mod types; +mod jsonrpc; +mod transport; +mod transport_remote; +mod types; pub use jsonrpc::*; pub use transport::*; diff --git a/src/crates/core/src/service/mcp/protocol/transport_remote.rs b/src/crates/core/src/service/mcp/protocol/transport_remote.rs index b5a784ca..00caa192 100644 --- a/src/crates/core/src/service/mcp/protocol/transport_remote.rs +++ b/src/crates/core/src/service/mcp/protocol/transport_remote.rs @@ -3,11 +3,14 @@ //! Uses the official `rmcp` Rust SDK to implement the MCP Streamable HTTP client transport. use super::types::{ - InitializeResult as BitFunInitializeResult, MCPCapability, MCPPrompt, MCPPromptArgument, - MCPPromptMessage, MCPPromptMessageContent, MCPResource, MCPResourceContent, MCPServerInfo, - MCPTool, MCPToolResult, MCPToolResultContent, PromptsGetResult, PromptsListResult, - ResourcesListResult, ResourcesReadResult, ToolsListResult, + InitializeResult as BitFunInitializeResult, MCPCapability, MCPAnnotations, MCPPrompt, + MCPPromptArgument, MCPPromptMessage, MCPPromptMessageContent, MCPPromptMessageContentBlock, + MCPResource, MCPResourceContent, MCPResourceIcon, MCPServerInfo, MCPTool, + MCPToolAnnotations, MCPToolResult, MCPToolResultContent, + PromptsGetResult, PromptsListResult, ResourcesListResult, ResourcesReadResult, + ToolsListResult, }; +use crate::service::mcp::auth::build_authorization_manager; use crate::util::errors::{BitFunError, BitFunResult}; use futures::StreamExt; use log::{debug, error, info, warn}; @@ -21,6 +24,7 @@ use rmcp::model::{ ResourceContents, }; use rmcp::service::RunningService; +use rmcp::transport::auth::AuthorizationManager; use rmcp::transport::common::http_header::{ EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_SESSION_ID, JSON_MIME_TYPE, }; @@ -32,6 +36,7 @@ use rmcp::transport::streamable_http_client::{ use rmcp::transport::StreamableHttpClientTransport; use rmcp::ClientHandler; use rmcp::RoleClient; +use serde::de::DeserializeOwned; use serde_json::Value; use std::collections::HashMap; use std::str::FromStr; @@ -111,6 +116,25 @@ enum ClientState { #[derive(Clone)] struct BitFunStreamableHttpClient { client: reqwest::Client, + oauth_manager: Option>>, +} + +impl BitFunStreamableHttpClient { + async fn resolve_auth_token( + &self, + auth_token: Option, + ) -> Result, StreamableHttpError> { + if auth_token.is_some() { + return Ok(auth_token); + } + + let Some(oauth_manager) = &self.oauth_manager else { + return Ok(None); + }; + + let token = oauth_manager.lock().await.get_access_token().await?; + Ok(Some(token)) + } } impl StreamableHttpClient for BitFunStreamableHttpClient { @@ -126,6 +150,7 @@ impl StreamableHttpClient for BitFunStreamableHttpClient { futures::stream::BoxStream<'static, Result>, StreamableHttpError, > { + let auth_token = self.resolve_auth_token(auth_token).await?; let mut request_builder = self .client .get(uri.as_ref()) @@ -169,6 +194,7 @@ impl StreamableHttpClient for BitFunStreamableHttpClient { session: StdArc, auth_token: Option, ) -> Result<(), StreamableHttpError> { + let auth_token = self.resolve_auth_token(auth_token).await?; let mut request_builder = self.client.delete(uri.as_ref()); if let Some(auth_header) = auth_token { request_builder = request_builder.bearer_auth(auth_header); @@ -192,6 +218,7 @@ impl StreamableHttpClient for BitFunStreamableHttpClient { session_id: Option>, auth_token: Option, ) -> Result> { + let auth_token = self.resolve_auth_token(auth_token).await?; let mut request = self .client .post(uri.as_ref()) @@ -282,6 +309,7 @@ impl StreamableHttpClient for BitFunStreamableHttpClient { pub struct RemoteMCPTransport { url: String, default_headers: HeaderMap, + oauth_manager: Option>>, request_timeout: Duration, state: Mutex, } @@ -348,8 +376,30 @@ impl RemoteMCPTransport { } /// Creates a new streamable HTTP remote transport instance. - pub fn new(url: String, headers: HashMap, request_timeout: Duration) -> Self { + pub async fn new( + server_id: &str, + url: String, + headers: HashMap, + request_timeout: Duration, + oauth_enabled: bool, + ) -> BitFunResult { let default_headers = Self::build_default_headers(&headers); + let oauth_manager = if oauth_enabled + && !default_headers.contains_key(reqwest::header::AUTHORIZATION) + { + let (manager, initialized) = build_authorization_manager(server_id, &url).await?; + if initialized { + Some(Arc::new(Mutex::new(manager))) + } else { + info!( + "Remote MCP OAuth configured but credentials are not authorized yet: server_id={}", + server_id + ); + None + } + } else { + None + }; let http_client = reqwest::Client::builder() .connect_timeout(Duration::from_secs(10)) @@ -365,26 +415,41 @@ impl RemoteMCPTransport { let transport = StreamableHttpClientTransport::with_client( BitFunStreamableHttpClient { client: http_client, + oauth_manager: oauth_manager.clone(), }, StreamableHttpClientTransportConfig::with_uri(url.clone()), ); - Self { + Ok(Self { url, default_headers, + oauth_manager, request_timeout, state: Mutex::new(ClientState::Connecting { transport: Some(transport), }), - } + }) } /// Returns the auth token header value (if present). - pub fn get_auth_token(&self) -> Option { - self.default_headers + pub async fn get_auth_token(&self) -> Option { + if let Some(value) = self + .default_headers .get(reqwest::header::AUTHORIZATION) .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()) + { + return Some(value); + } + + let oauth_manager = self.oauth_manager.as_ref()?; + oauth_manager + .lock() + .await + .get_access_token() + .await + .ok() + .map(|token| format!("Bearer {}", token)) } async fn service( @@ -402,7 +467,13 @@ impl RemoteMCPTransport { fn build_client_info(client_name: &str, client_version: &str) -> ClientInfo { ClientInfo { protocol_version: ProtocolVersion::LATEST, - capabilities: ClientCapabilities::default(), + capabilities: ClientCapabilities::builder() + .enable_roots() + .enable_sampling() + .enable_elicitation_with(rmcp::model::ElicitationCapability { + schema_validation: Some(true), + }) + .build(), client_info: Implementation { name: client_name.to_string(), title: None, @@ -657,13 +728,13 @@ fn map_tool(tool: rmcp::model::Tool) -> MCPTool { let schema = Value::Object((*tool.input_schema).clone()); MCPTool { name: tool.name.to_string(), - title: None, + title: tool.title, description: tool.description.map(|d| d.to_string()), input_schema: schema, - output_schema: None, - icons: None, - annotations: None, - meta: None, + output_schema: tool.output_schema.map(|schema| Value::Object((*schema).clone())), + icons: map_icons(tool.icons.as_ref()), + annotations: tool.annotations.map(map_tool_annotations), + meta: map_optional_via_json(tool.meta.as_ref()), } } @@ -671,13 +742,13 @@ fn map_resource(resource: rmcp::model::Resource) -> MCPResource { MCPResource { uri: resource.uri.clone(), name: resource.name.clone(), - title: None, + title: resource.title.clone(), description: resource.description.clone(), mime_type: resource.mime_type.clone(), - icons: None, - size: None, - annotations: None, - metadata: None, + icons: map_icons(resource.icons.as_ref()), + size: resource.size.map(u64::from), + annotations: map_annotations(resource.annotations.as_ref()), + metadata: map_meta_to_hash_map(resource.meta.as_ref()), } } @@ -687,6 +758,7 @@ fn map_resource_content(contents: ResourceContents) -> MCPResourceContent { uri, mime_type, text, + meta, .. } => MCPResourceContent { uri, @@ -694,12 +766,13 @@ fn map_resource_content(contents: ResourceContents) -> MCPResourceContent { blob: None, mime_type, annotations: None, - meta: None, + meta: map_optional_via_json(meta.as_ref()), }, ResourceContents::BlobResourceContents { uri, mime_type, blob, + meta, .. } => MCPResourceContent { uri, @@ -707,7 +780,7 @@ fn map_resource_content(contents: ResourceContents) -> MCPResourceContent { blob: Some(blob), mime_type, annotations: None, - meta: None, + meta: map_optional_via_json(meta.as_ref()), }, } } @@ -715,18 +788,19 @@ fn map_resource_content(contents: ResourceContents) -> MCPResourceContent { fn map_prompt(prompt: rmcp::model::Prompt) -> MCPPrompt { MCPPrompt { name: prompt.name, - title: None, + title: prompt.title, description: prompt.description, arguments: prompt.arguments.map(|args| { args.into_iter() .map(|a| MCPPromptArgument { name: a.name, + title: a.title, description: a.description, required: a.required.unwrap_or(false), }) .collect() }), - icons: None, + icons: map_icons(prompt.icons.as_ref()), } } @@ -738,35 +812,48 @@ fn map_prompt_message(message: rmcp::model::PromptMessage) -> MCPPromptMessage { .to_string(); let content = match message.content { - rmcp::model::PromptMessageContent::Text { text } => text, - rmcp::model::PromptMessageContent::Image { .. } => "[image]".to_string(), - rmcp::model::PromptMessageContent::Resource { resource } => resource.get_text(), + rmcp::model::PromptMessageContent::Text { text } => { + MCPPromptMessageContent::Block(MCPPromptMessageContentBlock::Text { text }) + } + rmcp::model::PromptMessageContent::Image { image } => { + MCPPromptMessageContent::Block(MCPPromptMessageContentBlock::Image { + data: image.data.clone(), + mime_type: image.mime_type.clone(), + }) + } + rmcp::model::PromptMessageContent::Resource { resource } => { + let mut mapped = map_resource_content(resource.resource.clone()); + if mapped.meta.is_none() { + mapped.meta = map_optional_via_json(resource.meta.as_ref()); + } + mapped.annotations = map_annotations(resource.annotations.as_ref()); + MCPPromptMessageContent::Block(MCPPromptMessageContentBlock::Resource { + resource: mapped, + }) + } rmcp::model::PromptMessageContent::ResourceLink { link } => { - format!("[resource_link] {}", link.uri) + MCPPromptMessageContent::Block(MCPPromptMessageContentBlock::ResourceLink { + uri: link.uri.clone(), + name: Some(link.name.clone()), + description: link.description.clone(), + mime_type: link.mime_type.clone(), + }) } }; MCPPromptMessage { role, - content: MCPPromptMessageContent::Plain(content), + content, } } fn map_tool_result(result: rmcp::model::CallToolResult) -> MCPToolResult { - let mut mapped: Vec = result + let mapped: Vec = result .content .into_iter() .filter_map(map_content_block) .collect(); - if mapped.is_empty() { - if let Some(value) = result.structured_content { - mapped.push(MCPToolResultContent::Text { - text: value.to_string(), - }); - } - } - MCPToolResult { content: if mapped.is_empty() { None @@ -774,7 +861,8 @@ fn map_tool_result(result: rmcp::model::CallToolResult) -> MCPToolResult { Some(mapped) }, is_error: result.is_error.unwrap_or(false), - structured_content: None, + structured_content: result.structured_content, + meta: map_optional_json_value(result.meta.as_ref()), } } @@ -788,11 +876,326 @@ fn map_content_block(content: Content) -> Option { rmcp::model::RawContent::Resource(resource) => Some(MCPToolResultContent::Resource { resource: map_resource_content(resource.resource), }), - rmcp::model::RawContent::Audio(audio) => Some(MCPToolResultContent::Text { - text: format!("[audio] mime_type={}", audio.mime_type), + rmcp::model::RawContent::Audio(audio) => Some(MCPToolResultContent::Audio { + data: audio.data, + mime_type: audio.mime_type, }), - rmcp::model::RawContent::ResourceLink(link) => Some(MCPToolResultContent::Text { - text: format!("[resource_link] {}", link.uri), + rmcp::model::RawContent::ResourceLink(link) => Some(MCPToolResultContent::ResourceLink { + uri: link.uri, + name: Some(link.name), + description: link.description, + mime_type: link.mime_type, }), } } + +fn map_icons(icons: Option<&Vec>) -> Option> { + icons.map(|icons| { + icons + .iter() + .map(|icon| MCPResourceIcon { + src: icon.src.clone(), + mime_type: icon.mime_type.clone(), + sizes: icon.sizes.as_ref().map(|sizes| { + Value::Array( + sizes + .iter() + .cloned() + .map(Value::String) + .collect::>(), + ) + }), + }) + .collect() + }) +} + +fn map_annotations(annotations: Option<&rmcp::model::Annotations>) -> Option { + annotations.map(|annotations| MCPAnnotations { + audience: annotations + .audience + .as_ref() + .map(|audience| audience.iter().map(map_role).collect()), + priority: annotations.priority.map(f64::from), + last_modified: annotations.last_modified.map(|timestamp| timestamp.to_rfc3339()), + }) +} + +fn map_tool_annotations(annotations: rmcp::model::ToolAnnotations) -> MCPToolAnnotations { + MCPToolAnnotations { + title: annotations.title, + read_only_hint: annotations.read_only_hint, + destructive_hint: annotations.destructive_hint, + idempotent_hint: annotations.idempotent_hint, + open_world_hint: annotations.open_world_hint, + } +} + +fn map_role(role: &rmcp::model::Role) -> String { + match role { + rmcp::model::Role::User => "user", + rmcp::model::Role::Assistant => "assistant", + } + .to_string() +} + +fn map_meta_to_hash_map(meta: Option<&rmcp::model::Meta>) -> Option> { + meta.and_then(|meta| match serde_json::to_value(meta.clone()).ok()? { + Value::Object(map) => Some(map.into_iter().collect()), + _ => None, + }) +} + +fn map_optional_json_value(value: Option<&T>) -> Option +where + T: serde::Serialize, +{ + value.and_then(|value| serde_json::to_value(value).ok()) +} + +fn map_optional_via_json(value: Option<&T>) -> Option +where + T: serde::Serialize, + U: DeserializeOwned, +{ + value + .and_then(|value| serde_json::to_value(value).ok()) + .and_then(|value| serde_json::from_value(value).ok()) +} + +#[cfg(test)] +mod tests { + use super::*; + use rmcp::model::{AnnotateAble, Annotations, Content, Icon, Meta, RawResource}; + use serde_json::json; + + #[test] + fn build_client_info_declares_supported_client_capabilities() { + let info = RemoteMCPTransport::build_client_info("BitFun", "1.0.0"); + + assert!(info.capabilities.roots.is_some()); + assert!(info.capabilities.sampling.is_some()); + assert!(info.capabilities.elicitation.is_some()); + assert_eq!( + info.capabilities + .elicitation + .as_ref() + .and_then(|cap| cap.schema_validation), + Some(true) + ); + } + + #[test] + fn mapping_preserves_remote_tool_resource_and_prompt_metadata() { + let mut tool_meta = Meta::default(); + tool_meta.insert("ui".to_string(), json!({ "resourceUri": "ui://widget" })); + let tool = rmcp::model::Tool { + name: "search".into(), + title: Some("Search".to_string()), + description: Some("Find items".into()), + input_schema: Arc::new(serde_json::Map::new()), + output_schema: Some(Arc::new(serde_json::Map::from_iter([( + "type".to_string(), + json!("object"), + )]))), + annotations: Some( + rmcp::model::ToolAnnotations::new() + .read_only(true) + .destructive(false) + .idempotent(true) + .open_world(true), + ), + icons: Some(vec![Icon { + src: "https://example.com/tool.png".to_string(), + mime_type: Some("image/png".to_string()), + sizes: Some(vec!["32x32".to_string()]), + }]), + meta: Some(tool_meta), + }; + let mapped_tool = map_tool(tool); + assert_eq!(mapped_tool.title.as_deref(), Some("Search")); + assert_eq!(mapped_tool.output_schema, Some(json!({ "type": "object" }))); + assert_eq!( + mapped_tool + .annotations + .as_ref() + .and_then(|annotations| annotations.read_only_hint), + Some(true) + ); + assert_eq!( + mapped_tool + .meta + .as_ref() + .and_then(|meta| meta.ui.as_ref()) + .and_then(|ui| ui.resource_uri.as_deref()), + Some("ui://widget") + ); + + let mut resource_meta = Meta::default(); + resource_meta.insert("source".to_string(), json!("catalog")); + let resource = RawResource { + uri: "file:///tmp/report.md".to_string(), + name: "report".to_string(), + title: Some("Quarterly Report".to_string()), + description: Some("Report".to_string()), + mime_type: Some("text/markdown".to_string()), + size: Some(42), + icons: Some(vec![Icon { + src: "https://example.com/resource.png".to_string(), + mime_type: Some("image/png".to_string()), + sizes: Some(vec!["64x64".to_string()]), + }]), + meta: Some(resource_meta), + } + .annotate(Annotations { + audience: Some(vec![rmcp::model::Role::User]), + priority: Some(0.9), + last_modified: None, + }); + let mapped_resource = map_resource(resource); + assert_eq!(mapped_resource.title.as_deref(), Some("Quarterly Report")); + assert_eq!(mapped_resource.size, Some(42)); + assert_eq!( + mapped_resource + .annotations + .as_ref() + .and_then(|annotations| annotations.audience.as_ref()) + .cloned(), + Some(vec!["user".to_string()]) + ); + assert_eq!( + mapped_resource + .metadata + .as_ref() + .and_then(|meta| meta.get("source")), + Some(&json!("catalog")) + ); + + let prompt = rmcp::model::Prompt { + name: "summarize".to_string(), + title: Some("Summarize".to_string()), + description: Some("Summarize content".to_string()), + arguments: Some(vec![rmcp::model::PromptArgument { + name: "topic".to_string(), + title: Some("Topic".to_string()), + description: Some("Topic to summarize".to_string()), + required: Some(true), + }]), + icons: Some(vec![Icon { + src: "https://example.com/prompt.png".to_string(), + mime_type: Some("image/png".to_string()), + sizes: Some(vec!["16x16".to_string()]), + }]), + meta: None, + }; + let mapped_prompt = map_prompt(prompt); + assert_eq!(mapped_prompt.title.as_deref(), Some("Summarize")); + assert_eq!( + mapped_prompt + .arguments + .as_ref() + .and_then(|arguments| arguments.first()) + .and_then(|argument| argument.title.as_deref()), + Some("Topic") + ); + assert!(mapped_prompt.icons.is_some()); + } + + #[test] + fn mapping_preserves_structured_results_and_resource_links() { + let resource_link = RawResource { + uri: "file:///tmp/output.json".to_string(), + name: "output".to_string(), + title: Some("Output".to_string()), + description: Some("Generated output".to_string()), + mime_type: Some("application/json".to_string()), + size: Some(7), + icons: None, + meta: None, + }; + let mut result_meta = Meta::default(); + result_meta.insert("traceId".to_string(), json!("abc123")); + let result = rmcp::model::CallToolResult { + content: vec![ + Content::text("done"), + Content::resource_link(resource_link), + Content::image("aGVsbG8=", "image/png"), + ], + structured_content: Some(json!({ "ok": true })), + is_error: Some(false), + meta: Some(result_meta), + }; + + let mapped = map_tool_result(result); + assert_eq!(mapped.structured_content, Some(json!({ "ok": true }))); + assert_eq!(mapped.meta, Some(json!({ "traceId": "abc123" }))); + assert!(matches!( + mapped.content.as_ref().and_then(|content| content.get(1)), + Some(MCPToolResultContent::ResourceLink { uri, .. }) if uri == "file:///tmp/output.json" + )); + assert!(matches!( + mapped.content.as_ref().and_then(|content| content.get(2)), + Some(MCPToolResultContent::Image { mime_type, .. }) if mime_type == "image/png" + )); + } + + #[test] + fn mapping_preserves_prompt_message_blocks() { + let prompt_message = rmcp::model::PromptMessage { + role: rmcp::model::PromptMessageRole::User, + content: rmcp::model::PromptMessageContent::Text { + text: "hello".to_string(), + }, + }; + let mapped = map_prompt_message(prompt_message); + assert!(matches!( + mapped.content, + MCPPromptMessageContent::Block(MCPPromptMessageContentBlock::Text { ref text }) if text == "hello" + )); + + let resource_link = RawResource { + uri: "file:///tmp/input.md".to_string(), + name: "input".to_string(), + title: None, + description: Some("input".to_string()), + mime_type: Some("text/markdown".to_string()), + size: None, + icons: None, + meta: None, + } + .no_annotation(); + let prompt_message = rmcp::model::PromptMessage { + role: rmcp::model::PromptMessageRole::Assistant, + content: rmcp::model::PromptMessageContent::ResourceLink { + link: resource_link, + }, + }; + let mapped = map_prompt_message(prompt_message); + assert!(matches!( + mapped.content, + MCPPromptMessageContent::Block(MCPPromptMessageContentBlock::ResourceLink { ref uri, .. }) + if uri == "file:///tmp/input.md" + )); + + let embedded = rmcp::model::RawEmbeddedResource { + meta: Some(Meta::default()), + resource: ResourceContents::TextResourceContents { + uri: "file:///tmp/embedded.txt".to_string(), + mime_type: Some("text/plain".to_string()), + text: "embedded".to_string(), + meta: None, + }, + } + .no_annotation(); + let prompt_message = rmcp::model::PromptMessage { + role: rmcp::model::PromptMessageRole::Assistant, + content: rmcp::model::PromptMessageContent::Resource { resource: embedded }, + }; + let mapped = map_prompt_message(prompt_message); + assert!(matches!( + mapped.content, + MCPPromptMessageContent::Block(MCPPromptMessageContentBlock::Resource { ref resource }) + if resource.uri == "file:///tmp/embedded.txt" + )); + } +} diff --git a/src/crates/core/src/service/mcp/protocol/types.rs b/src/crates/core/src/service/mcp/protocol/types.rs index 08ad55bb..b00351f5 100644 --- a/src/crates/core/src/service/mcp/protocol/types.rs +++ b/src/crates/core/src/service/mcp/protocol/types.rs @@ -237,6 +237,8 @@ pub struct MCPPrompt { pub struct MCPPromptArgument { pub name: String, #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub description: Option, #[serde(default)] pub required: bool, @@ -270,6 +272,16 @@ pub enum MCPPromptMessageContentBlock { Image { data: String, mime_type: String }, #[serde(rename = "audio")] Audio { data: String, mime_type: String }, + #[serde(rename = "resource_link")] + ResourceLink { + uri: String, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + mime_type: Option, + }, #[serde(rename = "resource")] Resource { resource: MCPResourceContent }, } @@ -294,6 +306,15 @@ impl MCPPromptMessageContent { }) => { format!("[Audio: {}]", mime_type) } + MCPPromptMessageContent::Block(MCPPromptMessageContentBlock::ResourceLink { + uri, + name, + .. + }) => name + .as_ref() + .map_or_else(|| format!("[Resource Link: {}]", uri), |n| { + format!("[Resource Link: {} ({})]", n, uri) + }), MCPPromptMessageContent::Block(MCPPromptMessageContentBlock::Resource { resource }) => { format!("[Resource: {}]", resource.uri) } @@ -356,6 +377,10 @@ pub struct MCPToolAnnotations { pub read_only_hint: Option, #[serde(skip_serializing_if = "Option::is_none")] pub destructive_hint: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub idempotent_hint: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub open_world_hint: Option, } /// MCP tool definition (2025-11-25 spec). @@ -395,6 +420,9 @@ pub struct MCPToolResult { /// Structured data for MCP App UI (ext-apps ontoolresult expects this). #[serde(skip_serializing_if = "Option::is_none")] pub structured_content: Option, + /// Optional protocol-level metadata returned by the server. + #[serde(skip_serializing_if = "Option::is_none", rename = "_meta")] + pub meta: Option, } /// MCP tool result content (2025-11-25 spec). diff --git a/src/crates/core/src/service/mcp/server/config.rs b/src/crates/core/src/service/mcp/server/config.rs new file mode 100644 index 00000000..0ba2797f --- /dev/null +++ b/src/crates/core/src/service/mcp/server/config.rs @@ -0,0 +1,170 @@ +//! MCP server configuration types. + +use super::MCPServerType; +use crate::service::mcp::config::ConfigLocation; +use crate::util::errors::{BitFunError, BitFunResult}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::HashMap; + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "kebab-case")] +pub enum MCPServerTransport { + Stdio, + StreamableHttp, + Sse, +} + +impl MCPServerTransport { + pub const fn as_str(&self) -> &'static str { + match self { + Self::Stdio => "stdio", + Self::StreamableHttp => "streamable-http", + Self::Sse => "sse", + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +pub struct MCPServerOAuthConfig { + #[serde(default)] + pub scopes: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub client_name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub client_metadata_url: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub callback_host: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub callback_port: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub callback_path: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +pub struct MCPServerXaaConfig { + #[serde(skip_serializing_if = "Option::is_none")] + pub issuer: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub audience: Option, + #[serde(default)] + pub scopes: Vec, +} + +/// MCP server configuration. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct MCPServerConfig { + pub id: String, + pub name: String, + #[serde(rename = "type")] + pub server_type: MCPServerType, + #[serde(skip_serializing_if = "Option::is_none")] + pub transport: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub command: Option, + #[serde(default)] + pub args: Vec, + #[serde(default)] + pub env: HashMap, + /// Additional HTTP headers for remote MCP servers (Cursor-style `headers`). + #[serde(default)] + pub headers: HashMap, + #[serde(skip_serializing_if = "Option::is_none")] + pub url: Option, + #[serde(default = "default_true")] + pub auto_start: bool, + #[serde(default = "default_true")] + pub enabled: bool, + pub location: ConfigLocation, + #[serde(default)] + pub capabilities: Vec, + #[serde(default)] + pub settings: HashMap, + #[serde(skip_serializing_if = "Option::is_none")] + pub oauth: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub xaa: Option, +} + +fn default_true() -> bool { + true +} + +impl MCPServerConfig { + pub fn resolved_transport(&self) -> MCPServerTransport { + self.transport.unwrap_or(match self.server_type { + MCPServerType::Local => MCPServerTransport::Stdio, + MCPServerType::Remote => MCPServerTransport::StreamableHttp, + }) + } + + /// Validates the configuration. + pub fn validate(&self) -> BitFunResult<()> { + if self.id.is_empty() { + return Err(BitFunError::Configuration( + "MCP server id cannot be empty".to_string(), + )); + } + + if self.name.is_empty() { + return Err(BitFunError::Configuration( + "MCP server name cannot be empty".to_string(), + )); + } + + let transport = self.resolved_transport(); + match self.server_type { + MCPServerType::Local => { + if self.command.is_none() { + return Err(BitFunError::Configuration(format!( + "Local MCP server '{}' must have a command", + self.id + ))); + } + + if transport != MCPServerTransport::Stdio { + return Err(BitFunError::Configuration(format!( + "Local MCP server '{}' must use stdio transport, got '{}'", + self.id, + transport.as_str() + ))); + } + } + MCPServerType::Remote => { + if self.url.is_none() { + return Err(BitFunError::Configuration(format!( + "Remote MCP server '{}' must have a URL", + self.id + ))); + } + + if let Some(oauth) = &self.oauth { + if let Some(port) = oauth.callback_port { + if port == 0 { + return Err(BitFunError::Configuration(format!( + "Remote MCP server '{}' OAuth callbackPort must be greater than 0", + self.id + ))); + } + } + } + + if !matches!( + transport, + MCPServerTransport::StreamableHttp | MCPServerTransport::Sse + ) { + return Err(BitFunError::Configuration(format!( + "Remote MCP server '{}' must use streamable-http or sse transport, got '{}'", + self.id, + transport.as_str() + ))); + } + } + } + + Ok(()) + } +} diff --git a/src/crates/core/src/service/mcp/server/connection.rs b/src/crates/core/src/service/mcp/server/connection.rs index 04c08d58..b12723a6 100644 --- a/src/crates/core/src/service/mcp/server/connection.rs +++ b/src/crates/core/src/service/mcp/server/connection.rs @@ -5,10 +5,10 @@ use crate::service::mcp::protocol::{ create_initialize_request, create_ping_request, create_prompts_get_request, create_prompts_list_request, create_resources_list_request, create_resources_read_request, - create_tools_call_request, create_tools_list_request, parse_response_result, - transport::MCPTransport, transport_remote::RemoteMCPTransport, InitializeResult, MCPMessage, - MCPResponse, MCPToolResult, PromptsGetResult, PromptsListResult, ResourcesListResult, - ResourcesReadResult, ToolsListResult, + create_tools_call_request, create_tools_list_request, parse_response_result, InitializeResult, + MCPError, MCPMessage, MCPResponse, MCPToolResult, MCPTransport, PromptsGetResult, + PromptsListResult, RemoteMCPTransport, ResourcesListResult, ResourcesReadResult, + ToolsListResult, }; use crate::util::errors::{BitFunError, BitFunResult}; use log::{debug, warn}; @@ -17,7 +17,7 @@ use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use tokio::process::ChildStdin; -use tokio::sync::{mpsc, oneshot, RwLock}; +use tokio::sync::{broadcast, mpsc, oneshot, RwLock}; /// Request/response waiter. type ResponseWaiter = oneshot::Sender; @@ -28,11 +28,27 @@ enum TransportType { Remote(Arc), } +/// Connection lifecycle / protocol events. +#[derive(Debug, Clone)] +pub enum MCPConnectionEvent { + Notification { + method: String, + params: Option, + }, + Request { + request_id: Value, + method: String, + params: Option, + }, + Closed, +} + /// MCP connection. pub struct MCPConnection { transport: TransportType, pending_requests: Arc>>, request_timeout: Duration, + event_tx: broadcast::Sender, } impl MCPConnection { @@ -40,36 +56,49 @@ impl MCPConnection { pub fn new_local(stdin: ChildStdin, message_rx: mpsc::UnboundedReceiver) -> Self { let transport = Arc::new(MCPTransport::new(stdin)); let pending_requests = Arc::new(RwLock::new(HashMap::new())); + let (event_tx, _) = broadcast::channel(64); let pending = pending_requests.clone(); + let event_tx_clone = event_tx.clone(); tokio::spawn(async move { - Self::handle_messages(message_rx, pending).await; + Self::handle_messages(message_rx, pending, event_tx_clone).await; }); Self { transport: TransportType::Local(transport), pending_requests, request_timeout: Duration::from_secs(180), + event_tx, } } /// Creates a new remote connection instance (Streamable HTTP). - pub fn new_remote(url: String, headers: HashMap) -> Self { + pub async fn new_remote( + server_id: &str, + url: String, + headers: HashMap, + oauth_enabled: bool, + ) -> BitFunResult { let request_timeout = Duration::from_secs(180); - let transport = Arc::new(RemoteMCPTransport::new(url, headers, request_timeout)); + let transport = Arc::new( + RemoteMCPTransport::new(server_id, url, headers, request_timeout, oauth_enabled) + .await?, + ); let pending_requests = Arc::new(RwLock::new(HashMap::new())); + let (event_tx, _) = broadcast::channel(64); - Self { + Ok(Self { transport: TransportType::Remote(transport), pending_requests, request_timeout, - } + event_tx, + }) } /// Returns the auth token for a remote connection. pub async fn get_auth_token(&self) -> Option { match &self.transport { - TransportType::Remote(transport) => transport.get_auth_token(), + TransportType::Remote(transport) => transport.get_auth_token().await, TransportType::Local(_) => None, } } @@ -79,10 +108,16 @@ impl MCPConnection { Self::new_local(stdin, message_rx) } + /// Subscribes to connection events. + pub fn subscribe_events(&self) -> broadcast::Receiver { + self.event_tx.subscribe() + } + /// Handles received messages. async fn handle_messages( mut rx: mpsc::UnboundedReceiver, pending_requests: Arc>>, + event_tx: broadcast::Sender, ) { while let Some(message) = rx.recv().await { match message { @@ -98,12 +133,23 @@ impl MCPConnection { } MCPMessage::Notification(notification) => { debug!("Received MCP notification: method={}", notification.method); + let _ = event_tx.send(MCPConnectionEvent::Notification { + method: notification.method, + params: notification.params, + }); } - MCPMessage::Request(_request) => { + MCPMessage::Request(request) => { warn!("Received unexpected request from MCP server"); + let _ = event_tx.send(MCPConnectionEvent::Request { + request_id: request.id, + method: request.method, + params: request.params, + }); } } } + + let _ = event_tx.send(MCPConnectionEvent::Closed); } /// Sends a request and waits for the response. @@ -272,6 +318,28 @@ impl MCPConnection { TransportType::Remote(transport) => transport.ping().await, } } + + /// Sends a JSON-RPC success response for a server-initiated request. + pub async fn send_response(&self, request_id: Value, result: Value) -> BitFunResult<()> { + match &self.transport { + TransportType::Local(transport) => transport.send_response(request_id, result).await, + TransportType::Remote(_) => Err(BitFunError::NotImplemented( + "Sending server-request responses is not supported for Streamable HTTP connections" + .to_string(), + )), + } + } + + /// Sends a JSON-RPC error response for a server-initiated request. + pub async fn send_error(&self, request_id: Value, error: MCPError) -> BitFunResult<()> { + match &self.transport { + TransportType::Local(transport) => transport.send_error(request_id, error).await, + TransportType::Remote(_) => Err(BitFunError::NotImplemented( + "Sending server-request errors is not supported for Streamable HTTP connections" + .to_string(), + )), + } + } } /// MCP connection pool. diff --git a/src/crates/core/src/service/mcp/server/manager/auth.rs b/src/crates/core/src/service/mcp/server/manager/auth.rs new file mode 100644 index 00000000..19c1e02f --- /dev/null +++ b/src/crates/core/src/service/mcp/server/manager/auth.rs @@ -0,0 +1,805 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use axum::{ + extract::{Query, State}, + http::HeaderMap, + response::{Html, IntoResponse}, + routing::get, + Router, +}; +use reqwest::Url; +use tokio::sync::{oneshot, Mutex}; +use tokio::time::{timeout, Duration}; + +use crate::service::config::app_language::get_app_language_code; +use crate::service::mcp::auth::{ + clear_stored_oauth_credentials, map_auth_error, prepare_remote_oauth_authorization, + MCPRemoteOAuthSessionSnapshot, MCPRemoteOAuthStatus, +}; +use crate::service::mcp::server::MCPServerType; +use crate::util::errors::{BitFunError, BitFunResult}; + +use super::{ActiveRemoteOAuthSession, MCPServerManager}; + +const OAUTH_CALLBACK_TIMEOUT: Duration = Duration::from_secs(300); + +#[derive(Debug)] +struct OAuthCallbackPayload { + code: Option, + state: Option, + error: Option, + error_description: Option, +} + +#[derive(Clone, Copy)] +enum OAuthCallbackLocale { + ZhCN, + EnUS, +} + +struct OAuthCallbackPageCopy { + html_lang: &'static str, + page_title: &'static str, + brand_label: &'static str, + badge_success: &'static str, + badge_warning: &'static str, + badge_error: &'static str, + success_title: &'static str, + success_message: &'static str, + success_detail_title: &'static str, + success_detail_body: &'static str, + warning_title: &'static str, + warning_message: &'static str, + warning_detail_title: &'static str, + error_title: &'static str, + error_message: &'static str, + error_detail_title: &'static str, + close_hint: &'static str, +} + +impl OAuthCallbackLocale { + fn from_language_code(value: &str) -> Option { + match value { + "zh-CN" | "zh" => Some(Self::ZhCN), + "en-US" | "en" => Some(Self::EnUS), + _ => None, + } + } + + fn from_accept_language(value: &str) -> Self { + value + .split(',') + .filter_map(|part| part.split(';').next()) + .find_map(|part| Self::from_language_code(part.trim())) + .unwrap_or(Self::ZhCN) + } + + fn copy(self) -> OAuthCallbackPageCopy { + match self { + Self::ZhCN => OAuthCallbackPageCopy { + html_lang: "zh-CN", + page_title: "BitFun OAuth 回调", + brand_label: "BitFun Desktop", + badge_success: "已收到授权", + badge_warning: "回调参数不完整", + badge_error: "授权失败", + success_title: "BitFun 已收到 OAuth 回调", + success_message: "可以返回 BitFun。应用正在交换授权码并重新连接 MCP 服务器。", + success_detail_title: "接下来会发生什么", + success_detail_body: + "这个页面可以直接关闭。如果 BitFun 没有自动完成重连,请回到 MCP 设置页后重试 OAuth。", + warning_title: "BitFun 收到的 OAuth 回调缺少必要参数", + warning_message: + "OAuth 提供方已跳转回来,但缺少必须的参数。请返回 BitFun 重新发起登录流程。", + warning_detail_title: "缺少的参数", + error_title: "BitFun 未能完成 OAuth 授权", + error_message: + "请返回 BitFun,并根据下面的提供方返回信息检查问题后重新发起 OAuth。", + error_detail_title: "提供方返回", + close_hint: "处理完成后,这个页面可以直接关闭。", + }, + Self::EnUS => OAuthCallbackPageCopy { + html_lang: "en-US", + page_title: "BitFun OAuth Callback", + brand_label: "BitFun Desktop", + badge_success: "Authorization received", + badge_warning: "Callback incomplete", + badge_error: "Authorization failed", + success_title: "BitFun received the OAuth callback", + success_message: + "You can return to BitFun now. The app is exchanging the authorization code and reconnecting the MCP server.", + success_detail_title: "What happens next", + success_detail_body: + "This page can be closed now. If BitFun does not finish reconnecting automatically, return to MCP settings and retry OAuth.", + warning_title: "BitFun received an OAuth callback with missing parameters", + warning_message: + "The provider redirected back, but required OAuth parameters were missing. Return to BitFun and start the sign-in flow again.", + warning_detail_title: "Missing parameters", + error_title: "BitFun could not finish the OAuth authorization", + error_message: + "Return to BitFun and review the provider response below before retrying OAuth.", + error_detail_title: "Provider response", + close_hint: "This page can be closed after you review the status.", + }, + } + } +} + +fn escape_html(input: &str) -> String { + input + .replace('&', "&") + .replace('<', "<") + .replace('>', ">") + .replace('"', """) + .replace('\'', "'") +} + +fn resolve_oauth_callback_locale( + preferred_language: Option<&str>, + accept_language: Option<&str>, +) -> OAuthCallbackLocale { + preferred_language + .and_then(OAuthCallbackLocale::from_language_code) + .or_else(|| accept_language.map(OAuthCallbackLocale::from_accept_language)) + .unwrap_or(OAuthCallbackLocale::ZhCN) +} + +fn render_oauth_callback_page( + payload: &OAuthCallbackPayload, + locale: OAuthCallbackLocale, +) -> String { + let copy = locale.copy(); + let (badge, badge_class, title, message, detail_title, detail_body, icon_label) = + if let Some(error) = payload.error.as_deref() { + let description = payload + .error_description + .as_deref() + .unwrap_or(match locale { + OAuthCallbackLocale::ZhCN => "OAuth 提供方拒绝了这次授权请求。", + OAuthCallbackLocale::EnUS => "The provider rejected the authorization request.", + }); + ( + copy.badge_error, + "is-error", + copy.error_title, + copy.error_message, + copy.error_detail_title, + format!("{}: {}", escape_html(error), escape_html(description)), + "!", + ) + } else if payload.code.is_some() && payload.state.is_some() { + ( + copy.badge_success, + "is-success", + copy.success_title, + copy.success_message, + copy.success_detail_title, + copy.success_detail_body.to_string(), + match locale { + OAuthCallbackLocale::ZhCN => "完成", + OAuthCallbackLocale::EnUS => "Done", + }, + ) + } else { + let mut missing = Vec::new(); + if payload.code.is_none() { + missing.push("code"); + } + if payload.state.is_none() { + missing.push("state"); + } + ( + copy.badge_warning, + "is-warning", + copy.warning_title, + copy.warning_message, + copy.warning_detail_title, + escape_html(&missing.join(", ")), + "?", + ) + }; + + format!( + r#" + + + + + {page_title} + + + +
+
+
+
+
+
BF
+
+ {brand_label} +

{title}

+
+
+
+
{badge}
+

{message}

+
+
{icon_label}
+
+

{detail_title}

+

{detail_body}

+
+
+
+

{close_hint}

+
+
+
+
+ +"#, + html_lang = copy.html_lang, + page_title = copy.page_title, + brand_label = copy.brand_label, + title = title, + badge = badge, + badge_class = badge_class, + message = message, + detail_title = detail_title, + detail_body = detail_body, + icon_label = icon_label, + close_hint = copy.close_hint, + ) +} + +#[derive(Clone)] +struct OAuthCallbackAppState { + callback_tx: Arc>>>, + preferred_language: String, +} + +impl MCPServerManager { + pub(super) async fn set_oauth_snapshot( + session: &Arc, + snapshot: MCPRemoteOAuthSessionSnapshot, + ) { + *session.snapshot.write().await = snapshot; + } + + pub(super) async fn update_oauth_snapshot( + session: &Arc, + update: F, + ) -> MCPRemoteOAuthSessionSnapshot + where + F: FnOnce(&mut MCPRemoteOAuthSessionSnapshot), + { + let mut snapshot = session.snapshot.write().await; + update(&mut snapshot); + snapshot.clone() + } + + pub(super) async fn insert_oauth_session( + &self, + server_id: &str, + session: Arc, + ) -> Option> { + self.oauth_sessions + .write() + .await + .insert(server_id.to_string(), session) + } + + pub(super) async fn shutdown_oauth_session(session: &Arc) { + if let Some(shutdown_tx) = session.shutdown_tx.lock().await.take() { + let _ = shutdown_tx.send(()); + } + } + + async fn fail_oauth_session( + session: &Arc, + message: String, + ) -> MCPRemoteOAuthSessionSnapshot { + let snapshot = MCPServerManager::update_oauth_snapshot(session, |snapshot| { + snapshot.status = MCPRemoteOAuthStatus::Failed; + snapshot.message = Some(message); + }) + .await; + Self::shutdown_oauth_session(session).await; + snapshot + } + + pub async fn start_remote_oauth_authorization( + &self, + server_id: &str, + ) -> BitFunResult { + let config = self + .config_service + .get_server_config(server_id) + .await? + .ok_or_else(|| { + BitFunError::NotFound(format!("MCP server config not found: {}", server_id)) + })?; + + if config.server_type != MCPServerType::Remote { + return Err(BitFunError::Validation(format!( + "MCP server '{}' is not a remote server", + server_id + ))); + } + + if let Some(existing) = self.oauth_sessions.write().await.remove(server_id) { + Self::shutdown_oauth_session(&existing).await; + } + + let prepared = prepare_remote_oauth_authorization(&config).await?; + let callback_path = Url::parse(&prepared.redirect_uri) + .map_err(|error| { + BitFunError::MCPError(format!( + "Invalid OAuth redirect URI for server '{}': {}", + server_id, error + )) + })? + .path() + .to_string(); + + let initial_snapshot = MCPRemoteOAuthSessionSnapshot::new( + server_id.to_string(), + MCPRemoteOAuthStatus::AwaitingBrowser, + Some(prepared.authorization_url.clone()), + Some(prepared.redirect_uri.clone()), + Some("Open the authorization URL to continue OAuth sign-in.".to_string()), + ); + + let (callback_tx, callback_rx) = oneshot::channel(); + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let session = Arc::new(ActiveRemoteOAuthSession { + snapshot: Arc::new(tokio::sync::RwLock::new(initial_snapshot.clone())), + shutdown_tx: Mutex::new(Some(shutdown_tx)), + }); + + if let Some(previous) = self.insert_oauth_session(server_id, session.clone()).await { + Self::shutdown_oauth_session(&previous).await; + } + + let callback_state = OAuthCallbackAppState { + callback_tx: Arc::new(Mutex::new(Some(callback_tx))), + preferred_language: get_app_language_code().await, + }; + let router = Router::new() + .route(&callback_path, get(handle_oauth_callback)) + .with_state(callback_state); + let callback_server_session = session.clone(); + let callback_server_id = server_id.to_string(); + tokio::spawn(async move { + let server = + axum::serve(prepared.listener, router).with_graceful_shutdown(async move { + let _ = shutdown_rx.await; + }); + + if let Err(error) = server.await { + let _ = + MCPServerManager::update_oauth_snapshot(&callback_server_session, |snapshot| { + if matches!( + snapshot.status, + MCPRemoteOAuthStatus::Authorized | MCPRemoteOAuthStatus::Cancelled + ) { + return; + } + snapshot.status = MCPRemoteOAuthStatus::Failed; + snapshot.message = Some(format!( + "OAuth callback listener failed for server '{}': {}", + callback_server_id, error + )); + }) + .await; + } + }); + + let manager = self.clone(); + let callback_session = session.clone(); + let callback_server_id = server_id.to_string(); + let authorization_url = prepared.authorization_url.clone(); + let redirect_uri = prepared.redirect_uri.clone(); + let mut oauth_state = prepared.state; + tokio::spawn(async move { + let _ = MCPServerManager::update_oauth_snapshot(&callback_session, |snapshot| { + snapshot.status = MCPRemoteOAuthStatus::AwaitingCallback; + snapshot.message = + Some("Waiting for the OAuth provider to redirect back to BitFun.".to_string()); + }) + .await; + + let callback = match timeout(OAUTH_CALLBACK_TIMEOUT, callback_rx).await { + Ok(Ok(callback)) => callback, + Ok(Err(_)) => { + let _ = + MCPServerManager::update_oauth_snapshot(&callback_session, |snapshot| { + snapshot.status = MCPRemoteOAuthStatus::Cancelled; + snapshot.message = + Some("OAuth authorization was cancelled.".to_string()); + }) + .await; + Self::shutdown_oauth_session(&callback_session).await; + return; + } + Err(_) => { + let _ = MCPServerManager::fail_oauth_session( + &callback_session, + "OAuth authorization timed out before the provider redirected back." + .to_string(), + ) + .await; + return; + } + }; + + if let Some(error) = callback.error { + let description = callback + .error_description + .map(|value| format!(": {}", value)) + .unwrap_or_default(); + let _ = MCPServerManager::fail_oauth_session( + &callback_session, + format!("OAuth provider returned '{}{}'", error, description), + ) + .await; + return; + } + + let code = match callback.code { + Some(code) => code, + None => { + let _ = MCPServerManager::fail_oauth_session( + &callback_session, + "OAuth callback did not include an authorization code.".to_string(), + ) + .await; + return; + } + }; + + let state = match callback.state { + Some(state) => state, + None => { + let _ = MCPServerManager::fail_oauth_session( + &callback_session, + "OAuth callback did not include a state token.".to_string(), + ) + .await; + return; + } + }; + + let _ = MCPServerManager::update_oauth_snapshot(&callback_session, |snapshot| { + snapshot.status = MCPRemoteOAuthStatus::ExchangingToken; + snapshot.message = + Some("Exchanging the authorization code for an access token.".to_string()); + }) + .await; + + match oauth_state.handle_callback(&code, &state).await { + Ok(_) => { + let _ = MCPServerManager::set_oauth_snapshot( + &callback_session, + MCPRemoteOAuthSessionSnapshot::new( + callback_server_id.clone(), + MCPRemoteOAuthStatus::Authorized, + Some(authorization_url.clone()), + Some(redirect_uri.clone()), + Some( + "OAuth authorization completed. Reconnecting MCP server." + .to_string(), + ), + ), + ) + .await; + + if let Some(shutdown_tx) = callback_session.shutdown_tx.lock().await.take() { + let _ = shutdown_tx.send(()); + } + + manager.clear_reconnect_state(&callback_server_id).await; + let _ = manager.stop_server(&callback_server_id).await; + if let Err(error) = manager.start_server(&callback_server_id).await { + let _ = MCPServerManager::update_oauth_snapshot( + &callback_session, + |snapshot| { + snapshot.message = Some(format!( + "OAuth token saved, but reconnect failed: {}", + error + )); + }, + ) + .await; + } + } + Err(error) => { + let _ = MCPServerManager::fail_oauth_session( + &callback_session, + map_auth_error(error).to_string(), + ) + .await; + } + } + }); + + Ok(initial_snapshot) + } + + pub async fn get_remote_oauth_session( + &self, + server_id: &str, + ) -> Option { + let session = self.oauth_sessions.read().await.get(server_id).cloned()?; + let snapshot = session.snapshot.read().await.clone(); + Some(snapshot) + } + + pub async fn cancel_remote_oauth_authorization(&self, server_id: &str) -> BitFunResult<()> { + let session = self.oauth_sessions.write().await.remove(server_id); + if let Some(session) = session { + let _ = MCPServerManager::update_oauth_snapshot(&session, |snapshot| { + snapshot.status = MCPRemoteOAuthStatus::Cancelled; + snapshot.message = Some("OAuth authorization was cancelled.".to_string()); + }) + .await; + Self::shutdown_oauth_session(&session).await; + } + Ok(()) + } + + pub async fn clear_remote_oauth_credentials(&self, server_id: &str) -> BitFunResult<()> { + self.cancel_remote_oauth_authorization(server_id).await?; + clear_stored_oauth_credentials(server_id).await + } +} + +async fn handle_oauth_callback( + State(state): State, + headers: HeaderMap, + Query(params): Query>, +) -> impl IntoResponse { + let payload = OAuthCallbackPayload { + code: params.get("code").cloned(), + state: params.get("state").cloned(), + error: params.get("error").cloned(), + error_description: params.get("error_description").cloned(), + }; + let accept_language = headers + .get(axum::http::header::ACCEPT_LANGUAGE) + .and_then(|value| value.to_str().ok()); + let locale = + resolve_oauth_callback_locale(Some(state.preferred_language.as_str()), accept_language); + let page = render_oauth_callback_page(&payload, locale); + + if let Some(callback_tx) = state.callback_tx.lock().await.take() { + let _ = callback_tx.send(payload); + } + + Html(page) +} diff --git a/src/crates/core/src/service/mcp/server/manager/catalog.rs b/src/crates/core/src/service/mcp/server/manager/catalog.rs new file mode 100644 index 00000000..78db90ca --- /dev/null +++ b/src/crates/core/src/service/mcp/server/manager/catalog.rs @@ -0,0 +1,122 @@ +use super::*; + +impl MCPServerManager { + pub(super) async fn refresh_resources_catalog( + &self, + server_id: &str, + connection: Arc, + ) -> BitFunResult { + let mut resources = Vec::new(); + let mut cursor = None::; + let mut visited = HashSet::new(); + + loop { + let result = connection.list_resources(cursor.clone()).await?; + resources.extend(result.resources); + + match result.next_cursor { + Some(next) => { + if !visited.insert(next.clone()) { + break; + } + cursor = Some(next); + } + None => break, + } + } + + let count = resources.len(); + let mut cache = self.resource_catalog_cache.write().await; + cache.insert(server_id.to_string(), resources); + Ok(count) + } + + pub(super) async fn refresh_prompts_catalog( + &self, + server_id: &str, + connection: Arc, + ) -> BitFunResult { + let mut prompts = Vec::new(); + let mut cursor = None::; + let mut visited = HashSet::new(); + + loop { + let result = connection.list_prompts(cursor.clone()).await?; + prompts.extend(result.prompts); + + match result.next_cursor { + Some(next) => { + if !visited.insert(next.clone()) { + break; + } + cursor = Some(next); + } + None => break, + } + } + + let count = prompts.len(); + let mut cache = self.prompt_catalog_cache.write().await; + cache.insert(server_id.to_string(), prompts); + Ok(count) + } + + pub(super) async fn warm_catalog_caches( + &self, + server_id: &str, + connection: Arc, + ) { + if let Err(e) = self + .refresh_resources_catalog(server_id, connection.clone()) + .await + { + debug!( + "Skipping MCP resources catalog warmup: server_id={} error={}", + server_id, e + ); + } + + if let Err(e) = self.refresh_prompts_catalog(server_id, connection).await { + debug!( + "Skipping MCP prompts catalog warmup: server_id={} error={}", + server_id, e + ); + } + } + + /// Returns cached MCP resources for a server. + pub async fn get_cached_resources(&self, server_id: &str) -> Vec { + self.resource_catalog_cache + .read() + .await + .get(server_id) + .cloned() + .unwrap_or_default() + } + + /// Returns cached MCP prompts for a server. + pub async fn get_cached_prompts(&self, server_id: &str) -> Vec { + self.prompt_catalog_cache + .read() + .await + .get(server_id) + .cloned() + .unwrap_or_default() + } + + /// Refreshes resources catalog cache for one server. + pub async fn refresh_server_resource_catalog(&self, server_id: &str) -> BitFunResult { + let connection = self.get_connection(server_id).await.ok_or_else(|| { + BitFunError::NotFound(format!("MCP server connection not found: {}", server_id)) + })?; + self.refresh_resources_catalog(server_id, connection).await + } + + /// Refreshes prompts catalog cache for one server. + pub async fn refresh_server_prompt_catalog(&self, server_id: &str) -> BitFunResult { + let connection = self.get_connection(server_id).await.ok_or_else(|| { + BitFunError::NotFound(format!("MCP server connection not found: {}", server_id)) + })?; + self.refresh_prompts_catalog(server_id, connection).await + } +} diff --git a/src/crates/core/src/service/mcp/server/manager/interaction.rs b/src/crates/core/src/service/mcp/server/manager/interaction.rs new file mode 100644 index 00000000..24fbdb0b --- /dev/null +++ b/src/crates/core/src/service/mcp/server/manager/interaction.rs @@ -0,0 +1,403 @@ +use super::*; + +impl MCPServerManager { + pub(super) fn detect_list_changed_kind(method: &str) -> Option { + match method { + "notifications/tools/list_changed" + | "notifications/tools/listChanged" + | "tools/list_changed" => Some(ListChangedKind::Tools), + "notifications/prompts/list_changed" + | "notifications/prompts/listChanged" + | "prompts/list_changed" => Some(ListChangedKind::Prompts), + "notifications/resources/list_changed" + | "notifications/resources/listChanged" + | "resources/list_changed" => Some(ListChangedKind::Resources), + _ => None, + } + } + + fn path_to_file_uri(path: &Path) -> Option { + reqwest::Url::from_directory_path(path) + .ok() + .map(|u| u.to_string()) + } + + fn build_roots_list_result() -> Value { + let mut candidate_roots = Vec::new(); + + if let Some(workspace_service) = get_global_workspace_service() { + if let Some(workspace_root) = workspace_service.try_get_current_workspace_path() { + candidate_roots.push(workspace_root); + } + } + + if candidate_roots.is_empty() { + if let Ok(current_dir) = std::env::current_dir() { + candidate_roots.push(current_dir); + } + } + + let mut seen_uris = HashSet::new(); + let mut roots = Vec::new(); + for root in candidate_roots { + let Some(uri) = Self::path_to_file_uri(&root) else { + continue; + }; + if !seen_uris.insert(uri.clone()) { + continue; + } + let name = root + .file_name() + .and_then(|v| v.to_str()) + .filter(|v| !v.is_empty()) + .unwrap_or("BitFun Workspace") + .to_string(); + roots.push(json!({ + "uri": uri, + "name": name, + })); + } + + json!({ "roots": roots }) + } + + async fn handle_server_request( + &self, + server_id: &str, + server_name: &str, + connection: Arc, + request_id: Value, + method: String, + params: Option, + ) { + match method.as_str() { + "ping" => { + if let Err(e) = connection.send_response(request_id, json!({})).await { + warn!( + "Failed to respond to MCP ping request: server_name={} server_id={} error={}", + server_name, server_id, e + ); + } + } + "roots/list" => { + let result = Self::build_roots_list_result(); + if let Err(e) = connection.send_response(request_id, result).await { + warn!( + "Failed to respond to MCP roots/list request: server_name={} server_id={} error={}", + server_name, server_id, e + ); + } else { + info!( + "Handled MCP roots/list request: server_name={} server_id={}", + server_name, server_id + ); + } + } + "elicitation/create" | "sampling/createMessage" => { + self.handle_interactive_server_request( + server_id, + server_name, + connection, + request_id, + method, + params, + ) + .await; + } + _ => { + let error = MCPError::method_not_found(method.clone()); + if let Err(e) = connection.send_error(request_id, error).await { + warn!( + "Failed to respond with method_not_found for MCP request: server_name={} server_id={} method={} error={}", + server_name, server_id, method, e + ); + } else { + warn!( + "Rejected unsupported MCP server request: server_name={} server_id={} method={}", + server_name, server_id, method + ); + } + } + } + } + + async fn handle_interactive_server_request( + &self, + server_id: &str, + server_name: &str, + connection: Arc, + request_id: Value, + method: String, + params: Option, + ) { + let interaction_id = format!("mcp_interaction_{}", uuid::Uuid::new_v4()); + let (tx, rx) = oneshot::channel(); + + { + let mut pending = self.pending_interactions.write().await; + pending.insert(interaction_id.clone(), PendingMCPInteraction { sender: tx }); + } + + let event_payload = json!({ + "interactionId": interaction_id, + "serverId": server_id, + "serverName": server_name, + "method": method.clone(), + "params": params, + }); + + let event_system = get_global_event_system(); + if let Err(e) = event_system + .emit(BackendEvent::Custom { + event_name: "backend-event-mcpinteractionrequest".to_string(), + payload: event_payload, + }) + .await + { + warn!( + "Failed to emit MCP interaction request event: server_name={} server_id={} method={} error={}", + server_name, server_id, method, e + ); + } + + let wait_timeout = Duration::from_secs(600); + let decision = tokio::time::timeout(wait_timeout, rx).await; + { + let mut pending = self.pending_interactions.write().await; + pending.remove(&interaction_id); + } + + match decision { + Ok(Ok(MCPInteractionDecision::Accept { result })) => { + if let Err(e) = connection.send_response(request_id, result).await { + warn!( + "Failed to send interactive MCP response: server_name={} server_id={} method={} error={}", + server_name, server_id, method, e + ); + } else { + info!( + "Handled interactive MCP request: server_name={} server_id={} method={}", + server_name, server_id, method + ); + } + } + Ok(Ok(MCPInteractionDecision::Reject { error })) => { + if let Err(e) = connection.send_error(request_id, error).await { + warn!( + "Failed to send interactive MCP rejection: server_name={} server_id={} method={} error={}", + server_name, server_id, method, e + ); + } else { + info!( + "Rejected interactive MCP request: server_name={} server_id={} method={}", + server_name, server_id, method + ); + } + } + Ok(Err(_)) => { + let error = MCPError::internal_error(format!( + "MCP interaction channel closed before response: {}", + method + )); + if let Err(e) = connection.send_error(request_id, error).await { + warn!( + "Failed to send interaction channel-closed error: server_name={} server_id={} method={} error={}", + server_name, server_id, method, e + ); + } + } + Err(_) => { + let error = MCPError::internal_error(format!( + "Timed out waiting for user interaction response for method: {}", + method + )); + if let Err(e) = connection.send_error(request_id, error).await { + warn!( + "Failed to send interaction timeout error: server_name={} server_id={} method={} error={}", + server_name, server_id, method, e + ); + } else { + warn!( + "Timed out waiting for interactive MCP request: server_name={} server_id={} method={} timeout={}s", + server_name, server_id, method, wait_timeout.as_secs() + ); + } + } + } + } + + pub async fn submit_interaction_response( + &self, + interaction_id: &str, + approve: bool, + result: Option, + error_message: Option, + error_code: Option, + error_data: Option, + ) -> BitFunResult<()> { + let pending = { + let mut interactions = self.pending_interactions.write().await; + interactions.remove(interaction_id) + }; + + let Some(pending) = pending else { + return Err(BitFunError::NotFound(format!( + "MCP interaction not found: {}", + interaction_id + ))); + }; + + let decision = if approve { + MCPInteractionDecision::Accept { + result: result.unwrap_or_else(|| json!({})), + } + } else { + MCPInteractionDecision::Reject { + error: MCPError { + code: error_code.unwrap_or(MCPError::INVALID_REQUEST), + message: error_message + .unwrap_or_else(|| "User rejected MCP interaction request".to_string()), + data: error_data, + }, + } + }; + + pending.sender.send(decision).map_err(|_| { + BitFunError::MCPError(format!( + "Failed to deliver MCP interaction response (receiver dropped): {}", + interaction_id + )) + })?; + + Ok(()) + } + + pub(super) async fn start_connection_event_listener( + &self, + server_id: &str, + server_name: &str, + connection: Arc, + ) { + self.stop_connection_event_listener(server_id).await; + + let manager = self.clone(); + let server_id_owned = server_id.to_string(); + let server_name_owned = server_name.to_string(); + let mut rx = connection.subscribe_events(); + let connection_for_refresh = connection.clone(); + + let handle = tokio::spawn(async move { + loop { + match rx.recv().await { + Ok(MCPConnectionEvent::Notification { method, .. }) => { + match Self::detect_list_changed_kind(&method) { + Some(ListChangedKind::Tools) => { + info!( + "Received MCP tools list-changed notification: server_name={} server_id={}", + server_name_owned, server_id_owned + ); + if let Err(e) = manager + .refresh_mcp_tools( + &server_id_owned, + &server_name_owned, + connection_for_refresh.clone(), + ) + .await + { + warn!( + "Failed to refresh MCP tools after list-changed notification: server_name={} server_id={} error={}", + server_name_owned, server_id_owned, e + ); + } + } + Some(ListChangedKind::Prompts) => { + info!( + "Received MCP prompts list-changed notification: server_name={} server_id={}", + server_name_owned, server_id_owned + ); + if let Err(e) = manager + .refresh_prompts_catalog( + &server_id_owned, + connection_for_refresh.clone(), + ) + .await + { + warn!( + "Failed to refresh MCP prompts catalog after list-changed notification: server_name={} server_id={} error={}", + server_name_owned, server_id_owned, e + ); + } + } + Some(ListChangedKind::Resources) => { + info!( + "Received MCP resources list-changed notification: server_name={} server_id={}", + server_name_owned, server_id_owned + ); + if let Err(e) = manager + .refresh_resources_catalog( + &server_id_owned, + connection_for_refresh.clone(), + ) + .await + { + warn!( + "Failed to refresh MCP resources catalog after list-changed notification: server_name={} server_id={} error={}", + server_name_owned, server_id_owned, e + ); + } + } + None => { + debug!( + "Ignoring MCP notification from server: server_name={} server_id={} method={}", + server_name_owned, server_id_owned, method + ); + } + } + } + Ok(MCPConnectionEvent::Request { + request_id, + method, + params, + }) => { + manager + .handle_server_request( + &server_id_owned, + &server_name_owned, + connection_for_refresh.clone(), + request_id, + method, + params, + ) + .await; + } + Ok(MCPConnectionEvent::Closed) => { + warn!( + "MCP connection event stream closed: server_name={} server_id={}", + server_name_owned, server_id_owned + ); + break; + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(count)) => { + warn!( + "Dropped MCP connection events due to lag: server_name={} server_id={} dropped={}", + server_name_owned, server_id_owned, count + ); + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + break; + } + } + } + }); + + let mut tasks = self.connection_event_tasks.write().await; + tasks.insert(server_id.to_string(), handle); + } + + pub(super) async fn stop_connection_event_listener(&self, server_id: &str) { + let mut tasks = self.connection_event_tasks.write().await; + if let Some(handle) = tasks.remove(server_id) { + handle.abort(); + } + } +} diff --git a/src/crates/core/src/service/mcp/server/manager.rs b/src/crates/core/src/service/mcp/server/manager/lifecycle.rs similarity index 73% rename from src/crates/core/src/service/mcp/server/manager.rs rename to src/crates/core/src/service/mcp/server/manager/lifecycle.rs index 08b04c44..5bccdf5e 100644 --- a/src/crates/core/src/service/mcp/server/manager.rs +++ b/src/crates/core/src/service/mcp/server/manager/lifecycle.rs @@ -1,35 +1,9 @@ -//! MCP server manager -//! -//! Manages the lifecycle of all MCP servers. - -use super::connection::{MCPConnection, MCPConnectionPool}; -use super::{MCPServerConfig, MCPServerRegistry, MCPServerStatus}; -use crate::service::mcp::adapter::tool::MCPToolAdapter; -use crate::service::mcp::config::MCPConfigService; -use crate::service::runtime::{RuntimeManager, RuntimeSource}; -use crate::util::errors::{BitFunError, BitFunResult}; -use log::{debug, error, info, warn}; -use std::sync::Arc; - -/// MCP server manager. -pub struct MCPServerManager { - registry: Arc, - connection_pool: Arc, - config_service: Arc, -} +use super::*; impl MCPServerManager { - /// Creates a new server manager. - pub fn new(config_service: Arc) -> Self { - Self { - registry: Arc::new(MCPServerRegistry::new()), - connection_pool: Arc::new(MCPConnectionPool::new()), - config_service, - } - } - /// Initializes all servers. pub async fn initialize_all(&self) -> BitFunResult<()> { + self.start_reconnect_monitor_if_needed(); info!("Initializing all MCP servers"); let existing_server_ids = self.registry.get_all_server_ids().await; @@ -107,6 +81,7 @@ impl MCPServerManager { /// /// This is safe to call multiple times (e.g., from multiple frontend windows). pub async fn initialize_non_destructive(&self) -> BitFunResult<()> { + self.start_reconnect_monitor_if_needed(); info!("Initializing MCP servers (non-destructive)"); let configs = self.config_service.load_all_configs().await?; @@ -133,7 +108,6 @@ impl MCPServerManager { continue; } - // Start only when not already running. if let Ok(status) = self.get_server_status(&config.id).await { if matches!( status, @@ -175,6 +149,7 @@ impl MCPServerManager { /// Starts a server. pub async fn start_server(&self, server_id: &str) -> BitFunResult<()> { + self.start_reconnect_monitor_if_needed(); info!("Starting MCP server: id={}", server_id); let config = self @@ -215,7 +190,7 @@ impl MCPServerManager { } match config.server_type { - super::MCPServerType::Local => { + super::super::MCPServerType::Local => { let command = config.command.as_ref().ok_or_else(|| { error!("Missing command for local MCP server: id={}", server_id); BitFunError::Configuration("Missing command for local MCP server".to_string()) @@ -246,36 +221,43 @@ impl MCPServerManager { error!( "Failed to start local MCP server process: id={} command={} source={} error={}", server_id, resolved.command, source_label, e - ); - e - })?; + ); + e + })?; } - super::MCPServerType::Remote => { + super::super::MCPServerType::Remote => { + let transport = config.resolved_transport(); + if transport != crate::service::mcp::server::MCPServerTransport::StreamableHttp { + error!( + "Remote MCP transport not supported yet: id={} transport={}", + server_id, + transport.as_str() + ); + return Err(BitFunError::NotImplemented(format!( + "Remote MCP transport '{}' is not yet supported", + transport.as_str() + ))); + } + let url = config.url.as_ref().ok_or_else(|| { error!("Missing URL for remote MCP server: id={}", server_id); BitFunError::Configuration("Missing URL for remote MCP server".to_string()) })?; info!( - "Connecting to remote MCP server: url={} id={}", - url, server_id + "Connecting to remote MCP server: transport={} url={} id={}", + transport.as_str(), + url, + server_id ); - proc.start_remote(url, &config.env, &config.headers) - .await - .map_err(|e| { - error!( - "Failed to connect to remote MCP server: url={} id={} error={}", - url, server_id, e - ); - e - })?; - } - super::MCPServerType::Container => { - error!("Container MCP servers not supported: id={}", server_id); - return Err(BitFunError::NotImplemented( - "Container MCP servers not yet supported".to_string(), - )); + proc.start_remote(&config).await.map_err(|e| { + error!( + "Failed to connect to remote MCP server: url={} id={} error={}", + url, server_id, e + ); + e + })?; } } @@ -284,7 +266,7 @@ impl MCPServerManager { .add_connection(server_id.to_string(), connection.clone()) .await; - match Self::register_mcp_tools(server_id, &config.name, connection).await { + match Self::register_mcp_tools(server_id, &config.name, connection.clone()).await { Ok(count) => { info!( "Registered {} MCP tools: server_name={} server_id={}", @@ -298,6 +280,10 @@ impl MCPServerManager { ); } } + + self.start_connection_event_listener(server_id, &config.name, connection.clone()) + .await; + self.warm_catalog_caches(server_id, connection).await; } else { warn!( "Connection not available, server may not have started correctly: id={}", @@ -306,6 +292,7 @@ impl MCPServerManager { } info!("MCP server started successfully: id={}", server_id); + self.clear_reconnect_state(server_id).await; Ok(()) } @@ -313,6 +300,8 @@ impl MCPServerManager { pub async fn stop_server(&self, server_id: &str) -> BitFunResult<()> { info!("Stopping MCP server: id={}", server_id); + self.stop_connection_event_listener(server_id).await; + let process = self.registry.get_process(server_id).await.ok_or_else(|| { BitFunError::NotFound(format!("MCP server not found: {}", server_id)) @@ -322,6 +311,8 @@ impl MCPServerManager { let stop_result = proc.stop().await; self.connection_pool.remove_connection(server_id).await; + self.resource_catalog_cache.write().await.remove(server_id); + self.prompt_catalog_cache.write().await.remove(server_id); Self::unregister_mcp_tools(server_id).await; @@ -341,7 +332,7 @@ impl MCPServerManager { })?; match config.server_type { - super::MCPServerType::Local => { + super::super::MCPServerType::Local => { self.ensure_registered(server_id).await?; let process = self.registry.get_process(server_id).await.ok_or_else(|| { @@ -355,17 +346,11 @@ impl MCPServerManager { .ok_or_else(|| BitFunError::Configuration("Missing command".to_string()))?; proc.restart(command, &config.args, &config.env).await?; } - super::MCPServerType::Remote => { - // Treat restart as reconnect for remote servers. + super::super::MCPServerType::Remote => { self.ensure_registered(server_id).await?; let _ = self.stop_server(server_id).await; self.start_server(server_id).await?; } - _ => { - return Err(BitFunError::NotImplemented( - "Restart not supported for this server type".to_string(), - )); - } } Ok(()) @@ -374,8 +359,6 @@ impl MCPServerManager { /// Returns server status. pub async fn get_server_status(&self, server_id: &str) -> BitFunResult { if !self.registry.contains(server_id).await { - // If the server exists in config but isn't registered yet, register it so status - // reflects reality (Uninitialized) instead of heuristics in the UI. let _ = self.ensure_registered(server_id).await; } @@ -388,6 +371,21 @@ impl MCPServerManager { Ok(proc.status().await) } + /// Returns the current status detail/message for one server. + pub async fn get_server_status_message(&self, server_id: &str) -> BitFunResult> { + if !self.registry.contains(server_id).await { + let _ = self.ensure_registered(server_id).await; + } + + let process = + self.registry.get_process(server_id).await.ok_or_else(|| { + BitFunError::NotFound(format!("MCP server not found: {}", server_id)) + })?; + + let proc = process.read().await; + Ok(proc.status_message().await) + } + /// Returns statuses of all servers. pub async fn get_all_server_statuses(&self) -> Vec<(String, MCPServerStatus)> { let processes = self.registry.get_all_processes().await; @@ -418,7 +416,6 @@ impl MCPServerManager { config.validate()?; self.config_service.save_server_config(&config).await?; - self.registry.register(&config).await?; if config.enabled && config.auto_start { @@ -432,6 +429,9 @@ impl MCPServerManager { pub async fn remove_server(&self, server_id: &str) -> BitFunResult<()> { info!("Removing MCP server: id={}", server_id); + let _ = self.clear_remote_oauth_credentials(server_id).await; + self.stop_connection_event_listener(server_id).await; + match self.registry.unregister(server_id).await { Ok(_) => { info!("Unregistered MCP server: id={}", server_id); @@ -445,6 +445,9 @@ impl MCPServerManager { } self.config_service.delete_server_config(server_id).await?; + self.clear_reconnect_state(server_id).await; + self.resource_catalog_cache.write().await.remove(server_id); + self.prompt_catalog_cache.write().await.remove(server_id); info!("Deleted MCP server config: id={}", server_id); Ok(()) @@ -456,21 +459,70 @@ impl MCPServerManager { self.config_service.save_server_config(&config).await?; - let status = self.get_server_status(&config.id).await; + let status = self.get_server_status(&config.id).await?; if matches!( status, - Ok(MCPServerStatus::Connected | MCPServerStatus::Healthy) + MCPServerStatus::Connected | MCPServerStatus::Healthy ) { info!( "Restarting MCP server to apply new configuration: id={}", config.id ); self.restart_server(&config.id).await?; + } else if config.enabled + && config.auto_start + && matches!( + status, + MCPServerStatus::NeedsAuth + | MCPServerStatus::Failed + | MCPServerStatus::Reconnecting + | MCPServerStatus::Stopped + | MCPServerStatus::Uninitialized + ) + { + info!( + "Starting MCP server after configuration update: id={} previous_status={:?}", + config.id, status + ); + let _ = self.start_server(&config.id).await; } Ok(()) } + /// Updates remote MCP authorization and immediately retries the connection. + pub async fn reauthenticate_remote_server( + &self, + server_id: &str, + authorization_value: &str, + ) -> BitFunResult<()> { + self.clear_remote_oauth_credentials(server_id).await?; + let config = self + .config_service + .set_remote_authorization(server_id, authorization_value) + .await?; + + let _ = self.stop_server(server_id).await; + self.clear_reconnect_state(server_id).await; + + if config.enabled { + self.start_server(server_id).await?; + } + + Ok(()) + } + + /// Clears remote MCP authorization and stops the current connection so stale credentials are dropped. + pub async fn clear_remote_server_auth(&self, server_id: &str) -> BitFunResult<()> { + self.clear_remote_oauth_credentials(server_id).await?; + self.config_service + .clear_remote_authorization(server_id) + .await?; + let _ = self.stop_server(server_id).await; + self.clear_reconnect_state(server_id).await; + Ok(()) + } + /// Shuts down all servers. pub async fn shutdown(&self) -> BitFunResult<()> { info!("Shutting down all MCP servers"); @@ -483,66 +535,26 @@ impl MCPServerManager { } self.registry.clear().await?; - - info!("All MCP servers shut down"); - Ok(()) - } - - /// Registers MCP tools into the global tool registry. - async fn register_mcp_tools( - server_id: &str, - server_name: &str, - connection: Arc, - ) -> BitFunResult { - info!( - "Registering MCP tools: server_name={} server_id={}", - server_name, server_id - ); - - let mut adapter = MCPToolAdapter::new(); - - adapter - .load_tools_from_server(server_id, server_name, connection) + self.reconnect_states.write().await.clear(); + self.resource_catalog_cache.write().await.clear(); + self.prompt_catalog_cache.write().await.clear(); + self.pending_interactions.write().await.clear(); + let oauth_sessions: Vec<_> = self + .oauth_sessions + .write() .await - .map_err(|e| { - error!( - "Failed to load tools from MCP server: server_name={} server_id={} error={}", - server_name, server_id, e - ); - e - })?; - - let tools = adapter.get_tools(); - let tool_count = tools.len(); - - for tool in tools { - debug!( - "Loaded MCP tool: name={} server={}", - tool.name(), - server_name - ); + .drain() + .map(|(_, session)| session) + .collect(); + for session in oauth_sessions { + Self::shutdown_oauth_session(&session).await; + } + let mut event_tasks = self.connection_event_tasks.write().await; + for (_, handle) in event_tasks.drain() { + handle.abort(); } - let registry = crate::agentic::tools::registry::get_global_tool_registry(); - let mut registry_lock = registry.write().await; - - let tools_to_register = adapter.get_tools().to_vec(); - registry_lock.register_mcp_tools(tools_to_register); - drop(registry_lock); - - info!( - "Registered {} MCP tools: server_name={} server_id={}", - tool_count, server_name, server_id - ); - - Ok(tool_count) - } - - /// Unregisters MCP tools from the global tool registry. - async fn unregister_mcp_tools(server_id: &str) { - let registry = crate::agentic::tools::registry::get_global_tool_registry(); - let mut registry_lock = registry.write().await; - registry_lock.unregister_mcp_server_tools(server_id); - info!("Unregistered MCP tools: server_id={}", server_id); + info!("All MCP servers shut down"); + Ok(()) } } diff --git a/src/crates/core/src/service/mcp/server/manager/mod.rs b/src/crates/core/src/service/mcp/server/manager/mod.rs new file mode 100644 index 00000000..43c9e804 --- /dev/null +++ b/src/crates/core/src/service/mcp/server/manager/mod.rs @@ -0,0 +1,128 @@ +//! MCP server manager +//! +//! The manager is split into focused submodules so lifecycle, reconnect, +//! catalog, interaction, and tool-registration logic can evolve independently. + +mod auth; +mod catalog; +mod interaction; +mod lifecycle; +mod reconnect; +#[cfg(test)] +mod tests; +mod tools; + +use super::connection::{MCPConnection, MCPConnectionEvent, MCPConnectionPool}; +use super::{MCPServerConfig, MCPServerRegistry, MCPServerStatus}; +use crate::infrastructure::events::event_system::{get_global_event_system, BackendEvent}; +use crate::service::mcp::auth::MCPRemoteOAuthSessionSnapshot; +use crate::service::mcp::adapter::MCPToolAdapter; +use crate::service::mcp::config::MCPConfigService; +use crate::service::mcp::protocol::{MCPError, MCPPrompt, MCPResource}; +use crate::service::runtime::{RuntimeManager, RuntimeSource}; +use crate::service::workspace::get_global_workspace_service; +use crate::util::errors::{BitFunError, BitFunResult}; +use log::{debug, error, info, warn}; +use serde_json::{json, Value}; +use std::collections::{HashMap, HashSet}; +use std::path::Path; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::{Mutex, oneshot}; +use tokio::task::JoinHandle; + +/// Reconnect policy for unhealthy MCP servers. +#[derive(Debug, Clone, Copy)] +struct ReconnectPolicy { + poll_interval: Duration, + base_delay: Duration, + max_delay: Duration, + max_attempts: u32, +} + +impl Default for ReconnectPolicy { + fn default() -> Self { + Self { + poll_interval: Duration::from_secs(5), + base_delay: Duration::from_secs(2), + max_delay: Duration::from_secs(60), + max_attempts: 6, + } + } +} + +#[derive(Debug, Clone)] +struct ReconnectAttemptState { + attempts: u32, + next_retry_at: Instant, + exhausted_logged: bool, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ListChangedKind { + Tools, + Prompts, + Resources, +} + +#[derive(Debug)] +enum MCPInteractionDecision { + Accept { result: Value }, + Reject { error: MCPError }, +} + +#[derive(Debug)] +struct PendingMCPInteraction { + sender: oneshot::Sender, +} + +struct ActiveRemoteOAuthSession { + snapshot: Arc>, + shutdown_tx: Mutex>>, +} + +impl ReconnectAttemptState { + fn new(now: Instant) -> Self { + Self { + attempts: 0, + next_retry_at: now, + exhausted_logged: false, + } + } +} + +/// MCP server manager. +#[derive(Clone)] +pub struct MCPServerManager { + registry: Arc, + connection_pool: Arc, + config_service: Arc, + reconnect_policy: ReconnectPolicy, + reconnect_states: Arc>>, + reconnect_monitor_started: Arc, + connection_event_tasks: Arc>>>, + resource_catalog_cache: Arc>>>, + prompt_catalog_cache: Arc>>>, + pending_interactions: Arc>>, + oauth_sessions: Arc>>>, +} + +impl MCPServerManager { + /// Creates a new server manager. + pub fn new(config_service: Arc) -> Self { + Self { + registry: Arc::new(MCPServerRegistry::new()), + connection_pool: Arc::new(MCPConnectionPool::new()), + config_service, + reconnect_policy: ReconnectPolicy::default(), + reconnect_states: Arc::new(tokio::sync::RwLock::new(HashMap::new())), + reconnect_monitor_started: Arc::new(AtomicBool::new(false)), + connection_event_tasks: Arc::new(tokio::sync::RwLock::new(HashMap::new())), + resource_catalog_cache: Arc::new(tokio::sync::RwLock::new(HashMap::new())), + prompt_catalog_cache: Arc::new(tokio::sync::RwLock::new(HashMap::new())), + pending_interactions: Arc::new(tokio::sync::RwLock::new(HashMap::new())), + oauth_sessions: Arc::new(tokio::sync::RwLock::new(HashMap::new())), + } + } +} diff --git a/src/crates/core/src/service/mcp/server/manager/reconnect.rs b/src/crates/core/src/service/mcp/server/manager/reconnect.rs new file mode 100644 index 00000000..ee3dfbdd --- /dev/null +++ b/src/crates/core/src/service/mcp/server/manager/reconnect.rs @@ -0,0 +1,147 @@ +use super::*; + +impl MCPServerManager { + pub(super) fn start_reconnect_monitor_if_needed(&self) { + if self.reconnect_monitor_started.swap(true, Ordering::SeqCst) { + return; + } + + let manager = self.clone(); + tokio::spawn(async move { + manager.run_reconnect_monitor().await; + }); + info!("Started MCP reconnect monitor"); + } + + async fn run_reconnect_monitor(self) { + let mut interval = tokio::time::interval(self.reconnect_policy.poll_interval); + loop { + interval.tick().await; + if let Err(e) = self.reconnect_once().await { + warn!("MCP reconnect monitor tick failed: {}", e); + } + } + } + + async fn reconnect_once(&self) -> BitFunResult<()> { + let configs = self.config_service.load_all_configs().await?; + + for config in configs { + if !(config.enabled && config.auto_start) { + self.clear_reconnect_state(&config.id).await; + continue; + } + + let status = self + .get_server_status(&config.id) + .await + .unwrap_or(MCPServerStatus::Uninitialized); + + if matches!( + status, + MCPServerStatus::Connected | MCPServerStatus::Healthy | MCPServerStatus::Starting + ) { + self.clear_reconnect_state(&config.id).await; + continue; + } + + if matches!(status, MCPServerStatus::NeedsAuth) { + self.clear_reconnect_state(&config.id).await; + continue; + } + + if !matches!( + status, + MCPServerStatus::Reconnecting | MCPServerStatus::Failed + ) { + continue; + } + + self.try_reconnect_server(&config.id, &config.name, status) + .await; + } + + Ok(()) + } + + async fn try_reconnect_server( + &self, + server_id: &str, + server_name: &str, + status: MCPServerStatus, + ) { + let now = Instant::now(); + + let (attempt_number, next_delay) = { + let mut reconnect_states = self.reconnect_states.write().await; + let state = reconnect_states + .entry(server_id.to_string()) + .or_insert_with(|| ReconnectAttemptState::new(now)); + + if state.attempts >= self.reconnect_policy.max_attempts { + if !state.exhausted_logged { + warn!( + "MCP reconnect attempts exhausted: server_name={} server_id={} max_attempts={} status={:?}", + server_name, server_id, self.reconnect_policy.max_attempts, status + ); + state.exhausted_logged = true; + } + return; + } + + if now < state.next_retry_at { + return; + } + + state.attempts += 1; + let delay = Self::compute_backoff_delay( + self.reconnect_policy.base_delay, + self.reconnect_policy.max_delay, + state.attempts, + ); + state.next_retry_at = now + delay; + (state.attempts, delay) + }; + + info!( + "Attempting MCP reconnect: server_name={} server_id={} attempt={}/{} status={:?}", + server_name, server_id, attempt_number, self.reconnect_policy.max_attempts, status + ); + + let _ = self.stop_server(server_id).await; + match self.start_server(server_id).await { + Ok(_) => { + self.clear_reconnect_state(server_id).await; + info!( + "MCP reconnect succeeded: server_name={} server_id={} attempt={}", + server_name, server_id, attempt_number + ); + } + Err(e) => { + warn!( + "MCP reconnect failed: server_name={} server_id={} attempt={}/{} next_retry_in={}s error={}", + server_name, + server_id, + attempt_number, + self.reconnect_policy.max_attempts, + next_delay.as_secs(), + e + ); + } + } + } + + pub(super) fn compute_backoff_delay(base: Duration, max: Duration, attempt: u32) -> Duration { + let shift = attempt.saturating_sub(1).min(20); + let factor = 1u64 << shift; + let base_ms = base.as_millis() as u64; + let max_ms = max.as_millis() as u64; + let delay_ms = base_ms.saturating_mul(factor).min(max_ms); + Duration::from_millis(delay_ms) + } + + pub(super) async fn clear_reconnect_state(&self, server_id: &str) { + let mut reconnect_states = self.reconnect_states.write().await; + reconnect_states.remove(server_id); + } +} diff --git a/src/crates/core/src/service/mcp/server/manager/tests.rs b/src/crates/core/src/service/mcp/server/manager/tests.rs new file mode 100644 index 00000000..b4f9f798 --- /dev/null +++ b/src/crates/core/src/service/mcp/server/manager/tests.rs @@ -0,0 +1,45 @@ +use super::{ListChangedKind, MCPServerManager}; +use std::time::Duration; + +#[test] +fn backoff_delay_grows_exponentially_and_caps() { + let base = Duration::from_secs(2); + let max = Duration::from_secs(60); + + assert_eq!( + MCPServerManager::compute_backoff_delay(base, max, 1), + Duration::from_secs(2) + ); + assert_eq!( + MCPServerManager::compute_backoff_delay(base, max, 2), + Duration::from_secs(4) + ); + assert_eq!( + MCPServerManager::compute_backoff_delay(base, max, 5), + Duration::from_secs(32) + ); + assert_eq!( + MCPServerManager::compute_backoff_delay(base, max, 10), + Duration::from_secs(60) + ); +} + +#[test] +fn detect_list_changed_kind_supports_three_catalogs() { + assert_eq!( + MCPServerManager::detect_list_changed_kind("notifications/tools/list_changed"), + Some(ListChangedKind::Tools) + ); + assert_eq!( + MCPServerManager::detect_list_changed_kind("notifications/prompts/list_changed"), + Some(ListChangedKind::Prompts) + ); + assert_eq!( + MCPServerManager::detect_list_changed_kind("notifications/resources/list_changed"), + Some(ListChangedKind::Resources) + ); + assert_eq!( + MCPServerManager::detect_list_changed_kind("notifications/unknown"), + None + ); +} diff --git a/src/crates/core/src/service/mcp/server/manager/tools.rs b/src/crates/core/src/service/mcp/server/manager/tools.rs new file mode 100644 index 00000000..b0ebc0d0 --- /dev/null +++ b/src/crates/core/src/service/mcp/server/manager/tools.rs @@ -0,0 +1,71 @@ +use super::*; + +impl MCPServerManager { + pub(super) async fn refresh_mcp_tools( + &self, + server_id: &str, + server_name: &str, + connection: Arc, + ) -> BitFunResult { + Self::unregister_mcp_tools(server_id).await; + Self::register_mcp_tools(server_id, server_name, connection).await + } + + /// Registers MCP tools into the global tool registry. + pub(super) async fn register_mcp_tools( + server_id: &str, + server_name: &str, + connection: Arc, + ) -> BitFunResult { + info!( + "Registering MCP tools: server_name={} server_id={}", + server_name, server_id + ); + + let mut adapter = MCPToolAdapter::new(); + + adapter + .load_tools_from_server(server_id, server_name, connection) + .await + .map_err(|e| { + error!( + "Failed to load tools from MCP server: server_name={} server_id={} error={}", + server_name, server_id, e + ); + e + })?; + + let tools = adapter.get_tools(); + let tool_count = tools.len(); + + for tool in tools { + debug!( + "Loaded MCP tool: name={} server={}", + tool.name(), + server_name + ); + } + + let registry = crate::agentic::tools::registry::get_global_tool_registry(); + let mut registry_lock = registry.write().await; + + let tools_to_register = adapter.get_tools().to_vec(); + registry_lock.register_mcp_tools(tools_to_register); + drop(registry_lock); + + info!( + "Registered {} MCP tools: server_name={} server_id={}", + tool_count, server_name, server_id + ); + + Ok(tool_count) + } + + /// Unregisters MCP tools from the global tool registry. + pub(super) async fn unregister_mcp_tools(server_id: &str) { + let registry = crate::agentic::tools::registry::get_global_tool_registry(); + let mut registry_lock = registry.write().await; + registry_lock.unregister_mcp_server_tools(server_id); + info!("Unregistered MCP tools: server_id={}", server_id); + } +} diff --git a/src/crates/core/src/service/mcp/server/mod.rs b/src/crates/core/src/service/mcp/server/mod.rs index b3e5bdf4..858e9b55 100644 --- a/src/crates/core/src/service/mcp/server/mod.rs +++ b/src/crates/core/src/service/mcp/server/mod.rs @@ -2,92 +2,14 @@ //! //! Manages MCP server process lifecycles, connections, and registration. -pub mod connection; -pub mod manager; -pub mod process; -pub mod registry; +mod config; +mod connection; +mod manager; +mod process; +mod registry; +pub use config::{MCPServerConfig, MCPServerOAuthConfig, MCPServerTransport, MCPServerXaaConfig}; pub use connection::{MCPConnection, MCPConnectionPool}; pub use manager::MCPServerManager; pub use process::{MCPServerProcess, MCPServerStatus, MCPServerType}; pub use registry::MCPServerRegistry; - -/// MCP server configuration. -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct MCPServerConfig { - pub id: String, - pub name: String, - #[serde(rename = "type")] - pub server_type: MCPServerType, - #[serde(skip_serializing_if = "Option::is_none")] - pub command: Option, - #[serde(default)] - pub args: Vec, - #[serde(default)] - pub env: std::collections::HashMap, - /// Additional HTTP headers for remote MCP servers (Cursor-style `headers`). - #[serde(default)] - pub headers: std::collections::HashMap, - #[serde(skip_serializing_if = "Option::is_none")] - pub url: Option, - #[serde(default = "default_true")] - pub auto_start: bool, - #[serde(default = "default_true")] - pub enabled: bool, - pub location: crate::service::mcp::config::ConfigLocation, - #[serde(default)] - pub capabilities: Vec, - #[serde(default)] - pub settings: std::collections::HashMap, -} - -fn default_true() -> bool { - true -} - -impl MCPServerConfig { - /// Validates the configuration. - pub fn validate(&self) -> crate::util::errors::BitFunResult<()> { - if self.id.is_empty() { - return Err(crate::util::errors::BitFunError::Configuration( - "MCP server id cannot be empty".to_string(), - )); - } - - if self.name.is_empty() { - return Err(crate::util::errors::BitFunError::Configuration( - "MCP server name cannot be empty".to_string(), - )); - } - - match self.server_type { - MCPServerType::Local => { - if self.command.is_none() { - return Err(crate::util::errors::BitFunError::Configuration(format!( - "Local MCP server '{}' must have a command", - self.id - ))); - } - } - MCPServerType::Remote => { - if self.url.is_none() { - return Err(crate::util::errors::BitFunError::Configuration(format!( - "Remote MCP server '{}' must have a URL", - self.id - ))); - } - } - MCPServerType::Container => { - if self.command.is_none() { - return Err(crate::util::errors::BitFunError::Configuration(format!( - "Container MCP server '{}' must have a command", - self.id - ))); - } - } - } - - Ok(()) - } -} diff --git a/src/crates/core/src/service/mcp/server/process.rs b/src/crates/core/src/service/mcp/server/process.rs index b5a75a86..7a4b6e36 100644 --- a/src/crates/core/src/service/mcp/server/process.rs +++ b/src/crates/core/src/service/mcp/server/process.rs @@ -3,7 +3,9 @@ //! Handles starting, stopping, monitoring, and restarting MCP server processes. use super::connection::MCPConnection; -use crate::service::mcp::protocol::{InitializeResult, MCPMessage, MCPServerInfo}; +use super::MCPServerConfig; +use crate::service::mcp::protocol::{InitializeResult, MCPMessage, MCPServerInfo, MCPTransport}; +use crate::service::mcp::server::MCPServerTransport; use crate::util::errors::{BitFunError, BitFunResult}; use log::{debug, error, info, warn}; use std::sync::Arc; @@ -15,9 +17,8 @@ use tokio::sync::{mpsc, RwLock}; #[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] #[serde(rename_all = "lowercase")] pub enum MCPServerType { - Local, // Local executable - Remote, // Remote HTTP/WebSocket server - Container, // Docker container + Local, // Command-driven stdio server, including docker/podman wrappers + Remote, // Remote HTTP/WebSocket server } /// MCP server status. @@ -28,6 +29,7 @@ pub enum MCPServerStatus { Starting, // Starting Connected, // Connected Healthy, // Healthy (heartbeat OK) + NeedsAuth, // Authentication required / token expired Reconnecting, // Reconnecting Failed, // Failed Stopping, // Stopping @@ -48,10 +50,37 @@ pub struct MCPServerProcess { max_restarts: u32, health_check_interval: Duration, last_ping_time: Arc>>, + last_error_message: Arc>>, message_rx: Option>, } impl MCPServerProcess { + fn is_auth_error(error: &BitFunError) -> bool { + let msg = error.to_string().to_ascii_lowercase(); + let patterns = [ + "unauthorized", + "forbidden", + "auth required", + "authorization required", + "authentication required", + "authentication failed", + "oauth authorization required", + "oauth token refresh failed", + "token refresh failed", + "www-authenticate", + "invalid token", + "token expired", + "access token expired", + "refresh token", + "session expired", + "status code: 401", + "status code: 403", + " 401 ", + " 403 ", + ]; + patterns.iter().any(|p| msg.contains(p)) + } + /// Creates a new server process instance. pub fn new(id: String, name: String, server_type: MCPServerType) -> Self { Self { @@ -67,6 +96,7 @@ impl MCPServerProcess { max_restarts: 3, health_check_interval: Duration::from_secs(30), last_ping_time: Arc::new(RwLock::new(None)), + last_error_message: Arc::new(RwLock::new(None)), message_rx: None, } } @@ -121,7 +151,8 @@ impl MCPServerProcess { let mut child = match child { Ok(c) => c, Err(e) => { - self.set_status(MCPServerStatus::Failed).await; + self.set_status_with_error(MCPServerStatus::Failed, Some(e.to_string())) + .await; return Err(e); } }; @@ -140,7 +171,7 @@ impl MCPServerProcess { let connection = Arc::new(MCPConnection::new(stdin, rx)); self.message_rx = None; // The connection already owns rx - crate::service::mcp::protocol::transport::MCPTransport::start_receive_loop(stdout, tx); + MCPTransport::start_receive_loop(stdout, tx); self.connection = Some(connection.clone()); self.child = Some(child); @@ -152,11 +183,14 @@ impl MCPServerProcess { self.name, self.id, e ); let _ = self.stop().await; - self.set_status(MCPServerStatus::Failed).await; + self.set_status_with_error(MCPServerStatus::Failed, Some(e.to_string())) + .await; return Err(e); } - self.set_status(MCPServerStatus::Connected).await; + self.set_status_with_error(MCPServerStatus::Connected, None) + .await; + self.restart_count = 0; info!( "MCP server started successfully: name={} id={}", self.name, self.id @@ -168,34 +202,45 @@ impl MCPServerProcess { } /// Starts a remote server (Streamable HTTP). - pub async fn start_remote( - &mut self, - url: &str, - env: &std::collections::HashMap, - headers: &std::collections::HashMap, - ) -> BitFunResult<()> { + pub async fn start_remote(&mut self, config: &MCPServerConfig) -> BitFunResult<()> { + let url = config.url.as_deref().ok_or_else(|| { + BitFunError::Configuration(format!("Remote MCP server '{}' is missing a URL", self.id)) + })?; + let transport = config.resolved_transport(); + if transport != MCPServerTransport::StreamableHttp { + return Err(BitFunError::NotImplemented(format!( + "Remote MCP transport '{}' is not yet supported", + transport.as_str() + ))); + } info!( - "Starting remote MCP server: name={} id={} url={}", - self.name, self.id, url + "Starting remote MCP server: name={} id={} transport={} url={}", + self.name, + self.id, + transport.as_str(), + url ); self.set_status(MCPServerStatus::Starting).await; - let mut merged_headers = headers.clone(); + let mut merged_headers = config.headers.clone(); if !merged_headers.contains_key("Authorization") && !merged_headers.contains_key("authorization") && !merged_headers.contains_key("AUTHORIZATION") { // Backward compatibility: older BitFun configs store `Authorization` under `env`. - if let Some(value) = env + if let Some(value) = config + .env .get("Authorization") - .or_else(|| env.get("authorization")) - .or_else(|| env.get("AUTHORIZATION")) + .or_else(|| config.env.get("authorization")) + .or_else(|| config.env.get("AUTHORIZATION")) { merged_headers.insert("Authorization".to_string(), value.clone()); } } - let connection = Arc::new(MCPConnection::new_remote(url.to_string(), merged_headers)); + let connection = Arc::new( + MCPConnection::new_remote(&self.id, url.to_string(), merged_headers, true).await?, + ); self.connection = Some(connection.clone()); self.start_time = Some(Instant::now()); @@ -208,11 +253,19 @@ impl MCPServerProcess { self.message_rx = None; self.child = None; self.server_info = None; - self.set_status(MCPServerStatus::Failed).await; + if Self::is_auth_error(&e) { + self.set_status_with_error(MCPServerStatus::NeedsAuth, Some(e.to_string())) + .await; + } else { + self.set_status_with_error(MCPServerStatus::Failed, Some(e.to_string())) + .await; + } return Err(e); } - self.set_status(MCPServerStatus::Connected).await; + self.set_status_with_error(MCPServerStatus::Connected, None) + .await; + self.restart_count = 0; info!( "Remote MCP server started successfully: name={} id={}", self.name, self.id @@ -286,7 +339,14 @@ impl MCPServerProcess { "Max restart attempts reached: name={} id={} max_restarts={}", self.name, self.id, self.max_restarts ); - self.set_status(MCPServerStatus::Failed).await; + self.set_status_with_error( + MCPServerStatus::Failed, + Some(format!( + "Max restart attempts ({}) reached", + self.max_restarts + )), + ) + .await; return Err(BitFunError::MCPError(format!( "Max restart attempts ({}) reached", self.max_restarts @@ -306,8 +366,14 @@ impl MCPServerProcess { /// Sets status. async fn set_status(&self, status: MCPServerStatus) { + self.set_status_with_error(status, None).await; + } + + async fn set_status_with_error(&self, status: MCPServerStatus, error: Option) { let mut current_status = self.status.write().await; *current_status = status; + let mut last_error_message = self.last_error_message.write().await; + *last_error_message = error; } /// Gets status. @@ -315,6 +381,11 @@ impl MCPServerProcess { *self.status.read().await } + /// Returns the last status/error detail associated with the process. + pub async fn status_message(&self) -> Option { + self.last_error_message.read().await.clone() + } + /// Returns the connection. pub fn connection(&self) -> Option> { self.connection.clone() @@ -329,6 +400,7 @@ impl MCPServerProcess { fn start_health_check(&self) { let status = self.status.clone(); let last_ping = self.last_ping_time.clone(); + let last_error_message = self.last_error_message.clone(); let connection = self.connection.clone(); let interval = self.health_check_interval; let server_name = self.name.clone(); @@ -356,13 +428,19 @@ impl MCPServerProcess { Ok(_) => { *status.write().await = MCPServerStatus::Healthy; *last_ping.write().await = Some(Instant::now()); + *last_error_message.write().await = None; } Err(e) => { warn!( "Health check failed: server_name={} error={}", server_name, e ); - *status.write().await = MCPServerStatus::Reconnecting; + if MCPServerProcess::is_auth_error(&e) { + *status.write().await = MCPServerStatus::NeedsAuth; + } else { + *status.write().await = MCPServerStatus::Reconnecting; + } + *last_error_message.write().await = Some(e.to_string()); } } } else { @@ -400,3 +478,24 @@ impl Drop for MCPServerProcess { } } } + +#[cfg(test)] +mod tests { + use super::MCPServerProcess; + use crate::util::errors::BitFunError; + + #[test] + fn detect_auth_error_patterns() { + let unauthorized = + BitFunError::MCPError("Handshake failed: Unauthorized (401)".to_string()); + assert!(MCPServerProcess::is_auth_error(&unauthorized)); + + let oauth_refresh = BitFunError::MCPError( + "Ping failed: OAuth token refresh failed: no refresh token available".to_string(), + ); + assert!(MCPServerProcess::is_auth_error(&oauth_refresh)); + + let generic = BitFunError::MCPError("Handshake failed: connection reset".to_string()); + assert!(!MCPServerProcess::is_auth_error(&generic)); + } +} diff --git a/src/crates/core/tests/remote_mcp_streamable_http.rs b/src/crates/core/tests/remote_mcp_streamable_http.rs index af660490..2c45d6d0 100644 --- a/src/crates/core/tests/remote_mcp_streamable_http.rs +++ b/src/crates/core/tests/remote_mcp_streamable_http.rs @@ -24,6 +24,9 @@ struct TestState { sse_connected: Arc, sse_connected_notify: Arc, saw_session_header: Arc, + saw_roots_capability: Arc, + saw_sampling_capability: Arc, + saw_elicitation_capability: Arc, } async fn sse_handler( @@ -64,6 +67,25 @@ async fn post_handler( match method { "initialize" => { + let capabilities = body + .get("params") + .and_then(|params| params.get("capabilities")) + .cloned() + .unwrap_or(Value::Null); + if capabilities.get("roots").is_some() { + state.saw_roots_capability.store(true, Ordering::SeqCst); + } + if capabilities.get("sampling").is_some() { + state + .saw_sampling_capability + .store(true, Ordering::SeqCst); + } + if capabilities.get("elicitation").is_some() { + state + .saw_elicitation_capability + .store(true, Ordering::SeqCst); + } + let response = json!({ "jsonrpc": "2.0", "id": id, @@ -102,8 +124,28 @@ async fn post_handler( "tools": [ { "name": "hello", + "title": "Hello Tool", "description": "test tool", - "inputSchema": { "type": "object", "properties": {} } + "inputSchema": { "type": "object", "properties": {} }, + "outputSchema": { "type": "object", "properties": { "message": { "type": "string" } } }, + "annotations": { + "title": "Hello", + "readOnlyHint": true, + "destructiveHint": false, + "openWorldHint": true + }, + "icons": [ + { + "src": "https://example.com/tool.png", + "mimeType": "image/png", + "sizes": ["32x32"] + } + ], + "_meta": { + "ui": { + "resourceUri": "ui://hello/widget" + } + } } ], "nextCursor": null @@ -147,7 +189,14 @@ async fn remote_mcp_streamable_http_accepts_202_and_delivers_response_via_sse() }); let url = format!("http://{addr}/mcp"); - let connection = MCPConnection::new_remote(url, Default::default()); + let connection = MCPConnection::new_remote( + "test-server", + url, + Default::default(), + false, + ) + .await + .expect("remote connection should be created"); connection .initialize("BitFunTest", "0.0.0") @@ -172,9 +221,37 @@ async fn remote_mcp_streamable_http_accepts_202_and_delivers_response_via_sse() .expect("tools/list should resolve via SSE"); assert_eq!(tools.tools.len(), 1); assert_eq!(tools.tools[0].name, "hello"); + assert_eq!(tools.tools[0].title.as_deref(), Some("Hello Tool")); + assert_eq!( + tools.tools[0] + .annotations + .as_ref() + .and_then(|annotations| annotations.read_only_hint), + Some(true) + ); + assert_eq!( + tools.tools[0] + .meta + .as_ref() + .and_then(|meta| meta.ui.as_ref()) + .and_then(|ui| ui.resource_uri.as_deref()), + Some("ui://hello/widget") + ); assert!( state.saw_session_header.load(Ordering::SeqCst), "client should forward session id header on subsequent requests" ); + assert!( + state.saw_roots_capability.load(Ordering::SeqCst), + "client should advertise roots capability" + ); + assert!( + state.saw_sampling_capability.load(Ordering::SeqCst), + "client should advertise sampling capability" + ); + assert!( + state.saw_elicitation_capability.load(Ordering::SeqCst), + "client should advertise elicitation capability" + ); } diff --git a/src/web-ui/src/app/components/MCPInteractionDialog/MCPInteractionDialog.scss b/src/web-ui/src/app/components/MCPInteractionDialog/MCPInteractionDialog.scss new file mode 100644 index 00000000..ed69310d --- /dev/null +++ b/src/web-ui/src/app/components/MCPInteractionDialog/MCPInteractionDialog.scss @@ -0,0 +1,71 @@ +.mcp-interaction-dialog { + display: flex; + flex-direction: column; + gap: 12px; + min-height: 420px; +} + +.mcp-interaction-dialog__meta { + display: flex; + justify-content: space-between; + align-items: center; + gap: 8px; + font-size: 12px; + color: var(--text-secondary, #9ca3af); +} + +.mcp-interaction-dialog__server { + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} + +.mcp-interaction-dialog__queue { + color: var(--text-tertiary, #6b7280); +} + +.mcp-interaction-dialog__section { + display: flex; + flex-direction: column; + gap: 6px; +} + +.mcp-interaction-dialog__label { + font-size: 12px; + color: var(--text-secondary, #9ca3af); +} + +.mcp-interaction-dialog__params { + margin: 0; + padding: 10px; + border-radius: 8px; + background: var(--bg-tertiary, #111827); + border: 1px solid var(--border-primary, #374151); + color: var(--text-primary, #e5e7eb); + font-size: 12px; + line-height: 1.5; + white-space: pre-wrap; + word-break: break-word; + max-height: 180px; + overflow: auto; +} + +.mcp-interaction-dialog__editor { + width: 100%; + min-height: 190px; + padding: 10px; + border-radius: 8px; + background: var(--bg-secondary, #0b1220); + border: 1px solid var(--border-primary, #374151); + color: var(--text-primary, #e5e7eb); + font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; + font-size: 12px; + line-height: 1.5; + resize: vertical; +} + +.mcp-interaction-dialog__actions { + display: flex; + justify-content: flex-end; + gap: 8px; +} diff --git a/src/web-ui/src/app/components/MCPInteractionDialog/MCPInteractionDialog.tsx b/src/web-ui/src/app/components/MCPInteractionDialog/MCPInteractionDialog.tsx new file mode 100644 index 00000000..8955585d --- /dev/null +++ b/src/web-ui/src/app/components/MCPInteractionDialog/MCPInteractionDialog.tsx @@ -0,0 +1,191 @@ +import React, { useCallback, useEffect, useMemo, useState } from 'react'; +import { Button, Modal } from '@/component-library'; +import { globalEventBus } from '@/infrastructure/event-bus'; +import { MCPAPI } from '@/infrastructure/api/service-api/MCPAPI'; +import { notificationService } from '@/shared/notification-system'; +import { createLogger } from '@/shared/utils/logger'; +import './MCPInteractionDialog.scss'; + +const log = createLogger('MCPInteractionDialog'); + +interface MCPInteractionRequestEvent { + interactionId: string; + serverId: string; + serverName: string; + method: string; + params?: unknown; +} + +function buildDefaultResult(method: string): string { + if (method === 'sampling/createMessage') { + return '{\n "role": "assistant",\n "content": [\n {\n "type": "text",\n "text": ""\n }\n ]\n}'; + } + return '{\n "action": "accept"\n}'; +} + +function stringifyParams(params: unknown): string { + try { + return JSON.stringify(params ?? null, null, 2); + } catch { + return String(params ?? ''); + } +} + +export const MCPInteractionDialog: React.FC = () => { + const [queue, setQueue] = useState([]); + const [editorValue, setEditorValue] = useState(''); + const [isSubmitting, setIsSubmitting] = useState(false); + + const currentRequest = queue[0] ?? null; + const isOpen = !!currentRequest; + const queueCount = queue.length; + + useEffect(() => { + const handleRequest = (event: MCPInteractionRequestEvent) => { + if (!event?.interactionId || !event?.method) { + log.warn('Ignoring invalid MCP interaction event payload', { event }); + return; + } + + setQueue((prev) => { + if (prev.some((item) => item.interactionId === event.interactionId)) { + return prev; + } + return [...prev, event]; + }); + }; + + globalEventBus.on('mcp:interaction:request', handleRequest); + return () => { + globalEventBus.off('mcp:interaction:request', handleRequest); + }; + }, []); + + useEffect(() => { + if (!currentRequest) { + setEditorValue(''); + return; + } + setEditorValue(buildDefaultResult(currentRequest.method)); + }, [currentRequest]); + + const popCurrentRequest = useCallback(() => { + setQueue((prev) => prev.slice(1)); + }, []); + + const handleReject = useCallback(async () => { + if (!currentRequest || isSubmitting) return; + setIsSubmitting(true); + try { + await MCPAPI.submitMCPInteractionResponse({ + interactionId: currentRequest.interactionId, + approve: false, + error: { + message: 'User rejected MCP interaction request', + }, + }); + popCurrentRequest(); + } catch (error) { + log.error('Failed to submit MCP interaction rejection', { error, currentRequest }); + notificationService.error(`Failed to reject MCP request: ${currentRequest.method}`); + } finally { + setIsSubmitting(false); + } + }, [currentRequest, isSubmitting, popCurrentRequest]); + + const handleApprove = useCallback(async () => { + if (!currentRequest || isSubmitting) return; + + const trimmed = editorValue.trim(); + let parsedResult: unknown = {}; + if (trimmed.length > 0) { + try { + parsedResult = JSON.parse(trimmed); + } catch { + notificationService.error('Result must be valid JSON'); + return; + } + } + + setIsSubmitting(true); + try { + await MCPAPI.submitMCPInteractionResponse({ + interactionId: currentRequest.interactionId, + approve: true, + result: parsedResult as any, + }); + popCurrentRequest(); + } catch (error) { + log.error('Failed to submit MCP interaction approval', { error, currentRequest }); + notificationService.error(`Failed to approve MCP request: ${currentRequest.method}`); + } finally { + setIsSubmitting(false); + } + }, [currentRequest, editorValue, isSubmitting, popCurrentRequest]); + + const paramsPreview = useMemo(() => { + if (!currentRequest) return ''; + return stringifyParams(currentRequest.params); + }, [currentRequest]); + + return ( + {}} + title={currentRequest ? `MCP Interaction: ${currentRequest.method}` : 'MCP Interaction'} + size="large" + showCloseButton={false} + > + {currentRequest && ( +
+
+ + Server: {currentRequest.serverName || currentRequest.serverId} + + {queueCount > 1 && ( + Queue: {queueCount} + )} +
+ +
+
Request Params
+
{paramsPreview}
+
+ +
+
Response JSON
+