diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e70feff3..6719a2630 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1638,6 +1638,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- **Google Gemini OAuth provider** (`gemini-oauth`): Added browser-based Authorization Code + PKCE authentication using Google accounts + - Integrates with provider setup UI and token storage + - Supports model discovery from Gemini OAuth credentials with fallback catalog + - API usage is billed to the authenticated user's Google account - **Voice Provider Management UI**: Configure TTS and STT providers from Settings > Voice - Auto-detection of API keys from environment variables and LLM provider configs - Toggle switches to enable/disable providers without removing configuration diff --git a/Cargo.toml b/Cargo.toml index 3d315bf16..0aff31a65 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -309,7 +309,7 @@ moltis-plugins = { path = "crates/plugins" } moltis-projects = { path = "crates/projects" } moltis-protocol = { path = "crates/protocol" } moltis-provider-setup = { path = "crates/provider-setup" } -moltis-providers = { features = ["provider-github-copilot", "provider-kimi-code", "provider-openai-codex"], path = "crates/providers" } +moltis-providers = { features = ["provider-gemini-oauth", "provider-github-copilot", "provider-kimi-code", "provider-openai-codex"], path = "crates/providers" } moltis-qmd = { path = "crates/qmd" } moltis-routing = { path = "crates/routing" } moltis-schema-export = { path = "crates/schema-export" } diff --git a/crates/config/src/schema.rs b/crates/config/src/schema.rs index 24df83d78..4d3b7b211 100644 --- a/crates/config/src/schema.rs +++ b/crates/config/src/schema.rs @@ -2053,7 +2053,10 @@ pub struct ProvidersConfig { pub offered: Vec, /// Provider-specific settings keyed by provider name. - /// Known keys: "anthropic", "openai", "gemini", "groq", "xai", "deepseek" + /// Known keys include "anthropic", "openai", "gemini", "gemini-oauth", + /// "groq", "xai", "deepseek", "mistral", "openrouter", "cerebras", + /// "minimax", "moonshot", "venice", "ollama", "lmstudio", + /// "openai-codex", "github-copilot", and "kimi-code". #[serde(flatten)] pub providers: HashMap, diff --git a/crates/config/src/template.rs b/crates/config/src/template.rs index 3313bcba2..65167512d 100644 --- a/crates/config/src/template.rs +++ b/crates/config/src/template.rs @@ -96,7 +96,7 @@ offered = ["local-llm", "github-copilot", "openai-codex", "openai", "anthropic", # All available providers: # "anthropic", "openai", "gemini", "groq", "xai", "deepseek", # "mistral", "openrouter", "cerebras", "minimax", "moonshot", -# "zai", "venice", "ollama", "local-llm", "openai-codex", +# "zai", "venice", "ollama", "local-llm", "gemini-oauth", "openai-codex", # "github-copilot", "kimi-code" # ── Anthropic (Claude) ──────────────────────────────────────── diff --git a/crates/config/src/validate.rs b/crates/config/src/validate.rs index cf57ff326..3476cd249 100644 --- a/crates/config/src/validate.rs +++ b/crates/config/src/validate.rs @@ -88,6 +88,7 @@ const KNOWN_PROVIDER_NAMES: &[&str] = &[ "anthropic", "openai", "gemini", + "gemini-oauth", "groq", "xai", "deepseek", @@ -99,6 +100,9 @@ const KNOWN_PROVIDER_NAMES: &[&str] = &[ "venice", "ollama", "lmstudio", + "openai-codex", + "github-copilot", + "kimi-code", ]; /// Static metadata keys allowed directly under `[providers]`. diff --git a/crates/oauth/src/defaults.rs b/crates/oauth/src/defaults.rs index f706b5214..da76d4642 100644 --- a/crates/oauth/src/defaults.rs +++ b/crates/oauth/src/defaults.rs @@ -46,6 +46,29 @@ fn builtin_defaults() -> HashMap { ], device_flow: false, }); + // Google Gemini uses Authorization Code + PKCE flow. + // Users authenticate with their Google account; API usage is billed to their account. + // The client_id is a public identifier (not secret) for the Moltis application. + m.insert("gemini-oauth".into(), OAuthConfig { + // TODO: Replace with actual Moltis client ID from Google Cloud Console + client_id: "MOLTIS_GEMINI_CLIENT_ID".into(), + auth_url: "https://accounts.google.com/o/oauth2/v2/auth".into(), + token_url: "https://oauth2.googleapis.com/token".into(), + redirect_uri: "http://localhost:1456/auth/callback".into(), + resource: None, + scopes: vec![ + // Scope for Gemini API access + "https://www.googleapis.com/auth/generative-language.retriever".into(), + "https://www.googleapis.com/auth/cloud-platform".into(), + ], + extra_auth_params: vec![ + // Request offline access to get a refresh token + ("access_type".into(), "offline".into()), + // Force consent screen to always show (ensures refresh token) + ("prompt".into(), "consent".into()), + ], + device_flow: false, + }); m } @@ -156,4 +179,17 @@ mod tests { let config = load_oauth_config("openai-codex").unwrap(); assert_eq!(callback_port(&config), 1455); } + + #[test] + fn load_gemini_oauth_config() { + let config = load_oauth_config("gemini-oauth").expect("should have gemini-oauth"); + assert!(!config.device_flow); + assert!(!config.redirect_uri.is_empty()); + assert_eq!( + config.auth_url, + "https://accounts.google.com/o/oauth2/v2/auth" + ); + assert_eq!(config.token_url, "https://oauth2.googleapis.com/token"); + assert_eq!(callback_port(&config), 1456); + } } diff --git a/crates/provider-setup/src/lib.rs b/crates/provider-setup/src/lib.rs index 749c29d63..8015e2b41 100644 --- a/crates/provider-setup/src/lib.rs +++ b/crates/provider-setup/src/lib.rs @@ -762,6 +762,15 @@ pub fn known_providers() -> Vec { requires_model: false, key_optional: false, }, + KnownProvider { + name: "gemini-oauth", + display_name: "Google Gemini (OAuth)", + auth_type: AuthType::Oauth, + env_key: None, + default_base_url: None, + requires_model: false, + key_optional: false, + }, KnownProvider { name: "groq", display_name: "Groq", @@ -3720,6 +3729,7 @@ mod tests { assert!(names.contains(&"venice"), "missing venice"); assert!(names.contains(&"ollama"), "missing ollama"); // OAuth providers + assert!(names.contains(&"gemini-oauth"), "missing gemini-oauth"); assert!(names.contains(&"github-copilot"), "missing github-copilot"); } @@ -3734,6 +3744,17 @@ mod tests { assert!(copilot.env_key.is_none()); } + #[test] + fn gemini_oauth_is_oauth_provider() { + let providers = known_providers(); + let gemini_oauth = providers + .iter() + .find(|p| p.name == "gemini-oauth") + .expect("gemini-oauth not in known_providers"); + assert_eq!(gemini_oauth.auth_type, AuthType::Oauth); + assert!(gemini_oauth.env_key.is_none()); + } + #[test] fn new_api_key_providers_have_correct_env_keys() { let expected = [ diff --git a/crates/providers/Cargo.toml b/crates/providers/Cargo.toml index 6742cdadc..3dda50632 100644 --- a/crates/providers/Cargo.toml +++ b/crates/providers/Cargo.toml @@ -12,6 +12,7 @@ local-llm-metal = ["llama-cpp-2/metal", "local-llm"] metrics = ["dep:moltis-metrics"] provider-async-openai = ["dep:async-openai"] provider-genai = ["dep:genai"] +provider-gemini-oauth = ["dep:moltis-oauth"] provider-github-copilot = ["dep:moltis-oauth"] provider-kimi-code = ["dep:moltis-oauth"] provider-openai-codex = ["dep:base64", "dep:moltis-oauth"] diff --git a/crates/providers/src/gemini_oauth.rs b/crates/providers/src/gemini_oauth.rs new file mode 100644 index 000000000..bef2dc63b --- /dev/null +++ b/crates/providers/src/gemini_oauth.rs @@ -0,0 +1,1080 @@ +//! Google Gemini OAuth provider. +//! +//! Authentication uses Authorization Code Flow with PKCE to obtain an access token, +//! which is then used to call the Gemini API. Users authenticate with their Google +//! account, and API usage is billed to their account (not to the application developer). +//! +//! The OAuth flow: +//! 1. Open browser to Google OAuth consent screen +//! 2. User authenticates with their Google account +//! 3. Browser redirects to local callback server with authorization code +//! 4. Exchange code for tokens using PKCE verifier +//! 5. Store tokens securely for future use + +use std::{pin::Pin, sync::mpsc, time::Duration}; + +use { + async_trait::async_trait, + futures::StreamExt, + moltis_agents::model::{ + ChatMessage, CompletionResponse, ContentPart, LlmProvider, StreamEvent, ToolCall, Usage, + UserContent, + }, + moltis_oauth::{ + CallbackServer, OAuthConfig, OAuthFlow, OAuthTokens, TokenStore, callback_port, + load_oauth_config, + }, + secrecy::ExposeSecret, + tokio_stream::Stream, + tracing::{debug, info, trace, warn}, +}; + +// ── Constants ──────────────────────────────────────────────────────────────── + +const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com"; +const PROVIDER_NAME: &str = "gemini-oauth"; + +/// Buffer before token expiry to trigger refresh (5 minutes). +const REFRESH_THRESHOLD_SECS: u64 = 300; + +/// Information about a Gemini model returned from the API. +#[derive(Debug, Clone, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GeminiModelInfo { + /// Full resource name (e.g., "models/gemini-2.0-flash") + pub name: String, + /// Human-readable display name + #[serde(default)] + pub display_name: String, + /// Maximum input tokens (context window) + #[serde(default)] + pub input_token_limit: u32, + /// Maximum output tokens + #[serde(default)] + pub output_token_limit: u32, + /// Supported generation methods (e.g., "generateContent", "streamGenerateContent") + #[serde(default)] + pub supported_generation_methods: Vec, +} + +impl GeminiModelInfo { + /// Extract the model ID from the full resource name. + #[must_use] + pub fn model_id(&self) -> &str { + self.name.strip_prefix("models/").unwrap_or(&self.name) + } + + /// Check if this model supports text generation. + #[must_use] + pub fn supports_generation(&self) -> bool { + self.supported_generation_methods + .iter() + .any(|m| m == "generateContent") + } +} + +// ── Provider ───────────────────────────────────────────────────────────────── + +pub struct GeminiOAuthProvider { + model: String, + client: reqwest::Client, + token_store: TokenStore, +} + +impl GeminiOAuthProvider { + pub fn new(model: String) -> Self { + Self { + model, + client: reqwest::Client::new(), + token_store: TokenStore::new(), + } + } + + /// Get the OAuth configuration for Gemini. + pub fn oauth_config() -> Option { + load_oauth_config(PROVIDER_NAME) + } + + /// Start the OAuth flow: returns the authorization URL to open in the browser. + /// Also returns the PKCE verifier and state for later token exchange. + pub fn start_auth_flow() -> Option { + let config = Self::oauth_config()?; + let flow = OAuthFlow::new(config.clone()); + let auth_request = flow.start().ok()?; + + Some(AuthFlowState { + auth_url: auth_request.url, + pkce_verifier: auth_request.pkce.verifier, + state: auth_request.state, + config, + }) + } + + /// Wait for the OAuth callback and exchange the code for tokens. + pub async fn complete_auth_flow(flow_state: &AuthFlowState) -> anyhow::Result { + let port = callback_port(&flow_state.config); + + // Wait for the callback with the authorization code + let code = + CallbackServer::wait_for_code(port, flow_state.state.clone(), "127.0.0.1").await?; + + // Exchange the code for tokens + let flow = OAuthFlow::new(flow_state.config.clone()); + let tokens = flow.exchange(&code, &flow_state.pkce_verifier).await?; + + Ok(tokens) + } + + /// Refresh the access token using the refresh token. + async fn refresh_access_token(&self, refresh_token: &str) -> anyhow::Result { + let config = Self::oauth_config() + .ok_or_else(|| anyhow::anyhow!("gemini-oauth configuration not found"))?; + + let flow = OAuthFlow::new(config); + Ok(flow.refresh(refresh_token).await?) + } + + /// List available models using stored OAuth credentials. + pub async fn list_available_models(&self) -> anyhow::Result> { + let token = self.get_valid_token().await?; + list_models_with_token(&token, GEMINI_API_BASE).await + } + + /// Get a valid access token, refreshing if needed. + async fn get_valid_token(&self) -> anyhow::Result { + let tokens = self.token_store.load(PROVIDER_NAME).ok_or_else(|| { + anyhow::anyhow!("not logged in to gemini-oauth — run OAuth flow first") + })?; + + // Check if token needs refresh + if let Some(expires_at) = tokens.expires_at { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map_err(|err| anyhow::anyhow!("system clock before UNIX epoch: {err}"))? + .as_secs(); + + if now + REFRESH_THRESHOLD_SECS >= expires_at { + // Token expiring soon — refresh it + if let Some(ref refresh_token) = tokens.refresh_token { + let new_tokens = self + .refresh_access_token(refresh_token.expose_secret()) + .await?; + self.token_store.save(PROVIDER_NAME, &new_tokens)?; + return Ok(new_tokens.access_token.expose_secret().clone()); + } + anyhow::bail!("token expired and no refresh token available"); + } + } + + Ok(tokens.access_token.expose_secret().clone()) + } +} + +/// State needed to complete the OAuth flow after user authorization. +pub struct AuthFlowState { + pub auth_url: String, + pub pkce_verifier: String, + pub state: String, + config: OAuthConfig, +} + +/// Check if we have stored tokens for Google Gemini OAuth. +pub fn has_stored_tokens() -> bool { + TokenStore::new().load(PROVIDER_NAME).is_some() +} + +/// Save tokens after successful authentication. +pub fn save_tokens(tokens: &OAuthTokens) -> anyhow::Result<()> { + TokenStore::new().save(PROVIDER_NAME, tokens)?; + Ok(()) +} + +// ── Model Listing ──────────────────────────────────────────────────────────── + +#[derive(Debug, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +struct ListModelsResponse { + models: Vec, + #[serde(default)] + next_page_token: Option, +} + +/// List available Gemini models using OAuth authentication. +/// +/// Returns models that support text generation, sorted by name. +/// Requires stored OAuth tokens from a prior authentication. +pub async fn list_models_oauth() -> anyhow::Result> { + let store = TokenStore::new(); + let tokens = store + .load(PROVIDER_NAME) + .ok_or_else(|| anyhow::anyhow!("not logged in to gemini-oauth — run OAuth flow first"))?; + + // Check if token needs refresh + let access_token = if needs_token_refresh(&tokens) { + let config = load_oauth_config(PROVIDER_NAME) + .ok_or_else(|| anyhow::anyhow!("gemini-oauth configuration not found"))?; + + if let Some(ref refresh_token) = tokens.refresh_token { + let flow = OAuthFlow::new(config); + let new_tokens = flow.refresh(refresh_token.expose_secret()).await?; + store.save(PROVIDER_NAME, &new_tokens)?; + new_tokens.access_token.expose_secret().clone() + } else { + anyhow::bail!("token expired and no refresh token available"); + } + } else { + tokens.access_token.expose_secret().clone() + }; + + list_models_with_token(&access_token, GEMINI_API_BASE).await +} + +/// List available Gemini models with an OAuth access token and custom base URL. +pub async fn list_models_with_token( + access_token: &str, + base_url: &str, +) -> anyhow::Result> { + let client = reqwest::Client::new(); + let mut all_models = Vec::new(); + let mut page_token: Option = None; + + loop { + let mut url = format!("{}/v1beta/models", base_url); + if let Some(ref token) = page_token { + url.push_str(&format!("?pageToken={}", token)); + } + + let resp = client + .get(&url) + .header("Authorization", format!("Bearer {access_token}")) + .send() + .await?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + anyhow::bail!("Failed to list Gemini models: HTTP {status}: {body}"); + } + + let list_resp: ListModelsResponse = resp.json().await?; + all_models.extend(list_resp.models); + + match list_resp.next_page_token { + Some(token) if !token.is_empty() => page_token = Some(token), + _ => break, + } + } + + // Filter to models that support generation and sort by name + let mut models: Vec<_> = all_models + .into_iter() + .filter(|m| m.supports_generation()) + .collect(); + models.sort_by(|a, b| a.name.cmp(&b.name)); + + Ok(models) +} + +/// Check if the stored token needs refresh (within REFRESH_THRESHOLD_SECS of expiry). +fn needs_token_refresh(tokens: &OAuthTokens) -> bool { + if let Some(expires_at) = tokens.expires_at { + let Ok(now) = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) else { + return true; + }; + let now = now.as_secs(); + now + REFRESH_THRESHOLD_SECS >= expires_at + } else { + false + } +} + +/// Known Gemini models available via OAuth. +pub const GEMINI_OAUTH_MODELS: &[(&str, &str)] = &[ + ("gemini-2.5-pro-preview-06-05", "Gemini 2.5 Pro (OAuth)"), + ("gemini-2.5-flash-preview-05-20", "Gemini 2.5 Flash (OAuth)"), + ("gemini-2.0-flash", "Gemini 2.0 Flash (OAuth)"), + ("gemini-2.0-flash-lite", "Gemini 2.0 Flash Lite (OAuth)"), + ("gemini-1.5-pro", "Gemini 1.5 Pro (OAuth)"), + ("gemini-1.5-flash", "Gemini 1.5 Flash (OAuth)"), +]; + +fn info_to_discovered_model(info: GeminiModelInfo) -> super::DiscoveredModel { + let id = info.model_id().to_string(); + let display_name = if info.display_name.trim().is_empty() { + id.clone() + } else { + info.display_name + }; + super::DiscoveredModel::new(id, display_name) +} + +pub fn default_model_catalog() -> Vec { + GEMINI_OAUTH_MODELS + .iter() + .map(|(id, name)| super::DiscoveredModel::new(*id, *name)) + .collect() +} + +fn fetch_models_blocking() -> anyhow::Result> { + let (tx, rx) = mpsc::sync_channel(1); + std::thread::spawn(move || { + let result = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .map_err(anyhow::Error::from) + .and_then(|rt| { + rt.block_on(async { + let models = tokio::time::timeout(Duration::from_secs(8), list_models_oauth()) + .await + .map_err(|_| anyhow::anyhow!("gemini-oauth model discovery timed out"))??; + Ok(models + .into_iter() + .map(info_to_discovered_model) + .collect::>()) + }) + }); + let _ = tx.send(result); + }); + + rx.recv() + .map_err(|err| anyhow::anyhow!("gemini-oauth model discovery worker failed: {err}"))? +} + +pub fn live_models() -> anyhow::Result> { + let models = fetch_models_blocking()?; + info!( + model_count = models.len(), + "loaded gemini-oauth live models" + ); + Ok(models) +} + +pub fn available_models() -> Vec { + let fallback = default_model_catalog(); + let discovered = match live_models() { + Ok(models) => models, + Err(err) => { + let msg = err.to_string(); + if msg.contains("not logged in") || msg.contains("tokens not found") { + debug!(error = %err, "gemini-oauth not configured, using fallback catalog"); + } else { + warn!(error = %err, "failed to fetch gemini-oauth models, using fallback catalog"); + } + return fallback; + }, + }; + + super::merge_discovered_with_fallback_catalog(discovered, fallback) +} + +// ── Gemini API helpers ─────────────────────────────────────────────────────── + +/// Convert JSON Schema types (lowercase) to Gemini types (uppercase). +fn convert_json_schema_types(schema: &serde_json::Value) -> serde_json::Value { + match schema { + serde_json::Value::Object(obj) => { + let mut result = serde_json::Map::new(); + for (key, value) in obj { + if key == "type" { + if let Some(type_str) = value.as_str() { + result.insert( + key.clone(), + serde_json::Value::String(type_str.to_uppercase()), + ); + } else { + result.insert(key.clone(), value.clone()); + } + } else if key == "properties" { + if let serde_json::Value::Object(props) = value { + let converted_props: serde_json::Map = props + .iter() + .map(|(k, v)| (k.clone(), convert_json_schema_types(v))) + .collect(); + result.insert(key.clone(), serde_json::Value::Object(converted_props)); + } else { + result.insert(key.clone(), value.clone()); + } + } else if key == "items" { + result.insert(key.clone(), convert_json_schema_types(value)); + } else { + result.insert(key.clone(), value.clone()); + } + } + serde_json::Value::Object(result) + }, + serde_json::Value::Array(arr) => { + serde_json::Value::Array(arr.iter().map(convert_json_schema_types).collect()) + }, + _ => schema.clone(), + } +} + +/// Convert tool schemas to Gemini's functionDeclarations format. +fn to_gemini_tools(tools: &[serde_json::Value]) -> serde_json::Value { + let declarations: Vec = tools + .iter() + .map(|t| { + let params = convert_json_schema_types(&t["parameters"]); + serde_json::json!({ + "name": t["name"], + "description": t["description"], + "parameters": params, + }) + }) + .collect(); + + serde_json::json!({ "functionDeclarations": declarations }) +} + +/// Extract system instruction from messages. +fn extract_system_instruction(messages: &[ChatMessage]) -> (Option, Vec<&ChatMessage>) { + let mut system_text = None; + let mut remaining = Vec::new(); + + for msg in messages { + if let ChatMessage::System { content } = msg { + system_text = Some(content.clone()); + } else { + remaining.push(msg); + } + } + + (system_text, remaining) +} + +/// Convert messages to Gemini's content format. +fn to_gemini_messages(messages: &[&ChatMessage]) -> Vec { + messages + .iter() + .map(|msg| match msg { + ChatMessage::System { .. } => { + // System messages are handled separately via systemInstruction + serde_json::json!({ + "role": "user", + "parts": [{ "text": "" }], + }) + }, + ChatMessage::User { content } => { + let parts = match content { + UserContent::Text(text) => { + vec![serde_json::json!({ "text": text })] + }, + UserContent::Multimodal(parts) => parts + .iter() + .map(|p| match p { + ContentPart::Text(text) => { + serde_json::json!({ "text": text }) + }, + ContentPart::Image { media_type, data } => { + serde_json::json!({ + "inlineData": { + "mimeType": media_type, + "data": data, + } + }) + }, + }) + .collect(), + }; + serde_json::json!({ + "role": "user", + "parts": parts, + }) + }, + ChatMessage::Assistant { + content, + tool_calls, + } => { + let mut parts = Vec::new(); + + if let Some(text) = content + && !text.is_empty() + { + parts.push(serde_json::json!({ "text": text })); + } + + for tc in tool_calls { + parts.push(serde_json::json!({ + "functionCall": { + "name": &tc.name, + "args": &tc.arguments, + } + })); + } + + if parts.is_empty() { + parts.push(serde_json::json!({ "text": "" })); + } + + serde_json::json!({ + "role": "model", + "parts": parts, + }) + }, + ChatMessage::Tool { + tool_call_id, + content, + } => { + let response: serde_json::Value = serde_json::from_str(content) + .unwrap_or_else(|_| serde_json::json!({ "result": content })); + + serde_json::json!({ + "role": "user", + "parts": [{ + "functionResponse": { + "name": tool_call_id, + "response": response, + } + }], + }) + }, + }) + .collect() +} + +/// Parse tool calls from Gemini response parts. +fn parse_tool_calls(parts: &[serde_json::Value]) -> Vec { + parts + .iter() + .filter_map(|part| { + if let Some(fc) = part.get("functionCall") { + let name = fc["name"].as_str().unwrap_or("").to_string(); + let args = fc["args"].clone(); + Some(ToolCall { + id: name.clone(), + name, + arguments: args, + }) + } else { + None + } + }) + .collect() +} + +/// Extract text content from Gemini response parts. +fn extract_text(parts: &[serde_json::Value]) -> Option { + let texts: Vec<&str> = parts + .iter() + .filter_map(|part| part["text"].as_str()) + .collect(); + + if texts.is_empty() { + None + } else { + Some(texts.join("")) + } +} + +// ── LlmProvider impl ──────────────────────────────────────────────────────── + +#[async_trait] +impl LlmProvider for GeminiOAuthProvider { + fn name(&self) -> &str { + PROVIDER_NAME + } + + fn id(&self) -> &str { + &self.model + } + + fn supports_tools(&self) -> bool { + true + } + + fn context_window(&self) -> u32 { + super::context_window_for_model(&self.model) + } + + async fn complete( + &self, + messages: &[ChatMessage], + tools: &[serde_json::Value], + ) -> anyhow::Result { + let token = self.get_valid_token().await?; + + let (system_text, conv_messages) = extract_system_instruction(messages); + let gemini_messages = to_gemini_messages(&conv_messages); + + let mut body = serde_json::json!({ + "contents": gemini_messages, + "generationConfig": { + "maxOutputTokens": 8192, + }, + }); + + if let Some(ref sys) = system_text { + body["systemInstruction"] = serde_json::json!({ + "parts": [{ "text": sys }] + }); + } + + if !tools.is_empty() { + body["tools"] = serde_json::Value::Array(vec![to_gemini_tools(tools)]); + } + + debug!( + model = %self.model, + messages_count = gemini_messages.len(), + tools_count = tools.len(), + has_system = system_text.is_some(), + "gemini-oauth complete request" + ); + trace!(body = %serde_json::to_string(&body).unwrap_or_default(), "gemini-oauth request body"); + + let url = format!( + "{}/v1beta/models/{}:generateContent", + GEMINI_API_BASE, self.model + ); + + let http_resp = self + .client + .post(&url) + .header("Authorization", format!("Bearer {token}")) + .header("content-type", "application/json") + .json(&body) + .send() + .await?; + + let status = http_resp.status(); + if !status.is_success() { + let body_text = http_resp.text().await.unwrap_or_default(); + warn!(status = %status, body = %body_text, "gemini-oauth API error"); + anyhow::bail!("Gemini OAuth API error HTTP {status}: {body_text}"); + } + + let resp = http_resp.json::().await?; + trace!(response = %resp, "gemini-oauth raw response"); + + let parts = resp["candidates"][0]["content"]["parts"] + .as_array() + .cloned() + .unwrap_or_default(); + + let text = extract_text(&parts); + let tool_calls = parse_tool_calls(&parts); + + let usage = Usage { + input_tokens: resp["usageMetadata"]["promptTokenCount"] + .as_u64() + .unwrap_or(0) as u32, + output_tokens: resp["usageMetadata"]["candidatesTokenCount"] + .as_u64() + .unwrap_or(0) as u32, + ..Default::default() + }; + + Ok(CompletionResponse { + text, + tool_calls, + usage, + }) + } + + #[allow(clippy::collapsible_if)] + fn stream( + &self, + messages: Vec, + ) -> Pin + Send + '_>> { + Box::pin(async_stream::stream! { + let token = match self.get_valid_token().await { + Ok(t) => t, + Err(e) => { + yield StreamEvent::Error(e.to_string()); + return; + } + }; + + let (system_text, conv_messages) = extract_system_instruction(&messages); + let gemini_messages = to_gemini_messages(&conv_messages); + + let mut body = serde_json::json!({ + "contents": gemini_messages, + "generationConfig": { + "maxOutputTokens": 8192, + }, + }); + + if let Some(ref sys) = system_text { + body["systemInstruction"] = serde_json::json!({ + "parts": [{ "text": sys }] + }); + } + + let url = format!( + "{}/v1beta/models/{}:streamGenerateContent?alt=sse", + GEMINI_API_BASE, self.model + ); + + let resp = match self + .client + .post(&url) + .header("Authorization", format!("Bearer {token}")) + .header("content-type", "application/json") + .json(&body) + .send() + .await + { + Ok(r) => { + if let Err(e) = r.error_for_status_ref() { + let status = e.status().map(|s| s.as_u16()).unwrap_or(0); + let body_text = r.text().await.unwrap_or_default(); + yield StreamEvent::Error(format!("HTTP {status}: {body_text}")); + return; + } + r + } + Err(e) => { + yield StreamEvent::Error(e.to_string()); + return; + } + }; + + let mut byte_stream = resp.bytes_stream(); + let mut buf = String::new(); + let mut input_tokens: u32 = 0; + let mut output_tokens: u32 = 0; + + while let Some(chunk) = byte_stream.next().await { + let chunk = match chunk { + Ok(c) => c, + Err(e) => { + yield StreamEvent::Error(e.to_string()); + return; + } + }; + buf.push_str(&String::from_utf8_lossy(&chunk)); + + while let Some(pos) = buf.find("\n\n") { + let block = buf[..pos].to_string(); + buf = buf[pos + 2..].to_string(); + + for line in block.lines() { + let Some(data) = line.strip_prefix("data: ") else { + continue; + }; + + if let Ok(evt) = serde_json::from_str::(data) { + if let Some(usage) = evt.get("usageMetadata") { + if let Some(pt) = usage["promptTokenCount"].as_u64() { + input_tokens = pt as u32; + } + if let Some(ct) = usage["candidatesTokenCount"].as_u64() { + output_tokens = ct as u32; + } + } + + if let Some(parts) = evt["candidates"][0]["content"]["parts"].as_array() { + for part in parts { + if let Some(text) = part["text"].as_str() { + if !text.is_empty() { + yield StreamEvent::Delta(text.to_string()); + } + } + } + } + + if let Some(finish_reason) = evt["candidates"][0]["finishReason"].as_str() { + if finish_reason == "STOP" || finish_reason == "MAX_TOKENS" { + yield StreamEvent::Done(Usage { input_tokens, output_tokens, ..Default::default() }); + return; + } + } + } + } + } + } + + yield StreamEvent::Done(Usage { input_tokens, output_tokens, ..Default::default() }); + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn has_stored_tokens_returns_false_without_tokens() { + // Just verify it doesn't panic + let _ = has_stored_tokens(); + } + + #[test] + fn gemini_oauth_models_not_empty() { + assert!(!GEMINI_OAUTH_MODELS.is_empty()); + } + + #[test] + fn gemini_oauth_models_have_unique_ids() { + let mut ids: Vec<&str> = GEMINI_OAUTH_MODELS.iter().map(|(id, _)| *id).collect(); + ids.sort(); + ids.dedup(); + assert_eq!(ids.len(), GEMINI_OAUTH_MODELS.len()); + } + + #[test] + fn provider_name_and_id() { + let provider = GeminiOAuthProvider::new("gemini-2.0-flash".into()); + assert_eq!(provider.name(), "gemini-oauth"); + assert_eq!(provider.id(), "gemini-2.0-flash"); + assert!(provider.supports_tools()); + } + + #[test] + fn oauth_config_loads() { + // Should return Some since we have a default config + let config = GeminiOAuthProvider::oauth_config(); + assert!(config.is_some()); + let config = config.unwrap(); + assert!(!config.device_flow); + assert!(config.redirect_uri.contains("localhost")); + } + + #[test] + fn to_gemini_tools_converts_correctly() { + let tools = vec![serde_json::json!({ + "name": "test_tool", + "description": "A test tool", + "parameters": {"type": "object", "properties": {"x": {"type": "string"}}} + })]; + let converted = to_gemini_tools(&tools); + assert!(converted["functionDeclarations"].is_array()); + let decls = converted["functionDeclarations"].as_array().unwrap(); + assert_eq!(decls.len(), 1); + assert_eq!(decls[0]["name"], "test_tool"); + assert_eq!(decls[0]["parameters"]["type"], "OBJECT"); + } + + #[test] + fn convert_json_schema_types_works() { + let schema = serde_json::json!({ + "type": "object", + "properties": { + "name": { "type": "string" } + } + }); + let converted = convert_json_schema_types(&schema); + assert_eq!(converted["type"], "OBJECT"); + assert_eq!(converted["properties"]["name"]["type"], "STRING"); + } + + #[test] + fn extract_system_instruction_works() { + let messages = vec![ + ChatMessage::system("You are helpful"), + ChatMessage::user("Hello"), + ]; + let (system, remaining) = extract_system_instruction(&messages); + assert_eq!(system, Some("You are helpful".to_string())); + assert_eq!(remaining.len(), 1); + } + + #[test] + fn to_gemini_messages_converts_user() { + let msg = ChatMessage::user("Hello"); + let messages = vec![&msg]; + let gemini = to_gemini_messages(&messages); + assert_eq!(gemini.len(), 1); + assert_eq!(gemini[0]["role"], "user"); + assert_eq!(gemini[0]["parts"][0]["text"], "Hello"); + } + + #[test] + fn to_gemini_messages_converts_assistant() { + let msg = ChatMessage::assistant("Hi"); + let messages = vec![&msg]; + let gemini = to_gemini_messages(&messages); + assert_eq!(gemini[0]["role"], "model"); + assert_eq!(gemini[0]["parts"][0]["text"], "Hi"); + } + + #[test] + fn parse_tool_calls_works() { + let parts = vec![serde_json::json!({ + "functionCall": { + "name": "get_weather", + "args": { "city": "SF" } + } + })]; + let calls = parse_tool_calls(&parts); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "get_weather"); + assert_eq!(calls[0].arguments["city"], "SF"); + } + + #[test] + fn extract_text_works() { + let parts = vec![ + serde_json::json!({ "text": "Hello " }), + serde_json::json!({ "text": "world" }), + ]; + assert_eq!(extract_text(&parts), Some("Hello world".to_string())); + } + + #[test] + fn extract_text_empty() { + let parts: Vec = vec![]; + assert_eq!(extract_text(&parts), None); + } + + #[test] + fn context_window_uses_lookup() { + let provider = GeminiOAuthProvider::new("gemini-2.0-flash".into()); + assert_eq!(provider.context_window(), 1_000_000); + } + + #[test] + fn gemini_model_info_model_id_strips_prefix() { + let info = GeminiModelInfo { + name: "models/gemini-2.0-flash".into(), + display_name: "Gemini 2.0 Flash".into(), + input_token_limit: 1_000_000, + output_token_limit: 8192, + supported_generation_methods: vec!["generateContent".into()], + }; + assert_eq!(info.model_id(), "gemini-2.0-flash"); + } + + #[test] + fn gemini_model_info_model_id_handles_no_prefix() { + let info = GeminiModelInfo { + name: "gemini-2.0-flash".into(), + display_name: "Gemini 2.0 Flash".into(), + input_token_limit: 0, + output_token_limit: 0, + supported_generation_methods: vec![], + }; + assert_eq!(info.model_id(), "gemini-2.0-flash"); + } + + #[test] + fn gemini_model_info_supports_generation() { + let info = GeminiModelInfo { + name: "models/gemini-2.0-flash".into(), + display_name: "".into(), + input_token_limit: 0, + output_token_limit: 0, + supported_generation_methods: vec!["generateContent".into(), "embedContent".into()], + }; + assert!(info.supports_generation()); + + let info_no_gen = GeminiModelInfo { + name: "models/text-embedding".into(), + display_name: "".into(), + input_token_limit: 0, + output_token_limit: 0, + supported_generation_methods: vec!["embedContent".into()], + }; + assert!(!info_no_gen.supports_generation()); + } + + #[test] + fn needs_token_refresh_returns_false_when_no_expiry() { + use secrecy::Secret; + let tokens = OAuthTokens { + access_token: Secret::new("test".into()), + refresh_token: None, + id_token: None, + account_id: None, + expires_at: None, + }; + assert!(!needs_token_refresh(&tokens)); + } + + #[test] + fn needs_token_refresh_returns_true_when_expired() { + use secrecy::Secret; + // Token that expired 10 minutes ago + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(); + let tokens = OAuthTokens { + access_token: Secret::new("test".into()), + refresh_token: None, + id_token: None, + account_id: None, + expires_at: Some(now - 600), + }; + assert!(needs_token_refresh(&tokens)); + } + + #[test] + fn needs_token_refresh_returns_true_within_threshold() { + use secrecy::Secret; + // Token expiring in 2 minutes (within 5-minute threshold) + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(); + let tokens = OAuthTokens { + access_token: Secret::new("test".into()), + refresh_token: None, + id_token: None, + account_id: None, + expires_at: Some(now + 120), + }; + assert!(needs_token_refresh(&tokens)); + } + + #[test] + fn needs_token_refresh_returns_false_when_fresh() { + use secrecy::Secret; + // Token expiring in 1 hour + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(); + let tokens = OAuthTokens { + access_token: Secret::new("test".into()), + refresh_token: None, + id_token: None, + account_id: None, + expires_at: Some(now + 3600), + }; + assert!(!needs_token_refresh(&tokens)); + } + + #[test] + fn list_models_response_deserializes() { + let json = r#"{ + "models": [ + { + "name": "models/gemini-2.0-flash", + "displayName": "Gemini 2.0 Flash", + "inputTokenLimit": 1000000, + "outputTokenLimit": 8192, + "supportedGenerationMethods": ["generateContent", "streamGenerateContent"] + } + ], + "nextPageToken": "abc123" + }"#; + let resp: ListModelsResponse = serde_json::from_str(json).unwrap(); + assert_eq!(resp.models.len(), 1); + assert_eq!(resp.models[0].model_id(), "gemini-2.0-flash"); + assert_eq!(resp.models[0].display_name, "Gemini 2.0 Flash"); + assert_eq!(resp.models[0].input_token_limit, 1_000_000); + assert!(resp.models[0].supports_generation()); + assert_eq!(resp.next_page_token, Some("abc123".to_string())); + } + + #[test] + fn list_models_response_handles_missing_fields() { + let json = r#"{ + "models": [ + { + "name": "models/gemini-test" + } + ] + }"#; + let resp: ListModelsResponse = serde_json::from_str(json).unwrap(); + assert_eq!(resp.models.len(), 1); + assert_eq!(resp.models[0].model_id(), "gemini-test"); + assert_eq!(resp.models[0].display_name, ""); + assert_eq!(resp.models[0].input_token_limit, 0); + assert_eq!(resp.models[0].output_token_limit, 0); + assert!(!resp.models[0].supports_generation()); + assert!(resp.next_page_token.is_none()); + } +} diff --git a/crates/providers/src/lib.rs b/crates/providers/src/lib.rs index 9f44f97fb..5e4ae0db3 100644 --- a/crates/providers/src/lib.rs +++ b/crates/providers/src/lib.rs @@ -23,6 +23,9 @@ pub mod github_copilot; #[cfg(feature = "provider-kimi-code")] pub mod kimi_code; +#[cfg(feature = "provider-gemini-oauth")] +pub mod gemini_oauth; + #[cfg(feature = "local-llm")] pub mod local_gguf; @@ -1267,6 +1270,11 @@ impl ProviderRegistry { reg.register_kimi_code_providers(config, env_overrides); } + #[cfg(feature = "provider-gemini-oauth")] + { + reg.register_gemini_oauth_providers(config); + } + // Local GGUF providers (no API key needed, model runs locally) #[cfg(feature = "local-llm")] { @@ -1520,6 +1528,42 @@ impl ProviderRegistry { } } + #[cfg(feature = "provider-gemini-oauth")] + fn register_gemini_oauth_providers(&mut self, config: &ProvidersConfig) { + if !config.is_enabled("gemini-oauth") { + return; + } + if !gemini_oauth::has_stored_tokens() { + return; + } + + let preferred = configured_models_for_provider(config, "gemini-oauth"); + let discovered = if should_fetch_models(config, "gemini-oauth") { + gemini_oauth::available_models() + } else { + gemini_oauth::default_model_catalog() + }; + let models = merge_preferred_and_discovered_models(preferred, discovered); + + for model in models { + let (model_id, display_name, created_at) = + (model.id, model.display_name, model.created_at); + if self.has_provider_model("gemini-oauth", &model_id) { + continue; + } + let provider = Arc::new(gemini_oauth::GeminiOAuthProvider::new(model_id.clone())); + self.register( + ModelInfo { + id: model_id, + provider: "gemini-oauth".into(), + display_name, + created_at, + }, + provider, + ); + } + } + #[cfg(feature = "local-llm")] fn register_local_gguf_providers(&mut self, config: &ProvidersConfig) { use std::path::PathBuf; @@ -2677,6 +2721,26 @@ mod tests { ); } + #[cfg(feature = "provider-gemini-oauth")] + #[test] + fn gemini_oauth_not_registered_without_tokens() { + let mut config = ProvidersConfig::default(); + config.providers.insert( + "gemini-oauth".into(), + moltis_config::schema::ProviderEntry { + enabled: true, + ..Default::default() + }, + ); + + let reg = ProviderRegistry::from_env_with_config(&config); + assert!( + !reg.list_models() + .iter() + .any(|m| m.provider == "gemini-oauth") + ); + } + #[test] fn openrouter_requires_model_in_config() { // OpenRouter has no default models — without configured models it registers nothing. diff --git a/docs/src/SUMMARY.md b/docs/src/SUMMARY.md index 62d191472..613a91219 100644 --- a/docs/src/SUMMARY.md +++ b/docs/src/SUMMARY.md @@ -18,6 +18,7 @@ # Features - [LLM Providers](providers.md) + - [Google Gemini](gemini.md) - [MCP Servers](mcp.md) - [Memory](memory.md) - [Moltis vs OpenClaw](memory-comparison.md) diff --git a/docs/src/gemini.md b/docs/src/gemini.md new file mode 100644 index 000000000..bea8aa256 --- /dev/null +++ b/docs/src/gemini.md @@ -0,0 +1,156 @@ +# Google Gemini Provider + +Moltis supports Google Gemini models through two authentication methods: + +1. **API Key** (`gemini`) - Direct API key authentication +2. **OAuth** (`gemini-oauth`) - Browser-based authentication with your Google account + +## API Key Provider + +The simplest way to use Gemini. Get an API key from [Google AI Studio](https://aistudio.google.com/apikey) and set it: + +```bash +export GEMINI_API_KEY=your_api_key_here +``` + +Or add it to your `moltis.toml`: + +```toml +[providers.gemini] +api_key = "your_api_key_here" +``` + +## OAuth Provider + +The OAuth provider allows users to authenticate with their Google account. **API usage is billed to the user's Google account**, not to the application developer. This is the recommended approach for distributed applications. + +### How It Works + +1. User initiates login in the Moltis UI +2. Browser opens to Google OAuth consent screen +3. User authenticates with their Google account +4. Browser redirects to local callback server (port 1456) +5. Moltis exchanges the authorization code for tokens using PKCE +6. Tokens are stored securely for future use + +### Technical Details + +- **Flow**: Authorization Code with PKCE (no client secret required) +- **Scopes**: `generative-language.retriever`, `cloud-platform` +- **Token refresh**: Automatic with 5-minute buffer before expiry +- **Storage**: Tokens stored in `~/.moltis/oauth_tokens.json` + +### For Application Developers + +To enable Gemini OAuth in your Moltis deployment, you need to create a Google Cloud OAuth client and update the client ID in the codebase. + +#### Step 1: Create a Google Cloud Project + +1. Go to [Google Cloud Console](https://console.cloud.google.com/) +2. Create a new project or select an existing one +3. Enable the **Generative Language API**: + - Go to APIs & Services > Library + - Search for "Generative Language API" + - Click Enable + +#### Step 2: Configure OAuth Consent Screen + +1. Go to APIs & Services > OAuth consent screen +2. Select **External** user type (or Internal if using Google Workspace) +3. Fill in the required fields: + - App name: `Moltis` (or your app name) + - User support email: your email + - Developer contact: your email +4. Add scopes: + - `https://www.googleapis.com/auth/generative-language.retriever` + - `https://www.googleapis.com/auth/cloud-platform` +5. Add test users if in testing mode + +#### Step 3: Create OAuth Credentials + +1. Go to APIs & Services > Credentials +2. Click **Create Credentials** > **OAuth client ID** +3. Select **Desktop app** as the application type +4. Name it (e.g., "Moltis Desktop") +5. Click Create +6. Copy the **Client ID** (you don't need the client secret for PKCE) + +#### Step 4: Update the Client ID + +Replace the placeholder in `crates/oauth/src/defaults.rs`: + +```rust +m.insert("gemini-oauth".into(), OAuthConfig { + client_id: "YOUR_CLIENT_ID_HERE.apps.googleusercontent.com".into(), + // ... rest of config +}); +``` + +### Security Notes + +- The client ID is **not a secret** - it's safe to embed in distributed applications +- PKCE (Proof Key for Code Exchange) prevents authorization code interception attacks +- No client secret is needed because PKCE provides equivalent security +- Tokens are stored locally on the user's machine +- API usage and billing is tied to the user's Google account + +## Supported Models + +Both providers support the same models: + +| Model ID | Description | +|----------|-------------| +| `gemini-2.5-pro-preview-06-05` | Gemini 2.5 Pro (latest) | +| `gemini-2.5-flash-preview-05-20` | Gemini 2.5 Flash (latest) | +| `gemini-2.0-flash` | Gemini 2.0 Flash | +| `gemini-2.0-flash-lite` | Gemini 2.0 Flash Lite | +| `gemini-1.5-pro` | Gemini 1.5 Pro | +| `gemini-1.5-flash` | Gemini 1.5 Flash | + +All models support: +- 1M token context window +- Tool/function calling +- Streaming responses +- System instructions + +## Configuration + +### Selecting a specific model + +```toml +[providers.gemini] +model = "gemini-2.5-pro-preview-06-05" + +[providers.gemini-oauth] +model = "gemini-2.0-flash" +``` + +### Disabling a provider + +```toml +[providers.gemini] +enabled = false +``` + +## Troubleshooting + +### "not logged in to gemini-oauth" + +The OAuth flow hasn't been completed. Click the login button in the Moltis UI to authenticate. + +### "token expired and no refresh token available" + +The stored tokens are invalid. Clear them and re-authenticate: +- Delete `~/.moltis/oauth_tokens.json` or the gemini-oauth entry within it +- Re-authenticate through the UI + +### OAuth callback timeout + +If the browser doesn't redirect properly: +1. Check that port 1456 is not blocked by a firewall +2. Ensure no other application is using port 1456 +3. Try the authentication flow again + +### API quota errors + +Gemini has usage quotas. Check your quota in the [Google Cloud Console](https://console.cloud.google.com/apis/api/generativelanguage.googleapis.com/quotas). diff --git a/docs/src/providers.md b/docs/src/providers.md index 3e3297eb2..8d67ea74e 100644 --- a/docs/src/providers.md +++ b/docs/src/providers.md @@ -27,6 +27,7 @@ Configure providers through the web UI or directly in configuration files. | Provider | Config Name | Notes | |----------|-------------|-------| +| **Google Gemini (OAuth)** | `gemini-oauth` | Browser OAuth flow with Google account | | **OpenAI Codex** | `openai-codex` | OAuth flow via web UI | | **GitHub Copilot** | `github-copilot` | Requires active Copilot subscription | @@ -120,6 +121,16 @@ models = ["gemini-2.5-flash-preview-05-20", "gemini-2.0-flash"] Gemini supports native tool calling, vision/multimodal inputs, streaming, and automatic model discovery. +### Google Gemini (OAuth) + +Gemini OAuth uses browser-based OAuth authentication with your Google account. + +1. Go to **Settings** → **Providers** → **Google Gemini (OAuth)**. +2. Click **Connect**. +3. Complete the Google OAuth flow. + +See [Google Gemini](gemini.md) for full OAuth setup details and troubleshooting. + ### Anthropic 1. Get an API key from [console.anthropic.com](https://console.anthropic.com/).