diff --git a/crates/forge_api/src/api.rs b/crates/forge_api/src/api.rs index b2bfbb732d..9fc962ce33 100644 --- a/crates/forge_api/src/api.rs +++ b/crates/forge_api/src/api.rs @@ -20,8 +20,10 @@ pub trait API: Sync + Send { /// environment async fn get_tools(&self) -> anyhow::Result; - /// Provides a list of models available in the current environment + /// Provides a list of models available in the current environment (with + /// caching) async fn get_models(&self) -> Result>; + /// Provides a list of agents available in the current environment async fn get_agents(&self) -> Result>; /// Provides a list of providers available in the current environment diff --git a/crates/forge_api/src/forge_api.rs b/crates/forge_api/src/forge_api.rs index 0ccb035c59..c0c4b8c08c 100644 --- a/crates/forge_api/src/forge_api.rs +++ b/crates/forge_api/src/forge_api.rs @@ -73,6 +73,7 @@ impl< async fn get_models(&self) -> Result> { self.app().get_models().await } + async fn get_agents(&self) -> Result> { self.services.get_agents().await } diff --git a/crates/forge_app/src/app.rs b/crates/forge_app/src/app.rs index 82a16a118a..e58b22a14f 100644 --- a/crates/forge_app/src/app.rs +++ b/crates/forge_app/src/app.rs @@ -253,18 +253,82 @@ impl ForgeApp { self.tool_registry.tools_overview().await } - /// Gets available models for the default provider with automatic credential - /// refresh. + /// Gets available models from all configured LLM providers with automatic + /// credential refresh and caching. Fetches models in parallel and + /// aggregates results. Errors from individual providers are logged but + /// don't prevent returning models from other providers. pub async fn get_models(&self) -> Result> { - let agent_provider_resolver = AgentProviderResolver::new(self.services.clone()); - let provider = agent_provider_resolver.get_provider(None).await?; - let provider = self - .services - .provider_auth_service() - .refresh_provider_credential(provider) - .await?; + use forge_domain::ProviderType; + use futures::stream::{FuturesUnordered, StreamExt}; + + // Check cache first + if let Some(cached_models) = self.services.get_cached_all_models().await { + return Ok(cached_models); + } + + // Get all providers + let all_providers = self.services.get_all_providers().await?; + + // Filter to only LLM providers that are configured + let llm_providers: Vec<_> = all_providers + .into_iter() + .filter(|p| { + let is_llm = match p { + forge_domain::AnyProvider::Url(provider) => { + provider.provider_type == ProviderType::Llm + } + forge_domain::AnyProvider::Template(provider) => { + provider.provider_type == ProviderType::Llm + } + }; + is_llm && p.is_configured() + }) + .collect(); + + // Fetch models from all providers in parallel + let mut fetch_futures = FuturesUnordered::new(); + + for any_provider in llm_providers { + let provider = match any_provider.into_configured() { + Some(p) => p, + None => continue, + }; + + // Refresh credentials for this provider + let services = self.services.clone(); + let fetch_task = async move { + let provider = services + .provider_auth_service() + .refresh_provider_credential(provider) + .await?; + let provider_id = provider.id.clone(); + + services.models(provider).await.map_err(|e| { + tracing::warn!( + provider_id = %provider_id, + error = %e, + "Failed to fetch models from provider" + ); + e + }) + }; + + fetch_futures.push(fetch_task); + } + + // Collect all successful results + let mut all_models = Vec::new(); + while let Some(result) = fetch_futures.next().await { + if let Ok(mut models) = result { + all_models.append(&mut models); + } + // Errors are already logged in the task above + } + + // Cache the aggregated results + self.services.cache_all_models(all_models.clone()).await; - self.services.models(provider).await + Ok(all_models) } pub async fn login(&self, init_auth: &InitAuth) -> Result<()> { self.authenticator.login(init_auth).await diff --git a/crates/forge_app/src/command_generator.rs b/crates/forge_app/src/command_generator.rs index 429ade4fc6..0faebd4e43 100644 --- a/crates/forge_app/src/command_generator.rs +++ b/crates/forge_app/src/command_generator.rs @@ -197,6 +197,14 @@ mod tests { async fn migrate_env_credentials(&self) -> anyhow::Result> { Ok(None) } + + async fn cache_all_models(&self, _models: Vec) {} + + async fn get_cached_all_models(&self) -> Option> { + None + } + + async fn invalidate_caches(&self) {} } #[async_trait::async_trait] diff --git a/crates/forge_app/src/dto/anthropic/response.rs b/crates/forge_app/src/dto/anthropic/response.rs index ebd7e60051..c496f731de 100644 --- a/crates/forge_app/src/dto/anthropic/response.rs +++ b/crates/forge_app/src/dto/anthropic/response.rs @@ -1,6 +1,6 @@ use forge_domain::{ - ChatCompletionMessage, Content, ModelId, Reasoning, ReasoningPart, TokenCount, ToolCallId, - ToolCallPart, ToolName, + ChatCompletionMessage, Content, ModelId, ProviderId, Reasoning, ReasoningPart, TokenCount, + ToolCallId, ToolCallPart, ToolName, }; use serde::Deserialize; @@ -18,12 +18,14 @@ pub struct Model { pub display_name: String, } -impl From for forge_domain::Model { - fn from(value: Model) -> Self { - let context_length = get_context_length(&value.id); - Self { - id: ModelId::new(value.id), - name: Some(value.display_name), +impl Model { + /// Converts this DTO model to a domain model with the specified provider_id + pub fn into_domain_model(self, provider_id: ProviderId) -> forge_domain::Model { + let context_length = get_context_length(&self.id); + forge_domain::Model { + id: ModelId::new(self.id), + provider_id: Some(provider_id), + name: Some(self.display_name), description: None, context_length, tools_supported: Some(true), @@ -661,7 +663,8 @@ mod tests { display_name: "Claude 3.5 Sonnet (New)".to_string(), }; - let actual: forge_domain::Model = fixture.into(); + let actual: forge_domain::Model = + fixture.into_domain_model(forge_domain::ProviderId::ANTHROPIC); assert_eq!(actual.context_length, Some(200_000)); assert_eq!(actual.id.as_str(), "claude-sonnet-4-5-20250929"); @@ -675,7 +678,8 @@ mod tests { display_name: "Unknown Model".to_string(), }; - let actual: forge_domain::Model = fixture.into(); + let actual: forge_domain::Model = + fixture.into_domain_model(forge_domain::ProviderId::ANTHROPIC); assert_eq!(actual.context_length, None); assert_eq!(actual.id.as_str(), "unknown-claude-model"); diff --git a/crates/forge_app/src/dto/openai/model.rs b/crates/forge_app/src/dto/openai/model.rs index 622fbeefce..c8ee834c15 100644 --- a/crates/forge_app/src/dto/openai/model.rs +++ b/crates/forge_app/src/dto/openai/model.rs @@ -1,4 +1,4 @@ -use forge_domain::ModelId; +use forge_domain::{ModelId, ProviderId}; use serde::{Deserialize, Serialize}; #[derive(Debug, Deserialize, Serialize, Clone)] @@ -88,11 +88,11 @@ pub struct ListModelResponse { pub data: Vec, } -impl From for forge_domain::Model { - fn from(value: Model) -> Self { +impl Model { + /// Converts this DTO model to a domain model with the specified provider_id + pub fn into_domain_model(self, provider_id: ProviderId) -> forge_domain::Model { let has_param = |name: &str| { - value - .supported_parameters + self.supported_parameters .as_ref() .map(|params| params.iter().any(|p| p == name)) }; @@ -102,10 +102,11 @@ impl From for forge_domain::Model { let supports_reasoning = has_param("reasoning"); forge_domain::Model { - id: value.id, - name: value.name, - description: value.description, - context_length: value.context_length, + id: self.id, + provider_id: Some(provider_id), + name: self.name, + description: self.description, + context_length: self.context_length, tools_supported, supports_parallel_tool_calls, supports_reasoning, @@ -263,7 +264,7 @@ mod tests { supported_parameters: None, // No supported_parameters field }; - let domain_model: forge_domain::Model = model.into(); + let domain_model = model.into_domain_model(forge_domain::ProviderId::OPENAI); // When supported_parameters is None, capabilities should be None (unknown) assert_eq!(domain_model.tools_supported, None); @@ -290,7 +291,7 @@ mod tests { ]), }; - let domain_model: forge_domain::Model = model.into(); + let domain_model = model.into_domain_model(forge_domain::ProviderId::OPENAI); // Should reflect what's actually in supported_parameters assert_eq!(domain_model.tools_supported, Some(true)); diff --git a/crates/forge_app/src/orch_spec/orch_setup.rs b/crates/forge_app/src/orch_spec/orch_setup.rs index b691dec6fb..6bb3380ad9 100644 --- a/crates/forge_app/src/orch_spec/orch_setup.rs +++ b/crates/forge_app/src/orch_spec/orch_setup.rs @@ -90,6 +90,7 @@ impl Default for TestContext { override_model: None, override_provider: None, enable_permissions: false, + model_cache_ttl_seconds: 3600, }, title: Some("test-conversation".into()), agent: Agent::new( diff --git a/crates/forge_app/src/services.rs b/crates/forge_app/src/services.rs index 542599d261..70328c6810 100644 --- a/crates/forge_app/src/services.rs +++ b/crates/forge_app/src/services.rs @@ -147,6 +147,15 @@ pub trait ProviderService: Send + Sync { async fn migrate_env_credentials( &self, ) -> anyhow::Result>; + + /// Cache aggregated models from all providers + async fn cache_all_models(&self, models: Vec); + + /// Get cached aggregated models if available and not expired + async fn get_cached_all_models(&self) -> Option>; + + /// Invalidate all caches (per-provider and aggregated) + async fn invalidate_caches(&self); } /// Manages user preferences for default providers and models. @@ -656,6 +665,18 @@ impl ProviderService for I { ) -> anyhow::Result> { self.provider_service().migrate_env_credentials().await } + + async fn cache_all_models(&self, models: Vec) { + self.provider_service().cache_all_models(models).await + } + + async fn get_cached_all_models(&self) -> Option> { + self.provider_service().get_cached_all_models().await + } + + async fn invalidate_caches(&self) { + self.provider_service().invalidate_caches().await + } } #[async_trait::async_trait] diff --git a/crates/forge_domain/src/env.rs b/crates/forge_domain/src/env.rs index 2e143ff0e7..ff3907f1da 100644 --- a/crates/forge_domain/src/env.rs +++ b/crates/forge_domain/src/env.rs @@ -94,6 +94,10 @@ pub struct Environment { /// Controlled by FORGE_ENABLE_PERMISSIONS environment variable. /// When enabled, tools will check policies before execution. pub enable_permissions: bool, + /// Cache TTL in seconds for model lists. + /// Controlled by FORGE_MODEL_CACHE_TTL environment variable. + /// Defaults to 3600 seconds (1 hour). + pub model_cache_ttl_seconds: u64, } impl Environment { @@ -154,6 +158,10 @@ impl Environment { pub fn cache_dir(&self) -> PathBuf { self.base_path.join("cache") } + /// Returns the path to the model cache file + pub fn model_cache_path(&self) -> PathBuf { + self.cache_dir().join("models.json") + } /// Returns the global skills directory path (~/forge/skills) pub fn global_skills_path(&self) -> PathBuf { @@ -303,6 +311,7 @@ fn test_command_path() { override_model: None, override_provider: None, enable_permissions: false, + model_cache_ttl_seconds: 3600, }; let actual = fixture.command_path(); @@ -343,6 +352,7 @@ fn test_command_cwd_path() { override_model: None, override_provider: None, enable_permissions: false, + model_cache_ttl_seconds: 3600, }; let actual = fixture.command_cwd_path(); @@ -383,6 +393,7 @@ fn test_command_cwd_path_independent_from_command_path() { override_model: None, override_provider: None, enable_permissions: false, + model_cache_ttl_seconds: 3600, }; let command_path = fixture.command_path(); diff --git a/crates/forge_domain/src/model.rs b/crates/forge_domain/src/model.rs index 74d2433825..4a509c2c7b 100644 --- a/crates/forge_domain/src/model.rs +++ b/crates/forge_domain/src/model.rs @@ -3,13 +3,16 @@ use derive_setters::Setters; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use crate::ProviderId; + #[derive(Clone, Debug, PartialEq, Deserialize, Serialize, Setters)] pub struct Model { pub id: ModelId, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub provider_id: Option, pub name: Option, pub description: Option, pub context_length: Option, - // TODO: add provider information to the model pub tools_supported: Option, /// Whether the model supports parallel tool calls pub supports_parallel_tool_calls: Option, diff --git a/crates/forge_infra/src/env.rs b/crates/forge_infra/src/env.rs index fa8d16f6c2..b04728dfaf 100644 --- a/crates/forge_infra/src/env.rs +++ b/crates/forge_infra/src/env.rs @@ -60,6 +60,7 @@ impl ForgeEnvironmentInfra { .and_then(|s| ProviderId::from_str(&s).ok()); let enable_permissions = parse_env::("FORGE_ENABLE_PERMISSIONS").unwrap_or(cfg!(debug_assertions)); + let model_cache_ttl_seconds = parse_env::("FORGE_MODEL_CACHE_TTL").unwrap_or(3600); Environment { os: std::env::consts::OS.to_string(), @@ -97,6 +98,7 @@ impl ForgeEnvironmentInfra { override_model, override_provider, enable_permissions, + model_cache_ttl_seconds, } } diff --git a/crates/forge_main/src/model.rs b/crates/forge_main/src/model.rs index aa5556d0a9..53a59b58b1 100644 --- a/crates/forge_main/src/model.rs +++ b/crates/forge_main/src/model.rs @@ -46,6 +46,11 @@ impl Display for CliModel { write!(f, " {}", info.dimmed())?; } + // Add provider information + if let Some(provider_id) = &self.0.provider_id { + write!(f, " {}", format!("[{}]", provider_id).dimmed())?; + } + Ok(()) } } @@ -920,6 +925,7 @@ mod tests { ) -> Model { Model { id: ModelId::new(id), + provider_id: Some(ProviderId::OPENAI), name: None, description: None, context_length, @@ -934,7 +940,7 @@ mod tests { let fixture = create_model_fixture("gpt-4", Some(128000), Some(true)); let formatted = format!("{}", CliModel(fixture)); let actual = strip_ansi_codes(&formatted); - let expected = "gpt-4 [ 128k 🛠️ ]"; + let expected = "gpt-4 [ 128k 🛠️ ] [OpenAI]"; assert_eq!(actual, expected); } @@ -943,7 +949,7 @@ mod tests { let fixture = create_model_fixture("claude-3", Some(2000000), Some(true)); let formatted = format!("{}", CliModel(fixture)); let actual = strip_ansi_codes(&formatted); - let expected = "claude-3 [ 2M 🛠️ ]"; + let expected = "claude-3 [ 2M 🛠️ ] [OpenAI]"; assert_eq!(actual, expected); } @@ -952,7 +958,7 @@ mod tests { let fixture = create_model_fixture("small-model", Some(512), Some(false)); let formatted = format!("{}", CliModel(fixture)); let actual = strip_ansi_codes(&formatted); - let expected = "small-model [ 512 ]"; + let expected = "small-model [ 512 ] [OpenAI]"; assert_eq!(actual, expected); } @@ -961,7 +967,7 @@ mod tests { let fixture = create_model_fixture("text-model", Some(4096), Some(false)); let formatted = format!("{}", CliModel(fixture)); let actual = strip_ansi_codes(&formatted); - let expected = "text-model [ 4k ]"; + let expected = "text-model [ 4k ] [OpenAI]"; assert_eq!(actual, expected); } @@ -970,7 +976,7 @@ mod tests { let fixture = create_model_fixture("tool-model", None, Some(true)); let formatted = format!("{}", CliModel(fixture)); let actual = strip_ansi_codes(&formatted); - let expected = "tool-model [ 🛠️ ]"; + let expected = "tool-model [ 🛠️ ] [OpenAI]"; assert_eq!(actual, expected); } @@ -979,7 +985,7 @@ mod tests { let fixture = create_model_fixture("basic-model", None, Some(false)); let formatted = format!("{}", CliModel(fixture)); let actual = strip_ansi_codes(&formatted); - let expected = "basic-model"; + let expected = "basic-model [OpenAI]"; assert_eq!(actual, expected); } @@ -988,7 +994,7 @@ mod tests { let fixture = create_model_fixture("unknown-model", None, None); let formatted = format!("{}", CliModel(fixture)); let actual = strip_ansi_codes(&formatted); - let expected = "unknown-model"; + let expected = "unknown-model [OpenAI]"; assert_eq!(actual, expected); } @@ -997,7 +1003,7 @@ mod tests { let fixture = create_model_fixture("exact-k", Some(8000), Some(true)); let formatted = format!("{}", CliModel(fixture)); let actual = strip_ansi_codes(&formatted); - let expected = "exact-k [ 8k 🛠️ ]"; + let expected = "exact-k [ 8k 🛠️ ] [OpenAI]"; assert_eq!(actual, expected); } @@ -1006,7 +1012,7 @@ mod tests { let fixture = create_model_fixture("exact-m", Some(1000000), Some(true)); let formatted = format!("{}", CliModel(fixture)); let actual = strip_ansi_codes(&formatted); - let expected = "exact-m [ 1M 🛠️ ]"; + let expected = "exact-m [ 1M 🛠️ ] [OpenAI]"; assert_eq!(actual, expected); } @@ -1015,7 +1021,7 @@ mod tests { let fixture = create_model_fixture("edge-999", Some(999), None); let formatted = format!("{}", CliModel(fixture)); let actual = strip_ansi_codes(&formatted); - let expected = "edge-999 [ 999 ]"; + let expected = "edge-999 [ 999 ] [OpenAI]"; assert_eq!(actual, expected); } @@ -1024,7 +1030,7 @@ mod tests { let fixture = create_model_fixture("edge-1001", Some(1001), None); let formatted = format!("{}", CliModel(fixture)); let actual = strip_ansi_codes(&formatted); - let expected = "edge-1001 [ 1k ]"; + let expected = "edge-1001 [ 1k ] [OpenAI]"; assert_eq!(actual, expected); } diff --git a/crates/forge_main/src/ui.rs b/crates/forge_main/src/ui.rs index 2e2910f252..60e3732b35 100644 --- a/crates/forge_main/src/ui.rs +++ b/crates/forge_main/src/ui.rs @@ -126,6 +126,15 @@ impl A + Send + Sync> UI { Ok(models) } + /// Retrieve available models (with caching and spinner when needed) + async fn get_models_with_spinner(&mut self) -> Result> { + // Show spinner while fetching (caching is handled internally) + self.spinner.start(Some("Loading models"))?; + let models = self.api.get_models().await?; + self.spinner.stop(None)?; + Ok(models) + } + /// Helper to get provider for an optional agent, defaulting to the current /// active agent's provider async fn get_provider(&self, agent_id: Option) -> Result> { @@ -347,8 +356,13 @@ impl A + Send + Sync> UI { // Improve startup time by hydrating caches fn hydrate_caches(&self) { + // Hydrate models from all providers let api = self.api.clone(); - tokio::spawn(async move { api.get_models().await }); + tokio::spawn(async move { + if let Err(e) = api.get_models().await { + tracing::warn!("Failed to hydrate multi-provider model cache: {}", e); + } + }); let api = self.api.clone(); tokio::spawn(async move { api.get_tools().await }); let api = self.api.clone(); @@ -1026,20 +1040,37 @@ impl A + Send + Sync> UI { /// Lists all the models async fn on_show_models(&mut self, porcelain: bool) -> anyhow::Result<()> { - let models = self.get_models().await?; + // Fetch models from all providers + let mut models = self.get_models_with_spinner().await?; if models.is_empty() { return Ok(()); } + // Sort by provider first, then by model name + models.sort_by(|a, b| { + let provider_cmp = a.provider_id.cmp(&b.provider_id); + if provider_cmp == std::cmp::Ordering::Equal { + a.id.to_string().cmp(&b.id.to_string()) + } else { + provider_cmp + } + }); + let mut info = Info::new(); for model in models.iter() { let id = model.id.to_string(); + let provider = model + .provider_id + .as_ref() + .map(|p| p.to_string()) + .unwrap_or_else(|| "unknown".to_string()); info = info .add_title(model.name.as_ref().unwrap_or(&id)) - .add_key_value("Id", id); + .add_key_value("Id", id) + .add_key_value("Provider", provider); // Add context length if available, otherwise use "unknown" if let Some(limit) = model.context_length { @@ -1735,10 +1766,10 @@ impl A + Send + Sync> UI { } /// Select a model from the available models - /// Returns Some(ModelId) if a model was selected, or None if selection was + /// Returns Some(Model) if a model was selected, or None if selection was /// canceled #[async_recursion::async_recursion] - async fn select_model(&mut self) -> Result> { + async fn select_model(&mut self) -> Result> { // Check if provider is set otherwise first ask to select a provider if self.api.get_default_provider().await.is_err() { self.on_provider_selection().await?; @@ -1751,7 +1782,7 @@ impl A + Send + Sync> UI { } } - // Fetch available models + // Fetch available models from all configured providers let mut models = self .get_models() .await? @@ -1759,8 +1790,23 @@ impl A + Send + Sync> UI { .map(CliModel) .collect::>(); - // Sort the models by their names in ascending order - models.sort_by(|a, b| a.0.name.cmp(&b.0.name)); + // Sort by provider first, then by model name + models.sort_by(|a, b| { + let provider_a = + a.0.provider_id + .as_ref() + .map(|p| p.to_string()) + .unwrap_or_default(); + let provider_b = + b.0.provider_id + .as_ref() + .map(|p| p.to_string()) + .unwrap_or_default(); + match provider_a.cmp(&provider_b) { + std::cmp::Ordering::Equal => a.0.name.cmp(&b.0.name), + other => other, + } + }); // Find the index of the current model let current_model = self @@ -1777,7 +1823,7 @@ impl A + Send + Sync> UI { .with_help_message("Type a name or use arrow keys to navigate and Enter to select") .prompt()? { - Some(model) => Ok(Some(model.0.id)), + Some(model) => Ok(Some(model.0)), None => Ok(None), } } @@ -2158,15 +2204,44 @@ impl A + Send + Sync> UI { None => return Ok(None), }; + // Check if we need to switch providers + let current_provider = self.api.get_default_provider().await.ok(); + let needs_provider_switch = model + .provider_id + .as_ref() + .and_then(|model_provider| current_provider.as_ref().map(|p| p.id != *model_provider)) + .unwrap_or(false); + + // Switch provider if needed + if needs_provider_switch && let Some(provider_id) = &model.provider_id { + self.api.set_default_provider(provider_id.clone()).await?; + } + // Update the operating model via API - self.api.set_default_model(model.clone()).await?; + self.api.set_default_model(model.id.clone()).await?; // Update the UI state with the new model - self.update_model(Some(model.clone())); + self.update_model(Some(model.id.clone())); - self.writeln_title(TitleFormat::action(format!("Switched to model: {model}")))?; + // Show appropriate message + if needs_provider_switch { + let provider_name = model + .provider_id + .as_ref() + .map(|p| p.to_string()) + .unwrap_or_else(|| "unknown".to_string()); + self.writeln_title(TitleFormat::action(format!( + "Switched to model: {} on provider: {}", + model.id, provider_name + )))?; + } else { + self.writeln_title(TitleFormat::action(format!( + "Switched to model: {}", + model.id + )))?; + } - Ok(Some(model)) + Ok(Some(model.id)) } async fn on_provider_selection(&mut self) -> Result<()> { @@ -2775,10 +2850,49 @@ impl A + Send + Sync> UI { } ConfigField::Model => { let model_id = self.validate_model(&args.value).await?; + + // Get the full model object to check provider + let models = self.api.get_models().await?; + let model = models + .iter() + .find(|m| m.id == model_id) + .ok_or_else(|| anyhow::anyhow!("Model not found after validation"))?; + + // Check if we need to switch providers + let current_provider = self.api.get_default_provider().await.ok(); + let needs_provider_switch = model + .provider_id + .as_ref() + .and_then(|model_provider| { + current_provider.as_ref().map(|p| p.id != *model_provider) + }) + .unwrap_or(false); + + // Switch provider if needed + if needs_provider_switch && let Some(provider_id) = &model.provider_id { + self.api.set_default_provider(provider_id.clone()).await?; + } + + // Set the model self.api.set_default_model(model_id.clone()).await?; - self.writeln_title( - TitleFormat::action(model_id.as_str()).sub_title("is now the default model"), - )?; + + // Show appropriate message + if needs_provider_switch { + let provider_name = model + .provider_id + .as_ref() + .map(|p| p.to_string()) + .unwrap_or_else(|| "unknown".to_string()); + self.writeln_title(TitleFormat::action(format!( + "Switched to model: {} on provider: {}", + model_id, provider_name + )))?; + } else { + self.writeln_title( + TitleFormat::action(model_id.as_str()) + .sub_title("is now the default model"), + )?; + } } } @@ -2856,7 +2970,7 @@ impl A + Send + Sync> UI { Some(rprompt.to_string()) } - /// Validate model exists + /// Validate model exists across all configured providers async fn validate_model(&self, model_str: &str) -> Result { let models = self.api.get_models().await?; let model_id = ModelId::new(model_str); diff --git a/crates/forge_services/src/app_config.rs b/crates/forge_services/src/app_config.rs index 545509c51d..95ccc18e27 100644 --- a/crates/forge_services/src/app_config.rs +++ b/crates/forge_services/src/app_config.rs @@ -138,6 +138,7 @@ mod tests { tools_supported: Some(true), supports_parallel_tool_calls: Some(true), supports_reasoning: Some(false), + provider_id: Some(ProviderId::OPENAI), }])), }, Provider { @@ -162,6 +163,7 @@ mod tests { tools_supported: Some(true), supports_parallel_tool_calls: Some(true), supports_reasoning: Some(true), + provider_id: Some(ProviderId::ANTHROPIC), }])), }, ], diff --git a/crates/forge_services/src/provider/anthropic.rs b/crates/forge_services/src/provider/anthropic.rs index f2d74c7ce9..d45ebf34bf 100644 --- a/crates/forge_services/src/provider/anthropic.rs +++ b/crates/forge_services/src/provider/anthropic.rs @@ -110,7 +110,10 @@ impl Anthropic { Ok(Box::pin(stream)) } - pub async fn models(&self) -> anyhow::Result> { + pub async fn models( + &self, + provider_id: forge_domain::ProviderId, + ) -> anyhow::Result> { match &self.models { forge_domain::ModelSource::Url(url) => { debug!(url = %url, "Fetching models"); @@ -134,7 +137,11 @@ impl Anthropic { let response: ListModelResponse = serde_json::from_str(&text) .with_context(|| ctx_msg) .with_context(|| "Failed to deserialize models response")?; - Ok(response.data.into_iter().map(Into::into).collect()) + Ok(response + .data + .into_iter() + .map(|m| m.into_domain_model(provider_id.clone())) + .collect()) } else { // treat non 200 response as error. Err(anyhow::anyhow!(text)) @@ -144,7 +151,12 @@ impl Anthropic { } forge_domain::ModelSource::Hardcoded(models) => { debug!("Using hardcoded models"); - Ok(models.clone()) + let mut models = models.clone(); + // Set provider_id on all hardcoded models + for model in &mut models { + model.provider_id = Some(provider_id.clone()); + } + Ok(models) } } } @@ -321,7 +333,9 @@ mod tests { .mock_models(create_mock_models_response(), 200) .await; let anthropic = create_anthropic(&fixture.url())?; - let actual = anthropic.models().await?; + let actual = anthropic + .models(forge_domain::ProviderId::ANTHROPIC) + .await?; mock.assert_async().await; @@ -339,7 +353,7 @@ mod tests { .await; let anthropic = create_anthropic(&fixture.url())?; - let actual = anthropic.models().await; + let actual = anthropic.models(forge_domain::ProviderId::ANTHROPIC).await; mock.assert_async().await; @@ -357,7 +371,7 @@ mod tests { .await; let anthropic = create_anthropic(&fixture.url())?; - let actual = anthropic.models().await; + let actual = anthropic.models(forge_domain::ProviderId::ANTHROPIC).await; mock.assert_async().await; @@ -374,7 +388,9 @@ mod tests { let mock = fixture.mock_models(create_empty_response(), 200).await; let anthropic = create_anthropic(&fixture.url())?; - let actual = anthropic.models().await?; + let actual = anthropic + .models(forge_domain::ProviderId::ANTHROPIC) + .await?; mock.assert_async().await; assert!(actual.is_empty()); diff --git a/crates/forge_services/src/provider/bedrock/provider.rs b/crates/forge_services/src/provider/bedrock/provider.rs index f006efdd54..0b7508048e 100644 --- a/crates/forge_services/src/provider/bedrock/provider.rs +++ b/crates/forge_services/src/provider/bedrock/provider.rs @@ -250,11 +250,18 @@ impl BedrockProvider { } /// Get available models - pub async fn models(&self) -> Result> { + pub async fn models(&self, provider_id: forge_domain::ProviderId) -> Result> { // Bedrock doesn't have a models list API // Return hardcoded models from configuration match &self.provider.models { - Some(forge_domain::ModelSource::Hardcoded(models)) => Ok(models.clone()), + Some(forge_domain::ModelSource::Hardcoded(models)) => { + let mut models = models.clone(); + // Set provider_id on all hardcoded models + for model in &mut models { + model.provider_id = Some(provider_id.clone()); + } + Ok(models) + } _ => Ok(vec![]), } } @@ -1216,6 +1223,7 @@ mod tests { tools_supported: None, supports_parallel_tool_calls: None, supports_reasoning: None, + provider_id: Some(forge_domain::ProviderId::BEDROCK), }, Model { id: ModelId::from("claude-3-sonnet".to_string()), @@ -1225,6 +1233,7 @@ mod tests { tools_supported: None, supports_parallel_tool_calls: None, supports_reasoning: None, + provider_id: Some(forge_domain::ProviderId::BEDROCK), }, ]; fixture_provider.models = Some(ModelSource::Hardcoded(fixture_models.clone())); @@ -1236,7 +1245,10 @@ mod tests { _phantom: std::marker::PhantomData::, }; - let actual = bedrock.models().await.unwrap(); + let actual = bedrock + .models(forge_domain::ProviderId::BEDROCK) + .await + .unwrap(); let expected = fixture_models; assert_eq!(actual, expected); } @@ -1251,7 +1263,10 @@ mod tests { _phantom: std::marker::PhantomData::, }; - let actual = bedrock.models().await.unwrap(); + let actual = bedrock + .models(forge_domain::ProviderId::BEDROCK) + .await + .unwrap(); let expected: Vec = vec![]; assert_eq!(actual, expected); } diff --git a/crates/forge_services/src/provider/client.rs b/crates/forge_services/src/provider/client.rs index 30a91b1fde..e93b739a54 100644 --- a/crates/forge_services/src/provider/client.rs +++ b/crates/forge_services/src/provider/client.rs @@ -105,6 +105,7 @@ impl ClientBuilder { inner: Arc::new(inner), retry_config, models_cache: Arc::new(RwLock::new(HashMap::new())), + provider_id: provider.id.clone(), }) } } @@ -113,6 +114,7 @@ pub struct Client { retry_config: Arc, inner: Arc>, models_cache: Arc>>, + provider_id: forge_domain::ProviderId, } impl Clone for Client { @@ -121,6 +123,7 @@ impl Clone for Client { retry_config: self.retry_config.clone(), inner: self.inner.clone(), models_cache: self.models_cache.clone(), + provider_id: self.provider_id.clone(), } } } @@ -138,10 +141,11 @@ impl Client { } pub async fn refresh_models(&self) -> anyhow::Result> { + let provider_id = self.provider_id.clone(); let models = self.clone().retry(match self.inner.as_ref() { - InnerClient::OpenAICompat(provider) => provider.models().await, - InnerClient::Anthropic(provider) => provider.models().await, - InnerClient::Bedrock(provider) => provider.models().await, + InnerClient::OpenAICompat(provider) => provider.models(provider_id).await, + InnerClient::Anthropic(provider) => provider.models(provider_id).await, + InnerClient::Bedrock(provider) => provider.models(provider_id).await, })?; // Update the cache with all fetched models diff --git a/crates/forge_services/src/provider/openai.rs b/crates/forge_services/src/provider/openai.rs index 5d2f2ffbc8..2de10f9062 100644 --- a/crates/forge_services/src/provider/openai.rs +++ b/crates/forge_services/src/provider/openai.rs @@ -120,11 +120,19 @@ impl OpenAIProvider { Ok(Box::pin(stream)) } - async fn inner_models(&self) -> Result> { + async fn inner_models( + &self, + provider_id: forge_domain::ProviderId, + ) -> Result> { // For Vertex AI, load models from static JSON file using VertexProvider logic if self.provider.id == ProviderId::VERTEX_AI { debug!("Loading Vertex AI models from static JSON file"); - Ok(self.inner_vertex_models()) + let mut models = self.inner_vertex_models(); + // Set provider_id on all models + for model in &mut models { + model.provider_id = Some(provider_id.clone()); + } + Ok(models) } else { let models = self .provider @@ -143,13 +151,22 @@ impl OpenAIProvider { let data: ListModelResponse = serde_json::from_str(&response) .with_context(|| format_http_context(None, "GET", url)) .with_context(|| "Failed to deserialize models response")?; - Ok(data.data.into_iter().map(Into::into).collect()) + Ok(data + .data + .into_iter() + .map(|m| m.into_domain_model(provider_id.clone())) + .collect()) } } } forge_domain::ModelSource::Hardcoded(models) => { debug!("Using hardcoded models"); - Ok(models.clone()) + let mut models = models.clone(); + // Set provider_id on all hardcoded models + for model in &mut models { + model.provider_id = Some(provider_id.clone()); + } + Ok(models) } } } @@ -207,8 +224,11 @@ impl OpenAIProvider { self.inner_chat(model, context).await } - pub async fn models(&self) -> Result> { - self.inner_models().await + pub async fn models( + &self, + provider_id: forge_domain::ProviderId, + ) -> Result> { + self.inner_models(provider_id).await } } @@ -408,7 +428,7 @@ mod tests { .mock_models(create_mock_models_response(), 200) .await; let provider = create_provider(&fixture.url())?; - let actual = provider.models().await?; + let actual = provider.models(ProviderId::OPENAI).await?; mock.assert_async().await; insta::assert_json_snapshot!(actual); @@ -423,7 +443,7 @@ mod tests { .await; let provider = create_provider(&fixture.url())?; - let actual = provider.models().await; + let actual = provider.models(ProviderId::OPENAI).await; mock.assert_async().await; @@ -441,7 +461,7 @@ mod tests { .await; let provider = create_provider(&fixture.url())?; - let actual = provider.models().await; + let actual = provider.models(ProviderId::OPENAI).await; mock.assert_async().await; @@ -457,7 +477,7 @@ mod tests { let mock = fixture.mock_models(create_empty_response(), 200).await; let provider = create_provider(&fixture.url())?; - let actual = provider.models().await?; + let actual = provider.models(ProviderId::OPENAI).await?; mock.assert_async().await; assert!(actual.is_empty()); @@ -491,7 +511,7 @@ mod tests { let mock = fixture.mock_models(detailed_error, 401).await; let provider = create_provider(&fixture.url())?; - let actual = provider.models().await; + let actual = provider.models(ProviderId::OPENAI).await; mock.assert_async().await; assert!(actual.is_err()); diff --git a/crates/forge_services/src/provider/service.rs b/crates/forge_services/src/provider/service.rs index df061cdd6c..a99ca1cff8 100644 --- a/crates/forge_services/src/provider/service.rs +++ b/crates/forge_services/src/provider/service.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; use std::sync::Arc; +use std::time::{Duration, SystemTime}; use anyhow::{Context, Result}; use forge_app::domain::{ @@ -13,11 +14,60 @@ use url::Url; use crate::http::HttpClient; use crate::provider::client::{Client, ClientBuilder}; + +/// Flat cache structure for all models +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +struct FlatModelCache { + models: Vec, + #[serde(with = "systemtime_serde")] + cached_at: SystemTime, +} + +/// Custom serialization for SystemTime +mod systemtime_serde { + use std::time::{Duration, SystemTime, UNIX_EPOCH}; + + use serde::{Deserialize, Deserializer, Serializer}; + + pub fn serialize(time: &SystemTime, serializer: S) -> Result + where + S: Serializer, + { + let duration = time + .duration_since(UNIX_EPOCH) + .map_err(serde::ser::Error::custom)?; + serializer.serialize_u64(duration.as_secs()) + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let secs = u64::deserialize(deserializer)?; + Ok(UNIX_EPOCH + Duration::from_secs(secs)) + } +} + +impl FlatModelCache { + fn new(models: Vec) -> Self { + Self { models, cached_at: SystemTime::now() } + } + + /// Check if cache has expired based on TTL + fn is_expired(&self, ttl: Duration) -> bool { + SystemTime::now() + .duration_since(self.cached_at) + .map(|age| age > ttl) + .unwrap_or(true) + } +} + #[derive(Clone)] pub struct ForgeProviderService { retry_config: Arc, cached_clients: Arc>>>>, - cached_models: Arc>>>, + flat_cache: Arc>>, + cache_ttl: Duration, version: String, timeout_config: HttpConfig, infra: Arc, @@ -27,16 +77,57 @@ impl ForgeProviderService { pub fn new(infra: Arc) -> Self { let env = infra.get_environment(); let version = env.version(); + let cache_ttl = Duration::from_secs(env.model_cache_ttl_seconds); + + // Load flat cache from disk if available + let flat_cache = Self::load_cache_from_disk(&env); + let retry_config = Arc::new(env.retry_config); + let timeout_config = env.http; + Self { retry_config, cached_clients: Arc::new(Mutex::new(HashMap::new())), - cached_models: Arc::new(Mutex::new(HashMap::new())), + flat_cache: Arc::new(Mutex::new(flat_cache)), + cache_ttl, version, - timeout_config: env.http, + timeout_config, infra, } } + + /// Create a new service with custom cache TTL + pub fn with_cache_ttl(infra: Arc, cache_ttl: Duration) -> Self { + let mut service = Self::new(infra); + service.cache_ttl = cache_ttl; + service + } + + /// Load flat cache from disk + fn load_cache_from_disk(env: &forge_domain::Environment) -> Option { + let cache_path = env.model_cache_path(); + if !cache_path.exists() { + return None; + } + + let content = std::fs::read_to_string(&cache_path).ok()?; + serde_json::from_str(&content).ok() + } + + /// Save flat cache to disk + fn save_cache_to_disk(env: &forge_domain::Environment, cache: &FlatModelCache) { + let cache_path = env.model_cache_path(); + + // Ensure cache directory exists + if let Some(parent) = cache_path.parent() { + let _ = std::fs::create_dir_all(parent); + } + + // Serialize and write cache + if let Ok(content) = serde_json::to_string_pretty(cache) { + let _ = std::fs::write(&cache_path, content); + } + } } impl ForgeProviderService { @@ -88,27 +179,9 @@ impl ProviderService } async fn models(&self, provider: Provider) -> Result> { - let provider_id = provider.id.clone(); - - // Check cache first - { - let models_guard = self.cached_models.lock().await; - if let Some(cached_models) = models_guard.get(&provider_id) { - return Ok(cached_models.clone()); - } - } - - // Models not in cache, fetch from client + // Fetch models from client (no per-provider caching) let client = self.client(provider).await?; - let models = client.models().await?; - - // Cache the models for this provider - { - let mut models_guard = self.cached_models.lock().await; - models_guard.insert(provider_id, models.clone()); - } - - Ok(models) + client.models().await } async fn get_provider(&self, id: ProviderId) -> Result> { @@ -132,6 +205,9 @@ impl ProviderService clients_guard.remove(&provider_id); } + // Invalidate flat cache since credentials changed + self.invalidate_caches().await; + Ok(()) } @@ -144,10 +220,42 @@ impl ProviderService clients_guard.remove(id); } + // Invalidate flat cache since credentials removed + self.invalidate_caches().await; + Ok(()) } async fn migrate_env_credentials(&self) -> Result> { self.infra.migrate_env_credentials().await } + + async fn cache_all_models(&self, models: Vec) { + let mut cache_guard = self.flat_cache.lock().await; + *cache_guard = Some(FlatModelCache::new(models)); + + // Save cache to disk + if let Some(cache) = cache_guard.as_ref() { + let env = self.infra.get_environment(); + Self::save_cache_to_disk(&env, cache); + } + } + + async fn get_cached_all_models(&self) -> Option> { + let cache_guard = self.flat_cache.lock().await; + cache_guard + .as_ref() + .filter(|cached| !cached.is_expired(self.cache_ttl)) + .map(|cached| cached.models.clone()) + } + + async fn invalidate_caches(&self) { + let mut cache_guard = self.flat_cache.lock().await; + *cache_guard = None; + + // Delete cache file from disk + let env = self.infra.get_environment(); + let cache_path = env.model_cache_path(); + let _ = std::fs::remove_file(cache_path); + } } diff --git a/crates/forge_services/src/provider/snapshots/forge_services__provider__anthropic__tests__fetch_models_success.snap b/crates/forge_services/src/provider/snapshots/forge_services__provider__anthropic__tests__fetch_models_success.snap index 8099e44d39..81685ae634 100644 --- a/crates/forge_services/src/provider/snapshots/forge_services__provider__anthropic__tests__fetch_models_success.snap +++ b/crates/forge_services/src/provider/snapshots/forge_services__provider__anthropic__tests__fetch_models_success.snap @@ -5,6 +5,7 @@ expression: actual [ { "id": "claude-3-5-sonnet-20241022", + "provider_id": "anthropic", "name": "Claude 3.5 Sonnet (New)", "description": null, "context_length": 200000, @@ -14,6 +15,7 @@ expression: actual }, { "id": "claude-3-5-haiku-20241022", + "provider_id": "anthropic", "name": "Claude 3.5 Haiku", "description": null, "context_length": 200000, diff --git a/crates/forge_services/src/provider/snapshots/forge_services__provider__openai__tests__fetch_models_success.snap b/crates/forge_services/src/provider/snapshots/forge_services__provider__openai__tests__fetch_models_success.snap index 3eed58c319..3340c7b4e2 100644 --- a/crates/forge_services/src/provider/snapshots/forge_services__provider__openai__tests__fetch_models_success.snap +++ b/crates/forge_services/src/provider/snapshots/forge_services__provider__openai__tests__fetch_models_success.snap @@ -5,6 +5,7 @@ expression: actual [ { "id": "model-1", + "provider_id": "openai", "name": "Test Model 1", "description": "A test model", "context_length": 4096, @@ -14,6 +15,7 @@ expression: actual }, { "id": "model-2", + "provider_id": "openai", "name": "Test Model 2", "description": "Another test model", "context_length": 8192, diff --git a/plans/2025-12-20-unified-model-provider-selection-v1.md b/plans/2025-12-20-unified-model-provider-selection-v1.md new file mode 100644 index 0000000000..06734c9add --- /dev/null +++ b/plans/2025-12-20-unified-model-provider-selection-v1.md @@ -0,0 +1,199 @@ +# Unified Model and Provider Selection with Caching + +## Objective + +Implement a unified model and provider selection interface that allows users to select both model and provider simultaneously. The system should fetch and cache models from all logged-in providers at once, display provider information alongside models, automatically set the provider when a model is selected, and maintain high performance through intelligent caching that works seamlessly in both zsh and REPL environments. + +## Implementation Plan + +- [x] 1. **Add provider_id field to Model domain structure** + - Modify the Model struct in `crates/forge_domain/src/model.rs:7` to include `pub provider_id: ProviderId` field + - This addresses the TODO comment at `crates/forge_domain/src/model.rs:12` about adding provider information to the model + - Update all Model constructors and tests to include provider_id + - Rationale: Models need to carry provider context through the entire system for unified selection and proper provider switching + +- [x] 2. **Update DTO to Domain conversions to preserve provider context** + - Modify the From implementation in `crates/forge_app/src/dto/openai/model.rs:91-114` to accept and set provider_id during conversion + - Update the From implementation in `crates/forge_app/src/dto/anthropic/response.rs:21-33` similarly + - Change the conversion functions to take provider_id as a parameter or use a builder pattern + - Rationale: Provider context is currently lost during DTO conversion, this ensures it flows through to the domain model + +- [x] 3. **Enhance provider service to attach provider_id when fetching models** + - Modify the models method in `crates/forge_services/src/provider/service.rs:90-112` to set provider_id on each returned Model + - Update the client's models method in `crates/forge_services/src/provider/client.rs` to accept and propagate provider_id + - Update the inner_models method in `crates/forge_services/src/provider/openai.rs:123-156` to set provider_id on fetched models + - Rationale: Ensures provider context is attached at the point where models are fetched from provider APIs + +- [x] 4. **Create new API method to fetch models from all configured providers** + - Add get_all_models method in `crates/forge_app/src/app.rs` that fetches models from all providers returned by get_all_providers + - Use parallel async fetching with tokio join_all or FuturesUnordered to fetch from all providers concurrently + - Filter to only include LLM providers using ProviderType filter similar to `crates/forge_main/src/ui.rs:2114` + - Handle errors gracefully per provider - if one provider fails, still return models from others + - Return a flattened Vec of all models with provider_id populated + - Rationale: Current get_models only fetches from default provider, we need models from all logged-in providers for unified selection + +- [x] 5. **Add API endpoint for get_all_models in forge_api layer** + - Create get_all_models method in `crates/forge_api/src/api.rs:24` trait + - Implement it in `crates/forge_api/src/forge_api.rs` to call the app layer's get_all_models + - Ensure proper error handling and result mapping + - Rationale: Maintains clean architecture by exposing the new functionality through the API layer + +- [x] 6. **Enhance cached_models to support multi-provider aggregation** + - Consider renaming or adding a new cache field in `crates/forge_services/src/provider/service.rs:20` for aggregated models + - Add a new cache entry that stores all models across providers with a special key or separate field + - Implement cache invalidation logic - when any provider's credentials change, clear the aggregated cache + - Add timestamp to cache entries to implement TTL (time-to-live) for auto-expiration + - **IMPLEMENTED**: Added file-based persistent caching to `~/forge/cache/models.json` for zsh compatibility + - Rationale: Current cache is per-provider, we need efficient caching for the aggregated multi-provider model list + +- [x] 7. **Implement cache TTL and refresh strategy** + - Add a timestamp field to cached entries in the cached_models HashMap + - Implement a configurable TTL (e.g., 1 hour) that can be set via environment or config + - Add a check before returning cached models to verify they haven't exceeded TTL + - If TTL expired, trigger background refresh while returning stale data, or block and fetch fresh data based on configuration + - Rationale: Prevents showing outdated model lists when providers add new models, balancing freshness with performance + +- [x] 8. **Update CliModel display to show provider information** + - Modify the Display implementation in `crates/forge_main/src/model.rs:18-51` to include provider name or identifier + - Format could be: "gpt-4 [ 128k 🛠️ ] [OpenAI]" or use provider as a prefix: "[OpenAI] gpt-4 [ 128k 🛠️ ]" + - Retrieve provider name from the model's provider_id field + - Consider adding a helper method to format provider display name + - Rationale: Users need to see which provider each model belongs to for informed selection + +- [x] 9. **Create new get_all_models method in UI layer** + - Add get_all_models method in `crates/forge_main/src/ui.rs` similar to the existing get_models at line 122-127 + - Call the new API endpoint api.get_all_models + - Handle the spinner for loading indication + - Cache the results locally in the UI for quick subsequent access during the same session + - Rationale: Provides UI-level access to multi-provider model list with proper loading UX + +- [x] 10. **Update select_model to use multi-provider model list** + - Modify the select_model method in `crates/forge_main/src/ui.rs:1741-1783` to call get_all_models instead of get_models + - Update sorting logic to sort by provider first, then by model name, or add a filter option + - Consider adding a provider group header in the selection UI to visually separate models by provider + - Ensure starting_cursor logic still works correctly with the expanded model list + - Rationale: Enables unified model selection showing all available models across providers + +- [x] 11. **Implement automatic provider switching on model selection** + - After user selects a model in on_model_selection method at `crates/forge_main/src/ui.rs:2151-2170`, extract the provider_id from the selected model + - Check if the selected model's provider differs from the current default provider + - If different, call api.set_default_provider with the model's provider_id before calling api.set_default_model + - Update UI state to reflect both the new model and provider + - Display a message like "Switched to model: {model} on provider: {provider}" + - Rationale: Seamless user experience - selecting a model from a different provider automatically switches to that provider + +- [x] 12. **Update model list command to show provider column** + - Modify on_show_models method in `crates/forge_main/src/ui.rs:1028-1076` to fetch all models from all providers + - Add a provider column to the Info output using add_key_value for the provider name + - For porcelain format, add provider as a column in the table + - Sort output by provider first, then by model name for better readability + - Rationale: Provides clear visibility of which models belong to which provider in list commands + +- [x] 13. **Optimize cache hydration for multi-provider scenario** + - Update hydrate_caches method in `crates/forge_main/src/ui.rs:348-360` to spawn a task for get_all_models instead of get_models + - Consider adding a priority system - fetch models from default provider first, then others + - Add error logging if multi-provider fetch fails, but don't block app startup + - Rationale: Pre-warms the aggregated cache on startup for instant model selection + +- [x] 14. **Add configuration option for model list caching behavior** + - Add a new config field in forge.yaml schema for model_cache_ttl_seconds + - Add a config field for model_cache_strategy with options like "aggressive" (cache indefinitely), "moderate" (1 hour TTL), "fresh" (always fetch) + - Update the config loading in `crates/forge_app/src/app.rs` to read and apply these settings + - Use these settings in the provider service caching logic + - Rationale: Gives users control over cache behavior based on their needs and network conditions + +- [ ] 15. **Implement ForgeSelect enhancement for provider grouping** + - Consider enhancing ForgeSelect in `crates/forge_select/src/select.rs` to support grouped items with section headers + - If grouped selection is complex, alternatively pre-format the model list with provider headers as non-selectable items + - Or use a simpler approach with provider prefix in the display string + - Rationale: Improves UX by visually organizing models by provider in the selection interface + +- [ ] 16. **Add cache statistics and diagnostics** + - Add methods to report cache statistics like hit/miss rates, number of cached entries, cache memory usage + - Consider adding a diagnostic command like "forge debug cache" to show cache status + - Log cache hits and misses at debug level for troubleshooting + - Rationale: Helps monitor cache effectiveness and diagnose performance issues + +- [ ] 17. **Update tests to handle provider_id in Model** + - Update all model fixtures in tests to include provider_id field + - Add tests for multi-provider model fetching scenarios + - Test cache invalidation when credentials change + - Test TTL expiration and refresh logic + - Test automatic provider switching when selecting models from different providers + - Rationale: Ensures the new functionality works correctly and prevents regressions + +- [ ] 18. **Handle migration for existing stored models** + - Check if models are persisted in the database in `crates/forge_repo` + - If yes, create a migration to add provider_id column to the models table + - Implement backward compatibility for models without provider_id + - Consider a data migration script to populate provider_id for existing models based on current provider configuration + - Rationale: Ensures existing installations can upgrade smoothly without data loss + +- [ ] 19. **Update documentation for new model selection behavior** + - Update user-facing documentation to explain the unified model and provider selection + - Document the new cache configuration options + - Add examples of how provider is automatically set when selecting a model + - Document the cache TTL behavior and how to customize it + - Rationale: Users need to understand the new behavior and configuration options + +## Verification Criteria + +- Model struct includes provider_id field and all conversions preserve this information +- Calling get_all_models returns models from all configured LLM providers with provider_id populated +- Models are cached per provider and also in an aggregated cache for multi-provider access +- Cache respects configured TTL and refreshes expired entries appropriately +- Model selection UI displays provider information alongside each model +- Selecting a model from provider A while provider B is active automatically switches to provider A +- The forge list model command shows provider column in output +- Cache is pre-warmed on startup without blocking user interaction +- Configuration options for cache TTL and strategy are respected +- Tests pass including new tests for multi-provider scenarios +- Performance is maintained or improved - model selection is fast due to caching +- Works correctly in both zsh completion scenarios and interactive REPL mode +- Cache invalidation works correctly when credentials are added, updated, or removed + +## Potential Risks and Mitigations + +1. **Performance degradation when fetching from many providers** + Mitigation: Use parallel async fetching with tokio join_all to fetch from all providers concurrently. Add timeout configuration to prevent slow providers from blocking the entire operation. Implement aggressive caching with reasonable TTL. + +2. **Memory usage increase with multi-provider caching** + Mitigation: Implement cache size limits and LRU eviction if memory becomes a concern. Monitor cache statistics. Make TTL configurable so users can tune based on their environment. Consider storing Arc references instead of cloning model data. + +3. **Breaking changes to Model struct affecting existing code** + Mitigation: Carefully review all usages of Model struct across the codebase. Update all constructors, builders, and tests. Use derive_setters to provide flexible construction. Implement database migration for any persisted models. + +4. **Cache invalidation complexity with multiple providers** + Mitigation: Clear aggregated cache whenever any provider's credentials change. Keep invalidation logic simple and conservative - when in doubt, invalidate. Log cache operations at debug level for troubleshooting. + +5. **UI becoming cluttered with too many models** + Mitigation: Implement smart sorting and grouping by provider. Consider adding filtering options in the future. Use ForgeSelect's search functionality to help users find models quickly. Show provider headers or prefixes to organize the list. + +6. **Race conditions in concurrent cache access** + Mitigation: Continue using Mutex for service-level caches to ensure thread-safety. Consider upgrading to RwLock if read-heavy workloads show contention. Ensure proper lock ordering to prevent deadlocks. + +7. **Inconsistent provider state when auto-switching** + Mitigation: Update both model and provider atomically in the API layer. Ensure UI state is updated consistently. Add clear user feedback when provider is automatically switched. Consider adding a confirmation prompt if desired. + +8. **Backward compatibility with existing configurations** + Mitigation: Make provider_id optional (Option) initially if needed for migration. Provide sensible defaults for new config options. Test upgrade path from previous version. Document migration steps clearly. + +## Alternative Approaches + +1. **Separate model and provider selection instead of unified** + Trade-offs: Simpler implementation, less change to existing code, but requires two-step selection process and doesn't solve the core UX issue. Unified selection is more intuitive. + +2. **Fetch models on-demand per provider instead of all at once** + Trade-offs: Lower initial memory footprint, but slower UX when browsing models. Caching strategy becomes more complex. The all-at-once approach is better for interactive selection. + +3. **Use a separate ModelWithProvider struct instead of adding field to Model** + Trade-offs: Avoids changing core Model domain, but creates inconsistency and requires more wrapper types. Adding field to Model is cleaner and more maintainable. + +4. **Implement provider as a property of ModelId instead of Model** + Trade-offs: Could encode provider in model identifier (e.g., "openai:gpt-4"), but this conflates identity with provider relationship. Separate field is more flexible and clearer. + +5. **Create a federated model registry service** + Trade-offs: More sophisticated architecture with a dedicated service managing models across providers. Higher complexity but better scalability. Overkill for current needs but could be future enhancement. + +6. **Use database for model caching instead of in-memory** + Trade-offs: Persistent cache survives restarts, but adds database overhead and complexity. Current in-memory approach with smart hydration is faster for interactive use. Could be hybrid approach in future.