Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion crates/forge_api/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ pub trait API: Sync + Send {
/// environment
async fn get_tools(&self) -> anyhow::Result<ToolsOverview>;

/// Provides a list of models available in the current environment
/// Provides a list of models available in the current environment (with
/// caching)
async fn get_models(&self) -> Result<Vec<Model>>;

/// Provides a list of agents available in the current environment
async fn get_agents(&self) -> Result<Vec<Agent>>;
/// Provides a list of providers available in the current environment
Expand Down
1 change: 1 addition & 0 deletions crates/forge_api/src/forge_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ impl<
async fn get_models(&self) -> Result<Vec<Model>> {
self.app().get_models().await
}

async fn get_agents(&self) -> Result<Vec<Agent>> {
self.services.get_agents().await
}
Expand Down
84 changes: 74 additions & 10 deletions crates/forge_app/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,18 +253,82 @@ impl<S: Services> ForgeApp<S> {
self.tool_registry.tools_overview().await
}

/// Gets available models for the default provider with automatic credential
/// refresh.
/// Gets available models from all configured LLM providers with automatic
/// credential refresh and caching. Fetches models in parallel and
/// aggregates results. Errors from individual providers are logged but
/// don't prevent returning models from other providers.
pub async fn get_models(&self) -> Result<Vec<Model>> {
let agent_provider_resolver = AgentProviderResolver::new(self.services.clone());
let provider = agent_provider_resolver.get_provider(None).await?;
let provider = self
.services
.provider_auth_service()
.refresh_provider_credential(provider)
.await?;
use forge_domain::ProviderType;
use futures::stream::{FuturesUnordered, StreamExt};

// Check cache first
if let Some(cached_models) = self.services.get_cached_all_models().await {
return Ok(cached_models);
}

// Get all providers
let all_providers = self.services.get_all_providers().await?;

// Filter to only LLM providers that are configured
let llm_providers: Vec<_> = all_providers
.into_iter()
.filter(|p| {
let is_llm = match p {
forge_domain::AnyProvider::Url(provider) => {
provider.provider_type == ProviderType::Llm
}
forge_domain::AnyProvider::Template(provider) => {
provider.provider_type == ProviderType::Llm
}
};
is_llm && p.is_configured()
})
.collect();

// Fetch models from all providers in parallel
let mut fetch_futures = FuturesUnordered::new();

for any_provider in llm_providers {
let provider = match any_provider.into_configured() {
Some(p) => p,
None => continue,
};

// Refresh credentials for this provider
let services = self.services.clone();
let fetch_task = async move {
let provider = services
.provider_auth_service()
.refresh_provider_credential(provider)
.await?;
let provider_id = provider.id.clone();

services.models(provider).await.map_err(|e| {
tracing::warn!(
provider_id = %provider_id,
error = %e,
"Failed to fetch models from provider"
);
e
})
};

fetch_futures.push(fetch_task);
}

// Collect all successful results
let mut all_models = Vec::new();
while let Some(result) = fetch_futures.next().await {
if let Ok(mut models) = result {
all_models.append(&mut models);
}
// Errors are already logged in the task above
}

// Cache the aggregated results
self.services.cache_all_models(all_models.clone()).await;

self.services.models(provider).await
Ok(all_models)
}
pub async fn login(&self, init_auth: &InitAuth) -> Result<()> {
self.authenticator.login(init_auth).await
Expand Down
8 changes: 8 additions & 0 deletions crates/forge_app/src/command_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,14 @@ mod tests {
async fn migrate_env_credentials(&self) -> anyhow::Result<Option<MigrationResult>> {
Ok(None)
}

async fn cache_all_models(&self, _models: Vec<forge_domain::Model>) {}

async fn get_cached_all_models(&self) -> Option<Vec<forge_domain::Model>> {
None
}

async fn invalidate_caches(&self) {}
}

#[async_trait::async_trait]
Expand Down
24 changes: 14 additions & 10 deletions crates/forge_app/src/dto/anthropic/response.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use forge_domain::{
ChatCompletionMessage, Content, ModelId, Reasoning, ReasoningPart, TokenCount, ToolCallId,
ToolCallPart, ToolName,
ChatCompletionMessage, Content, ModelId, ProviderId, Reasoning, ReasoningPart, TokenCount,
ToolCallId, ToolCallPart, ToolName,
};
use serde::Deserialize;

Expand All @@ -18,12 +18,14 @@ pub struct Model {
pub display_name: String,
}

impl From<Model> for forge_domain::Model {
fn from(value: Model) -> Self {
let context_length = get_context_length(&value.id);
Self {
id: ModelId::new(value.id),
name: Some(value.display_name),
impl Model {
/// Converts this DTO model to a domain model with the specified provider_id
pub fn into_domain_model(self, provider_id: ProviderId) -> forge_domain::Model {
let context_length = get_context_length(&self.id);
forge_domain::Model {
id: ModelId::new(self.id),
provider_id: Some(provider_id),
name: Some(self.display_name),
description: None,
context_length,
tools_supported: Some(true),
Expand Down Expand Up @@ -661,7 +663,8 @@ mod tests {
display_name: "Claude 3.5 Sonnet (New)".to_string(),
};

let actual: forge_domain::Model = fixture.into();
let actual: forge_domain::Model =
fixture.into_domain_model(forge_domain::ProviderId::ANTHROPIC);

assert_eq!(actual.context_length, Some(200_000));
assert_eq!(actual.id.as_str(), "claude-sonnet-4-5-20250929");
Expand All @@ -675,7 +678,8 @@ mod tests {
display_name: "Unknown Model".to_string(),
};

let actual: forge_domain::Model = fixture.into();
let actual: forge_domain::Model =
fixture.into_domain_model(forge_domain::ProviderId::ANTHROPIC);

assert_eq!(actual.context_length, None);
assert_eq!(actual.id.as_str(), "unknown-claude-model");
Expand Down
23 changes: 12 additions & 11 deletions crates/forge_app/src/dto/openai/model.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use forge_domain::ModelId;
use forge_domain::{ModelId, ProviderId};
use serde::{Deserialize, Serialize};

#[derive(Debug, Deserialize, Serialize, Clone)]
Expand Down Expand Up @@ -88,11 +88,11 @@ pub struct ListModelResponse {
pub data: Vec<Model>,
}

impl From<Model> for forge_domain::Model {
fn from(value: Model) -> Self {
impl Model {
/// Converts this DTO model to a domain model with the specified provider_id
pub fn into_domain_model(self, provider_id: ProviderId) -> forge_domain::Model {
let has_param = |name: &str| {
value
.supported_parameters
self.supported_parameters
.as_ref()
.map(|params| params.iter().any(|p| p == name))
};
Expand All @@ -102,10 +102,11 @@ impl From<Model> for forge_domain::Model {
let supports_reasoning = has_param("reasoning");

forge_domain::Model {
id: value.id,
name: value.name,
description: value.description,
context_length: value.context_length,
id: self.id,
provider_id: Some(provider_id),
name: self.name,
description: self.description,
context_length: self.context_length,
tools_supported,
supports_parallel_tool_calls,
supports_reasoning,
Expand Down Expand Up @@ -263,7 +264,7 @@ mod tests {
supported_parameters: None, // No supported_parameters field
};

let domain_model: forge_domain::Model = model.into();
let domain_model = model.into_domain_model(forge_domain::ProviderId::OPENAI);

// When supported_parameters is None, capabilities should be None (unknown)
assert_eq!(domain_model.tools_supported, None);
Expand All @@ -290,7 +291,7 @@ mod tests {
]),
};

let domain_model: forge_domain::Model = model.into();
let domain_model = model.into_domain_model(forge_domain::ProviderId::OPENAI);

// Should reflect what's actually in supported_parameters
assert_eq!(domain_model.tools_supported, Some(true));
Expand Down
1 change: 1 addition & 0 deletions crates/forge_app/src/orch_spec/orch_setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ impl Default for TestContext {
override_model: None,
override_provider: None,
enable_permissions: false,
model_cache_ttl_seconds: 3600,
},
title: Some("test-conversation".into()),
agent: Agent::new(
Expand Down
21 changes: 21 additions & 0 deletions crates/forge_app/src/services.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,15 @@ pub trait ProviderService: Send + Sync {
async fn migrate_env_credentials(
&self,
) -> anyhow::Result<Option<forge_domain::MigrationResult>>;

/// Cache aggregated models from all providers
async fn cache_all_models(&self, models: Vec<Model>);

/// Get cached aggregated models if available and not expired
async fn get_cached_all_models(&self) -> Option<Vec<Model>>;

/// Invalidate all caches (per-provider and aggregated)
async fn invalidate_caches(&self);
}

/// Manages user preferences for default providers and models.
Expand Down Expand Up @@ -656,6 +665,18 @@ impl<I: Services> ProviderService for I {
) -> anyhow::Result<Option<forge_domain::MigrationResult>> {
self.provider_service().migrate_env_credentials().await
}

async fn cache_all_models(&self, models: Vec<Model>) {
self.provider_service().cache_all_models(models).await
}

async fn get_cached_all_models(&self) -> Option<Vec<Model>> {
self.provider_service().get_cached_all_models().await
}

async fn invalidate_caches(&self) {
self.provider_service().invalidate_caches().await
}
}

#[async_trait::async_trait]
Expand Down
11 changes: 11 additions & 0 deletions crates/forge_domain/src/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ pub struct Environment {
/// Controlled by FORGE_ENABLE_PERMISSIONS environment variable.
/// When enabled, tools will check policies before execution.
pub enable_permissions: bool,
/// Cache TTL in seconds for model lists.
/// Controlled by FORGE_MODEL_CACHE_TTL environment variable.
/// Defaults to 3600 seconds (1 hour).
pub model_cache_ttl_seconds: u64,
}

impl Environment {
Expand Down Expand Up @@ -154,6 +158,10 @@ impl Environment {
pub fn cache_dir(&self) -> PathBuf {
self.base_path.join("cache")
}
/// Returns the path to the model cache file
pub fn model_cache_path(&self) -> PathBuf {
self.cache_dir().join("models.json")
}

/// Returns the global skills directory path (~/forge/skills)
pub fn global_skills_path(&self) -> PathBuf {
Expand Down Expand Up @@ -303,6 +311,7 @@ fn test_command_path() {
override_model: None,
override_provider: None,
enable_permissions: false,
model_cache_ttl_seconds: 3600,
};

let actual = fixture.command_path();
Expand Down Expand Up @@ -343,6 +352,7 @@ fn test_command_cwd_path() {
override_model: None,
override_provider: None,
enable_permissions: false,
model_cache_ttl_seconds: 3600,
};

let actual = fixture.command_cwd_path();
Expand Down Expand Up @@ -383,6 +393,7 @@ fn test_command_cwd_path_independent_from_command_path() {
override_model: None,
override_provider: None,
enable_permissions: false,
model_cache_ttl_seconds: 3600,
};

let command_path = fixture.command_path();
Expand Down
5 changes: 4 additions & 1 deletion crates/forge_domain/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@ use derive_setters::Setters;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};

use crate::ProviderId;

#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, Setters)]
pub struct Model {
pub id: ModelId,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub provider_id: Option<ProviderId>,
pub name: Option<String>,
pub description: Option<String>,
pub context_length: Option<u64>,
// TODO: add provider information to the model
pub tools_supported: Option<bool>,
/// Whether the model supports parallel tool calls
pub supports_parallel_tool_calls: Option<bool>,
Expand Down
2 changes: 2 additions & 0 deletions crates/forge_infra/src/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ impl ForgeEnvironmentInfra {
.and_then(|s| ProviderId::from_str(&s).ok());
let enable_permissions =
parse_env::<bool>("FORGE_ENABLE_PERMISSIONS").unwrap_or(cfg!(debug_assertions));
let model_cache_ttl_seconds = parse_env::<u64>("FORGE_MODEL_CACHE_TTL").unwrap_or(3600);

Environment {
os: std::env::consts::OS.to_string(),
Expand Down Expand Up @@ -97,6 +98,7 @@ impl ForgeEnvironmentInfra {
override_model,
override_provider,
enable_permissions,
model_cache_ttl_seconds,
}
}

Expand Down
Loading
Loading