diff --git a/src-tauri/src/antigravity/endpoints.rs b/src-tauri/src/antigravity/endpoints.rs deleted file mode 100644 index 8784e70..0000000 --- a/src-tauri/src/antigravity/endpoints.rs +++ /dev/null @@ -1,35 +0,0 @@ -use std::env; - -pub(crate) const BASE_URL_DAILY: &str = "https://daily-cloudcode-pa.googleapis.com"; -pub(crate) const BASE_URL_SANDBOX: &str = "https://daily-cloudcode-pa.sandbox.googleapis.com"; -pub(crate) const BASE_URL_PROD: &str = "https://cloudcode-pa.googleapis.com"; - -// Align with CLIProxyAPIPlus: prefer daily, then sandbox. Prod is intentionally excluded. -pub(crate) const BASE_URLS: [&str; 2] = [BASE_URL_DAILY, BASE_URL_SANDBOX]; - -const ANTIGRAVITY_VERSION: &str = "1.104.0"; - -pub(crate) fn default_user_agent() -> String { - let os = match env::consts::OS { - "macos" => "darwin", - other => other, - }; - let arch = env::consts::ARCH; - format!("antigravity/{ANTIGRAVITY_VERSION} {os}/{arch}") -} - -pub(crate) fn build_base_url_list(primary: &str) -> Vec { - let mut urls = Vec::new(); - let primary = primary.trim_end_matches('/'); - if !primary.is_empty() { - urls.push(primary.to_string()); - } - for base in BASE_URLS { - let base = base.trim_end_matches('/'); - if urls.iter().any(|value| value == base) { - continue; - } - urls.push(base.to_string()); - } - urls -} diff --git a/src-tauri/src/antigravity/ide.rs b/src-tauri/src/antigravity/ide.rs deleted file mode 100644 index 0de5c8d..0000000 --- a/src-tauri/src/antigravity/ide.rs +++ /dev/null @@ -1,302 +0,0 @@ -use std::path::{Path, PathBuf}; -use std::process::Command; - -use serde_json::Value; -use sysinfo::ProcessesToUpdate; -use tauri::{AppHandle, Manager}; - -use super::ide_db; -use super::protobuf; -use super::store::AntigravityAccountStore; -use super::types::{AntigravityAccountSummary, AntigravityIdeStatus, AntigravityTokenRecord}; - -#[derive(Clone)] -pub(crate) struct AntigravityIdeConfig { - pub(crate) ide_db_path: Option, - pub(crate) app_paths: Vec, - pub(crate) process_names: Vec, -} - -#[cfg(target_os = "macos")] -const DEFAULT_PROCESS_NAMES: [&str; 2] = ["com.google.antigravity", "com.todesktop.230313mzl4w4u92"]; - -#[cfg(not(target_os = "macos"))] -const DEFAULT_PROCESS_NAMES: [&str; 0] = []; - -#[cfg(target_os = "macos")] -fn default_app_paths(home: &Path) -> Vec { - vec![ - PathBuf::from("/Applications/Antigravity.app"), - home.join("Applications").join("Antigravity.app"), - ] -} - -#[cfg(not(target_os = "macos"))] -fn default_app_paths(_home: &Path) -> Vec { - Vec::new() -} - -#[cfg(target_os = "macos")] -fn default_db_path(home: &Path) -> PathBuf { - home.join("Library") - .join("Application Support") - .join("Antigravity") - .join("User") - .join("globalStorage") - .join("state.vscdb") -} - -#[cfg(not(target_os = "macos"))] -fn default_db_path(_home: &Path) -> PathBuf { - PathBuf::new() -} - -pub(crate) async fn import_from_ide( - app: &AppHandle, - store: &AntigravityAccountStore, - path_override: Option, -) -> Result, String> { - let config = resolve_ide_config(app).await?; - let db_path = resolve_db_path(path_override, &config)?; - let state = ide_db::read_item(&db_path, "jetskiStateSync.agentManagerInitState") - .await? - .ok_or_else(|| "Antigravity IDE state not found.".to_string())?; - let mut record = match protobuf::extract_token_record(&state)? { - Some(record) => record, - None => return Err("Failed to extract Antigravity token from IDE.".to_string()), - }; - record.email = read_auth_email(&db_path).await?; - record.source = Some("ide".to_string()); - let summary = store.save_new_account(record).await?; - Ok(vec![summary]) -} - -pub(crate) async fn switch_ide_account( - app: &AppHandle, - store: &AntigravityAccountStore, - account_id: &str, - path_override: Option, -) -> Result { - let config = resolve_ide_config(app).await?; - let db_path = resolve_db_path(path_override, &config)?; - let record = store.get_account_record(account_id).await?; - ensure_ide_closed(&config).await?; - ide_db::delete_wal_shm(&db_path).await?; - let backup = ide_db::backup_db(&db_path).await?; - let result = apply_account_to_db(&db_path, &record).await; - if let Err(err) = result { - let _ = ide_db::restore_db(&db_path, &backup).await; - let _ = ide_db::cleanup_backup(&backup).await; - return Err(err); - } - let _ = ide_db::cleanup_backup(&backup).await; - restart_ide(&config).await?; - ide_status(app, Some(&config)).await -} - -pub(crate) async fn ide_status( - app: &AppHandle, - cached: Option<&AntigravityIdeConfig>, -) -> Result { - let config = match cached { - Some(config) => config.clone(), - None => resolve_ide_config(app).await?, - }; - let database_available = config - .ide_db_path - .as_ref() - .map(|path| path.exists()) - .unwrap_or(false); - let active_email = if database_available { - let db_path = config.ide_db_path.as_ref().expect("db path"); - read_auth_email(db_path).await? - } else { - None - }; - let ide_running = is_ide_running(&config).await; - Ok(AntigravityIdeStatus { - database_available, - ide_running, - active_email, - }) -} - -async fn apply_account_to_db(path: &Path, record: &AntigravityTokenRecord) -> Result<(), String> { - let state = ide_db::read_item(path, "jetskiStateSync.agentManagerInitState") - .await? - .ok_or_else(|| "Antigravity IDE state not found.".to_string())?; - let injected = protobuf::inject_token_record(&state, record)?; - ide_db::write_item(path, "jetskiStateSync.agentManagerInitState", &injected).await?; - ide_db::write_item(path, "antigravityOnboarding", "true").await?; - if let Some(email) = record.email.as_deref() { - let payload = serde_json::json!({ "email": email }); - let payload = serde_json::to_string(&payload).unwrap_or_default(); - let _ = ide_db::write_item(path, "antigravityAuthStatus", &payload).await; - } - Ok(()) -} - -async fn read_auth_email(path: &Path) -> Result, String> { - let raw = ide_db::read_item(path, "antigravityAuthStatus").await?; - let Some(raw) = raw else { - return Ok(None); - }; - let Ok(value) = serde_json::from_str::(&raw) else { - return Ok(None); - }; - if let Some(email) = value.get("email").and_then(Value::as_str) { - let trimmed = email.trim(); - if !trimmed.is_empty() { - return Ok(Some(trimmed.to_string())); - } - } - Ok(None) -} - -async fn resolve_ide_config(app: &AppHandle) -> Result { - let config = crate::proxy::config::read_config(app.clone()).await?.config; - let home = resolve_home_dir(app)?; - let ide_db_path = config - .antigravity_ide_db_path - .as_deref() - .map(str::trim) - .filter(|value| !value.is_empty()) - .map(PathBuf::from) - .or_else(|| { - let default = default_db_path(&home); - if default.as_os_str().is_empty() { - None - } else { - Some(default) - } - }); - let app_paths = if !config.antigravity_app_paths.is_empty() { - config - .antigravity_app_paths - .iter() - .map(|value| PathBuf::from(value)) - .collect() - } else { - default_app_paths(&home) - }; - let process_names = if !config.antigravity_process_names.is_empty() { - config - .antigravity_process_names - .iter() - .map(|value| value.trim().to_string()) - .filter(|value| !value.is_empty()) - .collect() - } else { - DEFAULT_PROCESS_NAMES.iter().map(|value| value.to_string()).collect() - }; - Ok(AntigravityIdeConfig { - ide_db_path, - app_paths, - process_names, - }) -} - -fn resolve_db_path( - override_path: Option, - config: &AntigravityIdeConfig, -) -> Result { - if let Some(path) = override_path { - return Ok(path); - } - config - .ide_db_path - .clone() - .ok_or_else(|| "Antigravity IDE database path is not configured.".to_string()) -} - -async fn is_ide_running(config: &AntigravityIdeConfig) -> bool { - if config.process_names.is_empty() { - return false; - } - let targets: Vec = config - .process_names - .iter() - .map(|value| value.to_lowercase()) - .collect(); - super::warmup::run_blocking(move || { - let mut system = sysinfo::System::new_all(); - system.refresh_processes(ProcessesToUpdate::All, true); - system.processes().values().any(|process| { - let name = process.name().to_string_lossy().to_lowercase(); - targets.iter().any(|target| name.contains(target)) - }) - }) - .await - .unwrap_or(false) -} - -async fn ensure_ide_closed(config: &AntigravityIdeConfig) -> Result<(), String> { - if config.process_names.is_empty() { - return Ok(()); - } - let targets = config.process_names.clone(); - super::warmup::run_blocking(move || { - let mut system = sysinfo::System::new_all(); - system.refresh_processes(ProcessesToUpdate::All, true); - for process in system.processes().values() { - let name = process.name().to_string_lossy().to_lowercase(); - if targets.iter().any(|target| name.contains(&target.to_lowercase())) { - let _ = process.kill(); - } - } - }) - .await - .map_err(|_| "Failed to terminate Antigravity IDE.".to_string())?; - Ok(()) -} - -async fn restart_ide(config: &AntigravityIdeConfig) -> Result<(), String> { - if config.app_paths.is_empty() { - return Ok(()); - } - for path in &config.app_paths { - if !path.exists() { - continue; - } - #[cfg(target_os = "macos")] - { - let result = Command::new("open").arg(path).spawn(); - if result.is_ok() { - return Ok(()); - } - } - #[cfg(target_os = "windows")] - { - let result = Command::new("cmd") - .args(["/C", "start", ""]) - .arg(path) - .spawn(); - if result.is_ok() { - return Ok(()); - } - } - #[cfg(target_os = "linux")] - { - let result = Command::new("xdg-open").arg(path).spawn(); - if result.is_ok() { - return Ok(()); - } - } - } - Ok(()) -} - -fn resolve_home_dir(app: &AppHandle) -> Result { - if let Ok(dir) = app.path().home_dir() { - return Ok(dir); - } - if let Some(dir) = std::env::var_os("HOME").map(PathBuf::from) { - return Ok(dir); - } - if cfg!(target_os = "windows") { - if let Some(dir) = std::env::var_os("USERPROFILE").map(PathBuf::from) { - return Ok(dir); - } - } - Err("Failed to resolve user home directory.".to_string()) -} diff --git a/src-tauri/src/antigravity/ide_db.rs b/src-tauri/src/antigravity/ide_db.rs deleted file mode 100644 index adcd2c9..0000000 --- a/src-tauri/src/antigravity/ide_db.rs +++ /dev/null @@ -1,104 +0,0 @@ -use std::path::{Path, PathBuf}; -use std::time::Duration; - -use sqlx::sqlite::SqliteConnectOptions; -use sqlx::{Connection, Row, SqliteConnection}; -use time::OffsetDateTime; - -const WAL_SUFFIX: &str = "-wal"; -const SHM_SUFFIX: &str = "-shm"; - -pub(crate) async fn read_item(path: &Path, key: &str) -> Result, String> { - let mut conn = open_connection(path, true).await?; - let row = sqlx::query("SELECT value FROM ItemTable WHERE key = ?") - .bind(key) - .fetch_optional(&mut conn) - .await - .map_err(|err| format!("Failed to read Antigravity state: {err}"))?; - let Some(row) = row else { - return Ok(None); - }; - let value: Vec = row.try_get("value").unwrap_or_default(); - Ok(Some(String::from_utf8_lossy(&value).to_string())) -} - -pub(crate) async fn write_item(path: &Path, key: &str, value: &str) -> Result<(), String> { - let mut conn = open_connection(path, false).await?; - let mut tx = conn - .begin() - .await - .map_err(|err| format!("Failed to start Antigravity DB transaction: {err}"))?; - sqlx::query("INSERT INTO ItemTable(key, value) VALUES(?, ?) ON CONFLICT(key) DO UPDATE SET value=excluded.value") - .bind(key) - .bind(value) - .execute(&mut *tx) - .await - .map_err(|err| format!("Failed to write Antigravity state: {err}"))?; - tx.commit() - .await - .map_err(|err| format!("Failed to commit Antigravity state: {err}"))?; - Ok(()) -} - -pub(crate) async fn delete_wal_shm(path: &Path) -> Result<(), String> { - let wal = path_with_suffix(path, WAL_SUFFIX); - let shm = path_with_suffix(path, SHM_SUFFIX); - if tokio::fs::try_exists(&wal).await.unwrap_or(false) { - tokio::fs::remove_file(&wal) - .await - .map_err(|err| format!("Failed to remove WAL: {err}"))?; - } - if tokio::fs::try_exists(&shm).await.unwrap_or(false) { - tokio::fs::remove_file(&shm) - .await - .map_err(|err| format!("Failed to remove SHM: {err}"))?; - } - Ok(()) -} - -pub(crate) async fn backup_db(path: &Path) -> Result { - let timestamp = OffsetDateTime::now_utc().unix_timestamp(); - let backup = path.with_extension(format!("vscdb.bak-{timestamp}")); - tokio::fs::copy(path, &backup) - .await - .map_err(|err| format!("Failed to backup Antigravity DB: {err}"))?; - Ok(backup) -} - -pub(crate) async fn restore_db(original: &Path, backup: &Path) -> Result<(), String> { - if tokio::fs::try_exists(backup).await.unwrap_or(false) { - tokio::fs::copy(backup, original) - .await - .map_err(|err| format!("Failed to restore Antigravity DB: {err}"))?; - } - Ok(()) -} - -pub(crate) async fn cleanup_backup(path: &Path) -> Result<(), String> { - if tokio::fs::try_exists(path).await.unwrap_or(false) { - tokio::fs::remove_file(path) - .await - .map_err(|err| format!("Failed to cleanup backup: {err}"))?; - } - Ok(()) -} - -async fn open_connection(path: &Path, read_only: bool) -> Result { - if !tokio::fs::try_exists(path).await.unwrap_or(false) { - return Err("Antigravity IDE database not found.".to_string()); - } - let options = SqliteConnectOptions::new() - .filename(path) - .read_only(read_only) - .create_if_missing(false) - .busy_timeout(Duration::from_secs(3)); - SqliteConnection::connect_with(&options) - .await - .map_err(|err| format!("Failed to open Antigravity DB: {err}")) -} - -fn path_with_suffix(path: &Path, suffix: &str) -> PathBuf { - let file_name = path.file_name().and_then(|name| name.to_str()).unwrap_or(""); - let new_name = format!("{file_name}{suffix}"); - path.with_file_name(new_name) -} diff --git a/src-tauri/src/antigravity/login.rs b/src-tauri/src/antigravity/login.rs deleted file mode 100644 index a9b593a..0000000 --- a/src-tauri/src/antigravity/login.rs +++ /dev/null @@ -1,295 +0,0 @@ -use std::collections::HashMap; -use std::sync::Arc; -use std::time::Duration; - -use time::OffsetDateTime; -use tokio::sync::RwLock; - -use crate::app_proxy::AppProxyState; -use crate::oauth_util::{expires_at_from_seconds, generate_state}; - -use super::oauth::AntigravityOAuthClient; -use super::project; -use super::store::AntigravityAccountStore; -use super::types::{ - AntigravityAccountSummary, AntigravityLoginPollResponse, AntigravityLoginStartResponse, - AntigravityLoginStatus, AntigravityTokenRecord, -}; - -const AUTH_CODE_TIMEOUT: Duration = Duration::from_secs(300); -const POLL_INTERVAL_SECONDS: u64 = 2; -const CALLBACK_PORT: u16 = 51121; - -#[derive(Clone)] -pub(crate) struct AntigravityLoginManager { - store: Arc, - sessions: Arc>>, - app_proxy: AppProxyState, -} - -#[derive(Clone)] -struct LoginSession { - status: AntigravityLoginStatus, - error: Option, - account: Option, - expires_at: Option, -} - -impl AntigravityLoginManager { - pub(crate) fn new(store: Arc, app_proxy: AppProxyState) -> Self { - Self { - store, - sessions: Arc::new(RwLock::new(HashMap::new())), - app_proxy, - } - } - - pub(crate) async fn start_login(&self) -> Result { - let state = generate_state("antigravity")?; - let expires_at = Some(OffsetDateTime::now_utc() + time::Duration::seconds(300)); - self.insert_session(&state, expires_at).await; - let callback = start_auth_code_callback(state.clone()).await?; - let login_url = AntigravityOAuthClient::build_authorize_url(&callback.redirect_uri, &state); - let manager = self.clone(); - let state_for_task = state.clone(); - tauri::async_runtime::spawn(async move { - run_auth_code_login(manager, state_for_task, callback).await; - }); - Ok(AntigravityLoginStartResponse { - state, - login_url, - interval_seconds: POLL_INTERVAL_SECONDS, - expires_at: Some(expires_at_from_seconds(AUTH_CODE_TIMEOUT.as_secs() as i64)), - }) - } - - pub(crate) async fn poll_login( - &self, - state: &str, - ) -> Result { - let mut guard = self.sessions.write().await; - let session = guard - .get_mut(state) - .ok_or_else(|| "Login session not found.".to_string())?; - if session.status != AntigravityLoginStatus::Success - && session.status != AntigravityLoginStatus::Error - && session - .expires_at - .map(|deadline| OffsetDateTime::now_utc() > deadline) - .unwrap_or(false) - { - session.status = AntigravityLoginStatus::Error; - session.error = Some("Login expired.".to_string()); - } - Ok(AntigravityLoginPollResponse { - state: state.to_string(), - status: session.status.clone(), - error: session.error.clone(), - account: session.account.clone(), - }) - } - - pub(crate) async fn logout(&self, account_id: &str) -> Result<(), String> { - self.store.delete_account(account_id).await - } - - async fn insert_session(&self, state: &str, expires_at: Option) { - let session = LoginSession { - status: AntigravityLoginStatus::Waiting, - error: None, - account: None, - expires_at, - }; - let mut guard = self.sessions.write().await; - guard.insert(state.to_string(), session); - } - - async fn complete_session(&self, state: &str, account: AntigravityAccountSummary) { - let mut guard = self.sessions.write().await; - if let Some(session) = guard.get_mut(state) { - session.status = AntigravityLoginStatus::Success; - session.error = None; - session.account = Some(account); - } - } - - async fn fail_session(&self, state: &str, message: String) { - let mut guard = self.sessions.write().await; - if let Some(session) = guard.get_mut(state) { - session.status = AntigravityLoginStatus::Error; - session.error = Some(message); - } - } - - async fn app_proxy_url(&self) -> Option { - self.app_proxy.read().await.clone() - } -} - -struct AuthCodeCallback { - redirect_uri: String, - receiver: tokio::sync::mpsc::Receiver, - shutdown: Option>, -} - -#[derive(Clone)] -struct AuthCodeResult { - code: Option, - state: Option, - error: Option, -} - -async fn start_auth_code_callback(state: String) -> Result { - let (tx, rx) = tokio::sync::mpsc::channel::(1); - let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); - let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{CALLBACK_PORT}")) - .await - .map_err(|err| format!("Failed to start callback server: {err}"))?; - let redirect_uri = format!("http://localhost:{CALLBACK_PORT}/oauth-callback"); - let router = axum::Router::new().route( - "/oauth-callback", - axum::routing::get(move |query: axum::extract::Query>| { - let expected_state = state.clone(); - let tx = tx.clone(); - async move { - let code = query.get("code").cloned(); - let state = query.get("state").cloned(); - let error = query.get("error").cloned(); - let has_error = error.is_some(); - let state_matches = state.as_deref() == Some(&expected_state); - let _ = tx.send(AuthCodeResult { code, state, error }).await; - let body = if has_error || !state_matches { - "Login failed. You can close this window." - } else { - "Login successful. You can close this window." - }; - axum::response::Html(body) - } - }), - ); - tauri::async_runtime::spawn(async move { - let _ = axum::serve(listener, router) - .with_graceful_shutdown(async move { - let _ = shutdown_rx.await; - }) - .await; - }); - Ok(AuthCodeCallback { - redirect_uri, - receiver: rx, - shutdown: Some(shutdown_tx), - }) -} - -async fn run_auth_code_login( - manager: AntigravityLoginManager, - state: String, - mut callback: AuthCodeCallback, -) { - let redirect_uri = callback.redirect_uri.clone(); - let callback_result = match wait_for_auth_code(&mut callback).await { - Ok(result) => result, - Err(err) => { - manager.fail_session(&state, err).await; - return; - } - }; - let code = match extract_auth_code(&state, callback_result) { - Ok(code) => code, - Err(err) => { - manager.fail_session(&state, err).await; - return; - } - }; - let proxy_url = manager.app_proxy_url().await; - let client = AntigravityOAuthClient::new(proxy_url.clone()); - let token = match client.exchange_code(&code, &redirect_uri).await { - Ok(token) => token, - Err(err) => { - manager.fail_session(&state, err).await; - return; - } - }; - let email = match client.fetch_user_email(&token.access_token).await { - Ok(email) => email, - Err(err) => { - manager.fail_session(&state, err).await; - return; - } - }; - - let project_id = match project::load_code_assist(&token.access_token, proxy_url.as_deref()).await - { - Ok(info) => { - let mut project_id = info.project_id.clone(); - if project_id.is_none() { - if let Some(tier_id) = info.plan_type.as_deref() { - match project::onboard_user( - &token.access_token, - proxy_url.as_deref(), - tier_id, - ) - .await - { - Ok(Some(value)) => project_id = Some(value), - Ok(None) => {} - Err(err) => { - tracing::warn!(error = %err, "antigravity onboardUser failed"); - } - } - } - } - project_id - } - Err(err) => { - tracing::warn!(error = %err, "antigravity loadCodeAssist failed"); - None - } - }; - - let record = AntigravityTokenRecord { - access_token: token.access_token.clone(), - refresh_token: token.refresh_token.clone(), - expired: Some(expires_at_from_seconds(token.expires_in)), - expires_in: Some(token.expires_in), - timestamp: Some(OffsetDateTime::now_utc().unix_timestamp() * 1000), - email: email.clone(), - token_type: token.token_type.clone(), - project_id, - source: Some("oauth".to_string()), - }; - let summary = match manager.store.save_new_account(record).await { - Ok(summary) => summary, - Err(err) => { - manager.fail_session(&state, err).await; - return; - } - }; - manager.complete_session(&state, summary).await; -} - -async fn wait_for_auth_code(callback: &mut AuthCodeCallback) -> Result { - let timeout = tokio::time::sleep(AUTH_CODE_TIMEOUT); - tokio::pin!(timeout); - tokio::select! { - _ = &mut timeout => Err("Login timed out.".to_string()), - result = callback.receiver.recv() => { - let _ = callback.shutdown.take().map(|sender| sender.send(())); - result.ok_or_else(|| "Login failed.".to_string()) - } - } -} - -fn extract_auth_code(state: &str, result: AuthCodeResult) -> Result { - if result.error.is_some() { - return Err("Login failed.".to_string()); - } - if result.state.as_deref() != Some(state) { - return Err("Login failed: state mismatch.".to_string()); - } - let code = result.code.unwrap_or_default(); - if code.trim().is_empty() { - return Err("Login failed: code missing.".to_string()); - } - Ok(code) -} diff --git a/src-tauri/src/antigravity/oauth.rs b/src-tauri/src/antigravity/oauth.rs deleted file mode 100644 index 4f793ff..0000000 --- a/src-tauri/src/antigravity/oauth.rs +++ /dev/null @@ -1,141 +0,0 @@ -use reqwest::header::CONTENT_TYPE; -use serde::Deserialize; -use std::collections::HashMap; -use std::time::Duration; - -use crate::oauth_util::build_reqwest_client; - -const AUTH_URL: &str = "https://accounts.google.com/o/oauth2/v2/auth"; -const TOKEN_URL: &str = "https://oauth2.googleapis.com/token"; -const USERINFO_URL: &str = "https://www.googleapis.com/oauth2/v1/userinfo?alt=json"; - -const CLIENT_ID: &str = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"; -const CLIENT_SECRET: &str = "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"; - -const DEFAULT_TIMEOUT_SECS: u64 = 20; - -const SCOPES: [&str; 5] = [ - "https://www.googleapis.com/auth/cloud-platform", - "https://www.googleapis.com/auth/userinfo.email", - "https://www.googleapis.com/auth/userinfo.profile", - "https://www.googleapis.com/auth/cclog", - "https://www.googleapis.com/auth/experimentsandconfigs", -]; - -#[derive(Clone, Debug, Deserialize)] -pub(crate) struct AntigravityTokenResponse { - pub(crate) access_token: String, - pub(crate) refresh_token: Option, - pub(crate) expires_in: i64, - pub(crate) token_type: Option, -} - -#[derive(Clone, Debug, Deserialize)] -struct UserInfoResponse { - email: Option, -} - -pub(crate) struct AntigravityOAuthClient { - proxy_url: Option, -} - -impl AntigravityOAuthClient { - pub(crate) fn new(proxy_url: Option) -> Self { - Self { proxy_url } - } - - pub(crate) fn build_authorize_url(redirect_uri: &str, state: &str) -> String { - let mut params = HashMap::new(); - params.insert("access_type", "offline"); - params.insert("client_id", CLIENT_ID); - params.insert("prompt", "consent"); - params.insert("redirect_uri", redirect_uri); - params.insert("response_type", "code"); - let scope = SCOPES.join(" "); - params.insert("scope", scope.as_str()); - params.insert("state", state); - let query = serde_urlencoded::to_string(params).unwrap_or_default(); - format!("{AUTH_URL}?{query}") - } - - pub(crate) async fn exchange_code( - &self, - code: &str, - redirect_uri: &str, - ) -> Result { - let client = build_reqwest_client(self.proxy_url.as_deref(), Duration::from_secs(DEFAULT_TIMEOUT_SECS))?; - let mut params = HashMap::new(); - params.insert("code", code); - params.insert("client_id", CLIENT_ID); - params.insert("client_secret", CLIENT_SECRET); - params.insert("redirect_uri", redirect_uri); - params.insert("grant_type", "authorization_code"); - let response = client - .post(TOKEN_URL) - .header(CONTENT_TYPE, "application/x-www-form-urlencoded") - .form(¶ms) - .send() - .await - .map_err(|err| format!("Token exchange failed: {err}"))?; - if !response.status().is_success() { - let status = response.status(); - let body = response.text().await.unwrap_or_default(); - return Err(format!("Token exchange failed: {status} {body}")); - } - response - .json::() - .await - .map_err(|err| format!("Failed to parse token response: {err}")) - } - - pub(crate) async fn refresh_token( - &self, - refresh_token: &str, - ) -> Result { - let client = build_reqwest_client(self.proxy_url.as_deref(), Duration::from_secs(DEFAULT_TIMEOUT_SECS))?; - let mut params = HashMap::new(); - params.insert("client_id", CLIENT_ID); - params.insert("client_secret", CLIENT_SECRET); - params.insert("refresh_token", refresh_token); - params.insert("grant_type", "refresh_token"); - let response = client - .post(TOKEN_URL) - .header(CONTENT_TYPE, "application/x-www-form-urlencoded") - .form(¶ms) - .send() - .await - .map_err(|err| format!("Token refresh failed: {err}"))?; - if !response.status().is_success() { - let status = response.status(); - let body = response.text().await.unwrap_or_default(); - return Err(format!("Token refresh failed: {status} {body}")); - } - response - .json::() - .await - .map_err(|err| format!("Failed to parse refresh response: {err}")) - } - - pub(crate) async fn fetch_user_email(&self, access_token: &str) -> Result, String> { - if access_token.trim().is_empty() { - return Ok(None); - } - let client = build_reqwest_client(self.proxy_url.as_deref(), Duration::from_secs(DEFAULT_TIMEOUT_SECS))?; - let response = client - .get(USERINFO_URL) - .bearer_auth(access_token) - .send() - .await - .map_err(|err| format!("Failed to fetch user info: {err}"))?; - if !response.status().is_success() { - let status = response.status(); - let body = response.text().await.unwrap_or_default(); - return Err(format!("User info request failed: {status} {body}")); - } - let payload = response - .json::() - .await - .map_err(|err| format!("Failed to parse user info: {err}"))?; - Ok(payload.email.map(|value| value.trim().to_string()).filter(|value| !value.is_empty())) - } -} diff --git a/src-tauri/src/antigravity/project.rs b/src-tauri/src/antigravity/project.rs deleted file mode 100644 index 7e38af3..0000000 --- a/src-tauri/src/antigravity/project.rs +++ /dev/null @@ -1,227 +0,0 @@ -use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, USER_AGENT}; -use serde_json::Value; -use std::time::Duration; - -use super::endpoints; -use crate::oauth_util::build_reqwest_client; - -const LOAD_CODE_ASSIST_PATH: &str = "/v1internal:loadCodeAssist"; -const ONBOARD_USER_PATH: &str = "/v1internal:onboardUser"; - -const API_USER_AGENT: &str = "google-api-nodejs-client/9.15.1"; -const API_CLIENT: &str = "google-cloud-sdk vscode_cloudshelleditor/0.1"; -const CLIENT_METADATA: &str = - r#"{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}"#; - -const MAX_ONBOARD_ATTEMPTS: usize = 5; -const ONBOARD_POLL_DELAY_SECS: u64 = 2; - -#[derive(Clone, Default)] -pub(crate) struct AntigravityProjectInfo { - pub(crate) project_id: Option, - pub(crate) plan_type: Option, -} - -pub(crate) async fn load_code_assist( - access_token: &str, - proxy_url: Option<&str>, -) -> Result { - let client = build_reqwest_client(proxy_url, Duration::from_secs(20))?; - load_code_assist_with_client(&client, access_token).await -} - -pub(crate) async fn load_code_assist_with_client( - client: &reqwest::Client, - access_token: &str, -) -> Result { - let payload = serde_json::json!({ - "metadata": { - "ideType": "ANTIGRAVITY", - "platform": "PLATFORM_UNSPECIFIED", - "pluginType": "GEMINI" - } - }); - let mut last_error: Option = None; - for base in endpoints::BASE_URLS { - let url = format!("{}{}", base, LOAD_CODE_ASSIST_PATH); - let response = client - .post(url) - .header(AUTHORIZATION, format!("Bearer {access_token}")) - .header(USER_AGENT, API_USER_AGENT) - .header("X-Goog-Api-Client", API_CLIENT) - .header("Client-Metadata", CLIENT_METADATA) - .header(CONTENT_TYPE, "application/json") - .json(&payload) - .send() - .await; - let response = match response { - Ok(response) => response, - Err(err) => { - last_error = Some(format!("loadCodeAssist failed: {err}")); - continue; - } - }; - if !response.status().is_success() { - let status = response.status(); - let body = response.text().await.unwrap_or_default(); - let message = format!("loadCodeAssist failed: {status} {body}"); - if should_retry_status(status) { - last_error = Some(message); - continue; - } - return Err(message); - } - let value: Value = response - .json() - .await - .map_err(|err| format!("loadCodeAssist parse failed: {err}"))?; - return Ok(AntigravityProjectInfo { - project_id: extract_project_id(&value), - plan_type: extract_plan_type(&value), - }); - } - Err(last_error.unwrap_or_else(|| "loadCodeAssist failed.".to_string())) -} - -pub(crate) async fn onboard_user( - access_token: &str, - proxy_url: Option<&str>, - tier_id: &str, -) -> Result, String> { - let client = build_reqwest_client(proxy_url, Duration::from_secs(30))?; - onboard_user_with_client(&client, access_token, tier_id).await -} - -pub(crate) async fn onboard_user_with_client( - client: &reqwest::Client, - access_token: &str, - tier_id: &str, -) -> Result, String> { - if tier_id.trim().is_empty() { - return Ok(None); - } - let payload = serde_json::json!({ - "tierId": tier_id, - "metadata": { - "ideType": "ANTIGRAVITY", - "platform": "PLATFORM_UNSPECIFIED", - "pluginType": "GEMINI" - } - }); - for _ in 0..MAX_ONBOARD_ATTEMPTS { - let mut last_error: Option = None; - for base in endpoints::BASE_URLS { - let url = format!("{}{}", base, ONBOARD_USER_PATH); - let response = client - .post(url) - .header(AUTHORIZATION, format!("Bearer {access_token}")) - .header(USER_AGENT, API_USER_AGENT) - .header("X-Goog-Api-Client", API_CLIENT) - .header("Client-Metadata", CLIENT_METADATA) - .header(CONTENT_TYPE, "application/json") - .json(&payload) - .send() - .await; - let response = match response { - Ok(response) => response, - Err(err) => { - last_error = Some(format!("onboardUser failed: {err}")); - continue; - } - }; - if !response.status().is_success() { - let status = response.status(); - let body = response.text().await.unwrap_or_default(); - let message = format!("onboardUser failed: {status} {body}"); - if should_retry_status(status) { - last_error = Some(message); - continue; - } - return Err(message); - } - let value: Value = response - .json() - .await - .map_err(|err| format!("onboardUser parse failed: {err}"))?; - if value.get("done").and_then(Value::as_bool).unwrap_or(false) { - return Ok(extract_onboard_project_id(&value)); - } - } - if let Some(message) = last_error { - tracing::warn!(error = %message, "antigravity onboardUser retrying"); - } - tokio::time::sleep(Duration::from_secs(ONBOARD_POLL_DELAY_SECS)).await; - } - Ok(None) -} - -pub(crate) fn extract_project_id(value: &Value) -> Option { - if let Some(project) = value.get("cloudaicompanionProject") { - if let Some(text) = project.as_str() { - let trimmed = text.trim(); - if !trimmed.is_empty() { - return Some(trimmed.to_string()); - } - } - if let Some(obj) = project.as_object() { - if let Some(text) = obj.get("id").and_then(Value::as_str) { - let trimmed = text.trim(); - if !trimmed.is_empty() { - return Some(trimmed.to_string()); - } - } - } - } - None -} - -pub(crate) fn extract_plan_type(value: &Value) -> Option { - let tiers = value.get("allowedTiers")?.as_array()?; - for tier in tiers { - let obj = tier.as_object()?; - let is_default = obj - .get("isDefault") - .and_then(Value::as_bool) - .unwrap_or(false); - if !is_default { - continue; - } - if let Some(id) = obj.get("id").and_then(Value::as_str) { - let trimmed = id.trim(); - if !trimmed.is_empty() { - return Some(trimmed.to_string()); - } - } - } - None -} - -fn extract_onboard_project_id(value: &Value) -> Option { - let response = value.get("response")?; - match response.get("cloudaicompanionProject") { - Some(Value::String(text)) => { - let trimmed = text.trim(); - if trimmed.is_empty() { - None - } else { - Some(trimmed.to_string()) - } - } - Some(Value::Object(obj)) => obj - .get("id") - .and_then(Value::as_str) - .map(str::trim) - .filter(|value| !value.is_empty()) - .map(str::to_string), - _ => None, - } -} - -fn should_retry_status(status: reqwest::StatusCode) -> bool { - status == reqwest::StatusCode::TOO_MANY_REQUESTS || status.is_server_error() -} - -// 单元测试拆到独立文件,使用 `#[path]` 以保持 `.test.rs` 命名约定。 -#[cfg(test)] -#[path = "project.test.rs"] -mod tests; diff --git a/src-tauri/src/antigravity/project.test.rs b/src-tauri/src/antigravity/project.test.rs deleted file mode 100644 index c557bf7..0000000 --- a/src-tauri/src/antigravity/project.test.rs +++ /dev/null @@ -1,22 +0,0 @@ -use super::{extract_plan_type, extract_project_id}; -use serde_json::json; - -#[test] -fn extract_project_id_supports_string_and_object() { - let value = json!({ "cloudaicompanionProject": "project-123" }); - assert_eq!(extract_project_id(&value), Some("project-123".to_string())); - - let value = json!({ "cloudaicompanionProject": { "id": "project-456" } }); - assert_eq!(extract_project_id(&value), Some("project-456".to_string())); -} - -#[test] -fn extract_plan_type_picks_default_allowed_tier() { - let value = json!({ - "allowedTiers": [ - { "id": "FREE", "isDefault": true }, - { "id": "PRO", "isDefault": false } - ] - }); - assert_eq!(extract_plan_type(&value), Some("FREE".to_string())); -} diff --git a/src-tauri/src/antigravity/protobuf.rs b/src-tauri/src/antigravity/protobuf.rs deleted file mode 100644 index 608d4f4..0000000 --- a/src-tauri/src/antigravity/protobuf.rs +++ /dev/null @@ -1,281 +0,0 @@ -use base64::engine::general_purpose::STANDARD; -use base64::Engine; -use time::format_description::well_known::Rfc3339; -use time::OffsetDateTime; - -use super::types::AntigravityTokenRecord; - -const FIELD_OAUTH: u64 = 6; -const FIELD_ACCESS_TOKEN: u64 = 1; -const FIELD_TOKEN_TYPE: u64 = 2; -const FIELD_REFRESH_TOKEN: u64 = 3; -const FIELD_EXPIRES_AT: u64 = 4; -const FIELD_TIMESTAMP_SECONDS: u64 = 1; - -pub(crate) fn extract_token_record( - base64_state: &str, -) -> Result, String> { - let bytes = decode_base64(base64_state)?; - let mut pos = 0usize; - while pos < bytes.len() { - let tag_start = pos; - let tag = read_varint(&bytes, &mut pos).ok_or_else(|| "Invalid protobuf tag.".to_string())?; - let field_number = tag >> 3; - let wire_type = (tag & 0x07) as u8; - let field_end = skip_field(&bytes, pos, wire_type)?; - if field_number == FIELD_OAUTH && wire_type == 2 { - let length = read_varint(&bytes, &mut pos) - .ok_or_else(|| "Invalid protobuf length.".to_string())? as usize; - let end = pos + length; - if end > bytes.len() { - return Err("Invalid protobuf length.".to_string()); - } - let record = parse_oauth_message(&bytes[pos..end])?; - if record.is_some() { - return Ok(record); - } - pos = end; - } else { - pos = field_end; - } - if pos == tag_start { - break; - } - } - Ok(None) -} - -pub(crate) fn inject_token_record( - base64_state: &str, - record: &AntigravityTokenRecord, -) -> Result { - let bytes = decode_base64(base64_state)?; - let mut output = Vec::with_capacity(bytes.len() + 128); - let mut pos = 0usize; - while pos < bytes.len() { - let field_start = pos; - let tag = read_varint(&bytes, &mut pos).ok_or_else(|| "Invalid protobuf tag.".to_string())?; - let field_number = tag >> 3; - let wire_type = (tag & 0x07) as u8; - let field_end = skip_field(&bytes, pos, wire_type)?; - if field_number != FIELD_OAUTH { - output.extend_from_slice(&bytes[field_start..field_end]); - } - pos = field_end; - } - - let oauth_payload = build_oauth_message(record)?; - let mut field_bytes = Vec::with_capacity(1 + oauth_payload.len() + 10); - field_bytes.extend(encode_varint((FIELD_OAUTH << 3) | 2)); - field_bytes.extend(encode_varint(oauth_payload.len() as u64)); - field_bytes.extend(oauth_payload); - output.extend(field_bytes); - - Ok(STANDARD.encode(output)) -} - -fn parse_oauth_message(data: &[u8]) -> Result, String> { - let mut pos = 0usize; - let mut access_token: Option = None; - let mut refresh_token: Option = None; - let mut token_type: Option = None; - let mut expires_seconds: Option = None; - - while pos < data.len() { - let tag = read_varint(data, &mut pos).ok_or_else(|| "Invalid oauth tag.".to_string())?; - let field_number = tag >> 3; - let wire_type = (tag & 0x07) as u8; - match field_number { - FIELD_ACCESS_TOKEN if wire_type == 2 => { - let text = read_length_delimited_string(data, &mut pos)?; - if !text.trim().is_empty() { - access_token = Some(text); - } - } - FIELD_TOKEN_TYPE if wire_type == 2 => { - let text = read_length_delimited_string(data, &mut pos)?; - if !text.trim().is_empty() { - token_type = Some(text); - } - } - FIELD_REFRESH_TOKEN if wire_type == 2 => { - let text = read_length_delimited_string(data, &mut pos)?; - if !text.trim().is_empty() { - refresh_token = Some(text); - } - } - FIELD_EXPIRES_AT if wire_type == 2 => { - let length = read_varint(data, &mut pos) - .ok_or_else(|| "Invalid expiry length.".to_string())? as usize; - let end = pos + length; - if end > data.len() { - return Err("Invalid expiry length.".to_string()); - } - expires_seconds = parse_timestamp_seconds(&data[pos..end])?; - pos = end; - } - _ => { - pos = skip_field(data, pos, wire_type)?; - } - } - } - - let access_token = match access_token { - Some(value) => value, - None => return Ok(None), - }; - let now = OffsetDateTime::now_utc(); - let expires_at = expires_seconds - .and_then(|seconds| OffsetDateTime::from_unix_timestamp(seconds).ok()); - let expired = expires_at - .and_then(|value| value.format(&Rfc3339).ok()) - .filter(|value| !value.is_empty()); - let expires_in = expires_at.map(|value| (value - now).whole_seconds()); - - Ok(Some(AntigravityTokenRecord { - access_token, - refresh_token, - expired, - expires_in, - timestamp: Some(now.unix_timestamp() * 1000), - email: None, - token_type, - project_id: None, - source: Some("ide".to_string()), - })) -} - -fn build_oauth_message(record: &AntigravityTokenRecord) -> Result, String> { - let mut output = Vec::new(); - push_length_delimited(&mut output, FIELD_ACCESS_TOKEN, &record.access_token)?; - if let Some(token_type) = record.token_type.as_deref().filter(|value| !value.trim().is_empty()) { - push_length_delimited(&mut output, FIELD_TOKEN_TYPE, token_type)?; - } - if let Some(refresh) = record - .refresh_token - .as_deref() - .filter(|value| !value.trim().is_empty()) - { - push_length_delimited(&mut output, FIELD_REFRESH_TOKEN, refresh)?; - } - if let Some(expires_at) = record.expires_at() { - let seconds = expires_at.unix_timestamp(); - let mut ts = Vec::new(); - ts.extend(encode_varint((FIELD_TIMESTAMP_SECONDS << 3) | 0)); - ts.extend(encode_varint(seconds as u64)); - output.extend(encode_varint((FIELD_EXPIRES_AT << 3) | 2)); - output.extend(encode_varint(ts.len() as u64)); - output.extend(ts); - } - Ok(output) -} - -fn decode_base64(value: &str) -> Result, String> { - STANDARD - .decode(value.trim()) - .map_err(|_| "Invalid base64 state payload.".to_string()) -} - -fn read_varint(bytes: &[u8], pos: &mut usize) -> Option { - let mut shift = 0u32; - let mut output = 0u64; - while *pos < bytes.len() { - let byte = bytes[*pos]; - *pos += 1; - output |= ((byte & 0x7f) as u64) << shift; - if byte & 0x80 == 0 { - return Some(output); - } - shift += 7; - if shift > 63 { - return None; - } - } - None -} - -fn encode_varint(mut value: u64) -> Vec { - let mut bytes = Vec::new(); - loop { - let mut byte = (value & 0x7f) as u8; - value >>= 7; - if value != 0 { - byte |= 0x80; - bytes.push(byte); - } else { - bytes.push(byte); - break; - } - } - bytes -} - -fn skip_field(bytes: &[u8], pos: usize, wire_type: u8) -> Result { - let mut cursor = pos; - match wire_type { - 0 => { - let _ = read_varint(bytes, &mut cursor) - .ok_or_else(|| "Invalid protobuf varint.".to_string())?; - Ok(cursor) - } - 1 => Ok(cursor + 8), - 2 => { - let length = read_varint(bytes, &mut cursor) - .ok_or_else(|| "Invalid protobuf length.".to_string())? as usize; - let end = cursor + length; - if end > bytes.len() { - return Err("Invalid protobuf length.".to_string()); - } - Ok(end) - } - 5 => Ok(cursor + 4), - _ => Err("Unsupported protobuf wire type.".to_string()), - } -} - -fn read_length_delimited_string(bytes: &[u8], pos: &mut usize) -> Result { - let length = read_varint(bytes, pos).ok_or_else(|| "Invalid string length.".to_string())? as usize; - let end = *pos + length; - if end > bytes.len() { - return Err("Invalid string length.".to_string()); - } - let value = String::from_utf8_lossy(&bytes[*pos..end]).to_string(); - *pos = end; - Ok(value) -} - -fn parse_timestamp_seconds(data: &[u8]) -> Result, String> { - let mut pos = 0usize; - while pos < data.len() { - let tag = read_varint(data, &mut pos).ok_or_else(|| "Invalid timestamp tag.".to_string())?; - let field_number = tag >> 3; - let wire_type = (tag & 0x07) as u8; - if field_number == FIELD_TIMESTAMP_SECONDS && wire_type == 0 { - let seconds = read_varint(data, &mut pos) - .ok_or_else(|| "Invalid timestamp varint.".to_string())?; - return Ok(Some(seconds as i64)); - } - pos = skip_field(data, pos, wire_type)?; - } - Ok(None) -} - -fn push_length_delimited( - out: &mut Vec, - field_number: u64, - value: &str, -) -> Result<(), String> { - let trimmed = value.trim(); - if trimmed.is_empty() { - return Ok(()); - } - out.extend(encode_varint((field_number << 3) | 2)); - out.extend(encode_varint(trimmed.len() as u64)); - out.extend(trimmed.as_bytes()); - Ok(()) -} - -// 单元测试拆到独立文件,使用 `#[path]` 以保持 `.test.rs` 命名约定。 -#[cfg(test)] -#[path = "protobuf.test.rs"] -mod tests; diff --git a/src-tauri/src/antigravity/protobuf.test.rs b/src-tauri/src/antigravity/protobuf.test.rs deleted file mode 100644 index 6e6eaa0..0000000 --- a/src-tauri/src/antigravity/protobuf.test.rs +++ /dev/null @@ -1,37 +0,0 @@ -use super::extract_token_record; -use super::inject_token_record; -use super::AntigravityTokenRecord; -use time::format_description::well_known::Rfc3339; -use time::OffsetDateTime; - -#[test] -fn extract_empty_returns_none() { - let result = extract_token_record("").expect("empty base64 should be valid"); - assert!(result.is_none()); -} - -#[test] -fn inject_and_extract_roundtrip() { - let expires_at = OffsetDateTime::from_unix_timestamp(1_700_000_000).expect("timestamp"); - let expires_at_text = expires_at.format(&Rfc3339).expect("format"); - let record = AntigravityTokenRecord { - access_token: "ya29.test-token".to_string(), - refresh_token: Some("refresh-token".to_string()), - expired: Some(expires_at_text.clone()), - expires_in: None, - timestamp: None, - email: None, - token_type: Some("Bearer".to_string()), - project_id: None, - source: None, - }; - - let encoded = inject_token_record("", &record).expect("inject"); - let extracted = extract_token_record(&encoded).expect("extract").expect("record"); - - assert_eq!(extracted.access_token, record.access_token); - assert_eq!(extracted.refresh_token, record.refresh_token); - assert_eq!(extracted.token_type, record.token_type); - assert_eq!(extracted.expired, Some(expires_at_text)); - assert_eq!(extracted.source.as_deref(), Some("ide")); -} diff --git a/src-tauri/src/antigravity/quota.rs b/src-tauri/src/antigravity/quota.rs deleted file mode 100644 index 804c746..0000000 --- a/src-tauri/src/antigravity/quota.rs +++ /dev/null @@ -1,195 +0,0 @@ -use serde_json::Value; -use std::time::Duration; - -use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, USER_AGENT}; - -use crate::oauth_util::build_reqwest_client; - -use super::endpoints; -use super::project; -use super::store::AntigravityAccountStore; -use super::types::{AntigravityAccountSummary, AntigravityQuotaItem, AntigravityQuotaSummary}; - -const FETCH_MODELS_PATH: &str = "/v1internal:fetchAvailableModels"; - -pub(crate) async fn fetch_quotas( - store: &AntigravityAccountStore, -) -> Result, String> { - let accounts = store.list_accounts().await?; - let proxy_url = store.app_proxy_url().await; - let mut results = Vec::with_capacity(accounts.len()); - for account in accounts { - match fetch_account_quota(store, &account, proxy_url.as_deref()).await { - Ok(summary) => results.push(summary), - Err(err) => results.push(AntigravityQuotaSummary { - account_id: account.account_id.clone(), - plan_type: None, - quotas: Vec::new(), - error: Some(err), - }), - } - } - Ok(results) -} - -async fn fetch_account_quota( - store: &AntigravityAccountStore, - account: &AntigravityAccountSummary, - proxy_url: Option<&str>, -) -> Result { - let record = store.get_account_record(&account.account_id).await?; - let client = build_reqwest_client(proxy_url, Duration::from_secs(20))?; - let mut project_id = record.project_id.clone(); - let mut plan_type: Option = None; - let mut load_error: Option = None; - match project::load_code_assist_with_client(&client, &record.access_token).await { - Ok(info) => { - plan_type = info.plan_type.clone(); - if let Some(value) = info.project_id.clone() { - project_id = Some(value.clone()); - let _ = store.update_project_id(&account.account_id, value).await; - } else if let Some(tier_id) = info.plan_type.as_deref() { - match project::onboard_user_with_client(&client, &record.access_token, tier_id).await { - Ok(Some(value)) => { - project_id = Some(value.clone()); - let _ = store.update_project_id(&account.account_id, value).await; - } - Ok(None) => {} - Err(err) => load_error = Some(err), - } - } - } - Err(err) => load_error = Some(err), - } - let quotas = match fetch_available_models( - &client, - &record.access_token, - project_id.as_deref(), - ) - .await - { - Ok(quotas) => quotas, - Err(err) => { - if let Some(load_error) = load_error { - return Err(format!("{load_error}; {err}")); - } - return Err(err); - } - }; - if let Some(load_error) = load_error { - tracing::warn!(error = %load_error, "antigravity loadCodeAssist failed for quota"); - } - Ok(AntigravityQuotaSummary { - account_id: account.account_id.clone(), - plan_type, - quotas, - error: None, - }) -} - -async fn fetch_available_models( - client: &reqwest::Client, - access_token: &str, - project_id: Option<&str>, -) -> Result, String> { - let user_agent = endpoints::default_user_agent(); - let payload = if let Some(project_id) = project_id.filter(|value| !value.trim().is_empty()) { - serde_json::json!({ "project": project_id }) - } else { - serde_json::json!({}) - }; - let mut last_error: Option = None; - for base in endpoints::BASE_URLS { - let url = format!("{}{}", base, FETCH_MODELS_PATH); - let response = client - .post(url) - .header(AUTHORIZATION, format!("Bearer {access_token}")) - .header(USER_AGENT, user_agent.as_str()) - .header(CONTENT_TYPE, "application/json") - .json(&payload) - .send() - .await; - let response = match response { - Ok(response) => response, - Err(err) => { - last_error = Some(format!("fetchAvailableModels failed: {err}")); - continue; - } - }; - if !response.status().is_success() { - let status = response.status(); - let body = response.text().await.unwrap_or_default(); - let message = format!("fetchAvailableModels failed: {status} {body}"); - if should_retry_status(status) { - last_error = Some(message); - continue; - } - return Err(message); - } - let value: Value = response - .json() - .await - .map_err(|err| format!("fetchAvailableModels parse failed: {err}"))?; - return Ok(extract_quota_items(&value)); - } - Err(last_error.unwrap_or_else(|| "fetchAvailableModels failed.".to_string())) -} - -fn extract_quota_items(value: &Value) -> Vec { - let mut items = Vec::new(); - let Some(models) = value.get("models") else { - return items; - }; - match models { - Value::Array(list) => { - for model in list { - let Some(name) = model.get("name").and_then(Value::as_str) else { - continue; - }; - if let Some(item) = build_quota_item(name, model.get("quotaInfo")) { - items.push(item); - } - } - } - Value::Object(map) => { - for (name, info) in map { - if let Some(item) = build_quota_item(name, info.get("quotaInfo")) { - items.push(item); - } - } - } - _ => {} - } - items -} - -fn build_quota_item(name: &str, quota_info: Option<&Value>) -> Option { - let name_lower = name.to_lowercase(); - if !name_lower.contains("gemini") && !name_lower.contains("claude") { - return None; - } - let quota = quota_info?.as_object()?; - let remaining_fraction = quota - .get("remainingFraction") - .and_then(Value::as_f64) - .unwrap_or(0.0); - let percentage = (remaining_fraction * 100.0).clamp(0.0, 100.0); - let reset_at = quota - .get("resetTime") - .and_then(Value::as_str) - .map(|value| value.to_string()); - Some(AntigravityQuotaItem { - name: name.to_string(), - percentage, - reset_at, - }) -} - -fn should_retry_status(status: reqwest::StatusCode) -> bool { - status == reqwest::StatusCode::TOO_MANY_REQUESTS || status.is_server_error() -} - -// 单元测试拆到独立文件,使用 `#[path]` 以保持 `.test.rs` 命名约定。 -#[cfg(test)] -#[path = "quota.test.rs"] -mod tests; diff --git a/src-tauri/src/antigravity/quota.test.rs b/src-tauri/src/antigravity/quota.test.rs deleted file mode 100644 index 8aa69cb..0000000 --- a/src-tauri/src/antigravity/quota.test.rs +++ /dev/null @@ -1,32 +0,0 @@ -use super::extract_quota_items; -use serde_json::json; - -#[test] -fn extract_quota_items_from_array_models() { - let value = json!({ - "models": [ - { - "name": "gemini-3-pro", - "quotaInfo": { "remainingFraction": 0.5, "resetTime": "2026-01-01T00:00:00Z" } - } - ] - }); - let items = extract_quota_items(&value); - assert_eq!(items.len(), 1); - assert_eq!(items[0].name, "gemini-3-pro"); - assert_eq!(items[0].percentage, 50.0); -} - -#[test] -fn extract_quota_items_from_map_models() { - let value = json!({ - "models": { - "claude-opus-4": { "quotaInfo": { "remainingFraction": 0.25 } }, - "text-embedding": { "quotaInfo": { "remainingFraction": 0.9 } } - } - }); - let items = extract_quota_items(&value); - assert_eq!(items.len(), 1); - assert_eq!(items[0].name, "claude-opus-4"); - assert_eq!(items[0].percentage, 25.0); -} diff --git a/src-tauri/src/antigravity/store.rs b/src-tauri/src/antigravity/store.rs deleted file mode 100644 index 8f72fe9..0000000 --- a/src-tauri/src/antigravity/store.rs +++ /dev/null @@ -1,265 +0,0 @@ -use std::collections::HashMap; -use std::path::PathBuf; - -use tauri::AppHandle; -use time::OffsetDateTime; -use tokio::sync::RwLock; - -use crate::app_proxy::AppProxyState; -use crate::oauth_util::{expires_at_from_seconds, sanitize_id_part}; -use crate::proxy::config::config_dir_path; - -use super::oauth::AntigravityOAuthClient; -use super::types::{ - AntigravityAccountSummary, AntigravityAccountStatus, AntigravityTokenRecord, -}; - -const ANTIGRAVITY_AUTH_DIR_NAME: &str = "antigravity-auth"; - -pub(crate) struct AntigravityAccountStore { - dir: PathBuf, - cache: RwLock>, - app_proxy: AppProxyState, -} - -impl AntigravityAccountStore { - pub(crate) fn new(app: &AppHandle, app_proxy: AppProxyState) -> Result { - let dir = config_dir_path(app)?.join(ANTIGRAVITY_AUTH_DIR_NAME); - Ok(Self { - dir, - cache: RwLock::new(HashMap::new()), - app_proxy, - }) - } - - pub(crate) async fn list_accounts(&self) -> Result, String> { - self.refresh_cache().await?; - let cache = self.cache.read().await; - let mut items: Vec = cache - .iter() - .map(|(account_id, record)| AntigravityAccountSummary { - account_id: account_id.clone(), - email: record.email.clone(), - expires_at: record.expires_at().map(|value| { - value - .format(&time::format_description::well_known::Rfc3339) - .unwrap_or_else(|_| record.expired.clone().unwrap_or_default()) - }), - status: record.status(), - source: record.source.clone(), - }) - .collect(); - items.sort_by(|left, right| left.account_id.cmp(&right.account_id)); - Ok(items) - } - - pub(crate) async fn get_account_record( - &self, - account_id: &str, - ) -> Result { - let record = self.load_account(account_id).await?; - self.refresh_if_needed(account_id, record).await - } - - pub(crate) async fn save_new_account( - &self, - record: AntigravityTokenRecord, - ) -> Result { - let id_part_source = record - .email - .as_deref() - .or(record.source.as_deref()) - .unwrap_or_default(); - let mut id_part = sanitize_id_part(id_part_source); - if id_part.is_empty() { - id_part = format!("{}", OffsetDateTime::now_utc().unix_timestamp()); - } - let account_id = self.unique_account_id(&id_part).await?; - self.save_record(account_id, record).await - } - - pub(crate) async fn save_record( - &self, - account_id: String, - record: AntigravityTokenRecord, - ) -> Result { - self.ensure_dir().await?; - let path = self.account_path(&account_id); - let payload = serde_json::to_string_pretty(&record) - .map_err(|err| format!("Failed to serialize token record: {err}"))?; - tokio::fs::write(&path, payload) - .await - .map_err(|err| format!("Failed to write token record: {err}"))?; - let mut cache = self.cache.write().await; - cache.insert(account_id.clone(), record.clone()); - Ok(AntigravityAccountSummary { - account_id, - email: record.email.clone(), - expires_at: record.expires_at().map(|value| { - value - .format(&time::format_description::well_known::Rfc3339) - .unwrap_or_else(|_| record.expired.clone().unwrap_or_default()) - }), - status: record.status(), - source: record.source.clone(), - }) - } - - pub(crate) async fn delete_account(&self, account_id: &str) -> Result<(), String> { - let path = self.account_path(account_id); - if tokio::fs::try_exists(&path).await.unwrap_or(false) { - tokio::fs::remove_file(&path) - .await - .map_err(|err| format!("Failed to delete token record: {err}"))?; - } - let mut cache = self.cache.write().await; - cache.remove(account_id); - Ok(()) - } - - pub(crate) async fn update_project_id( - &self, - account_id: &str, - project_id: String, - ) -> Result<(), String> { - let mut record = self.get_account_record(account_id).await?; - record.project_id = Some(project_id); - let _ = self.save_record(account_id.to_string(), record).await?; - Ok(()) - } - - async fn refresh_if_needed( - &self, - account_id: &str, - record: AntigravityTokenRecord, - ) -> Result { - if !record.is_expired() { - return Ok(record); - } - self.refresh_record(account_id, record).await - } - - async fn refresh_record( - &self, - account_id: &str, - record: AntigravityTokenRecord, - ) -> Result { - let refresh_token = record - .refresh_token - .as_deref() - .filter(|value| !value.trim().is_empty()) - .ok_or_else(|| "Antigravity refresh token is missing.".to_string())?; - let proxy_url = self.app_proxy_url().await; - let client = AntigravityOAuthClient::new(proxy_url); - let response = client.refresh_token(refresh_token).await?; - let refreshed = AntigravityTokenRecord { - access_token: response.access_token, - refresh_token: response - .refresh_token - .filter(|value| !value.trim().is_empty()) - .or(record.refresh_token.clone()), - expired: Some(expires_at_from_seconds(response.expires_in)), - expires_in: Some(response.expires_in), - timestamp: Some(OffsetDateTime::now_utc().unix_timestamp() * 1000), - email: record.email.clone(), - token_type: response.token_type.or(record.token_type.clone()), - project_id: record.project_id.clone(), - source: record.source.clone().or_else(|| Some("oauth".to_string())), - }; - let summary = self - .save_record(account_id.to_string(), refreshed.clone()) - .await?; - if matches!(summary.status, AntigravityAccountStatus::Expired) { - return Err("Antigravity token refresh failed.".to_string()); - } - Ok(refreshed) - } - - async fn load_account(&self, account_id: &str) -> Result { - if let Some(record) = self.cache.read().await.get(account_id).cloned() { - return Ok(record); - } - self.refresh_cache().await?; - self.cache - .read() - .await - .get(account_id) - .cloned() - .ok_or_else(|| format!("Antigravity account not found: {account_id}")) - } - - pub(crate) async fn app_proxy_url(&self) -> Option { - self.app_proxy.read().await.clone() - } - - async fn refresh_cache(&self) -> Result<(), String> { - let mut cache = HashMap::new(); - let dir = self.dir.clone(); - let mut entries = match tokio::fs::read_dir(&dir).await { - Ok(entries) => entries, - Err(err) if err.kind() == std::io::ErrorKind::NotFound => { - let mut guard = self.cache.write().await; - guard.clear(); - return Ok(()); - } - Err(err) => return Err(format!("Failed to read Antigravity auth directory: {err}")), - }; - - while let Some(entry) = entries - .next_entry() - .await - .map_err(|err| format!("Failed to read Antigravity auth entry: {err}"))? - { - let path = entry.path(); - if path.extension().and_then(|ext| ext.to_str()) != Some("json") { - continue; - } - let file_name = match path.file_name().and_then(|name| name.to_str()) { - Some(name) => name.to_string(), - None => continue, - }; - let contents = match tokio::fs::read_to_string(&path).await { - Ok(contents) => contents, - Err(_) => continue, - }; - let record: AntigravityTokenRecord = match serde_json::from_str(&contents) { - Ok(record) => record, - Err(_) => continue, - }; - cache.insert(file_name, record); - } - - let mut guard = self.cache.write().await; - *guard = cache; - Ok(()) - } - - async fn ensure_dir(&self) -> Result<(), String> { - tokio::fs::create_dir_all(&self.dir) - .await - .map_err(|err| format!("Failed to create Antigravity auth dir: {err}")) - } - - async fn unique_account_id(&self, id_part: &str) -> Result { - self.ensure_dir().await?; - let mut suffix = 0u32; - loop { - let candidate = if suffix == 0 { - format!("antigravity-{id_part}.json") - } else { - format!("antigravity-{id_part}-{suffix}.json") - }; - if !tokio::fs::try_exists(self.account_path(&candidate)) - .await - .unwrap_or(false) - { - return Ok(candidate); - } - suffix += 1; - } - } - - fn account_path(&self, account_id: &str) -> PathBuf { - self.dir.join(account_id) - } -} diff --git a/src-tauri/src/antigravity/types.rs b/src-tauri/src/antigravity/types.rs deleted file mode 100644 index 4d2174e..0000000 --- a/src-tauri/src/antigravity/types.rs +++ /dev/null @@ -1,137 +0,0 @@ -use serde::{Deserialize, Serialize}; -use time::format_description::well_known::Rfc3339; -use time::OffsetDateTime; - -#[derive(Clone, Serialize, Deserialize)] -pub(crate) struct AntigravityTokenRecord { - pub(crate) access_token: String, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub(crate) refresh_token: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub(crate) expired: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub(crate) expires_in: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub(crate) timestamp: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub(crate) email: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub(crate) token_type: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub(crate) project_id: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub(crate) source: Option, -} - -impl AntigravityTokenRecord { - pub(crate) fn expires_at(&self) -> Option { - if let Some(expired) = self.expired.as_deref() { - if let Ok(value) = OffsetDateTime::parse(expired.trim(), &Rfc3339) { - return Some(value); - } - } - let expires_in = self.expires_in.unwrap_or_default(); - let timestamp = self.timestamp.unwrap_or_default(); - if expires_in <= 0 || timestamp <= 0 { - return None; - } - let expires_at = (timestamp / 1000) + expires_in; - OffsetDateTime::from_unix_timestamp(expires_at).ok() - } - - pub(crate) fn is_expired(&self) -> bool { - let Some(expires_at) = self.expires_at() else { - return true; - }; - OffsetDateTime::now_utc() >= expires_at - } - - pub(crate) fn status(&self) -> AntigravityAccountStatus { - if self.is_expired() { - AntigravityAccountStatus::Expired - } else { - AntigravityAccountStatus::Active - } - } -} - -#[derive(Clone, Serialize)] -#[serde(rename_all = "snake_case")] -pub(crate) enum AntigravityAccountStatus { - Active, - Expired, -} - -#[derive(Clone, Serialize)] -pub(crate) struct AntigravityAccountSummary { - pub(crate) account_id: String, - pub(crate) email: Option, - pub(crate) expires_at: Option, - pub(crate) status: AntigravityAccountStatus, - pub(crate) source: Option, -} - -#[derive(Clone, Serialize, PartialEq, Eq)] -#[serde(rename_all = "snake_case")] -pub(crate) enum AntigravityLoginStatus { - Waiting, - Success, - Error, -} - -#[derive(Clone, Serialize)] -pub(crate) struct AntigravityLoginStartResponse { - pub(crate) state: String, - pub(crate) login_url: String, - pub(crate) interval_seconds: u64, - pub(crate) expires_at: Option, -} - -#[derive(Clone, Serialize)] -pub(crate) struct AntigravityLoginPollResponse { - pub(crate) state: String, - pub(crate) status: AntigravityLoginStatus, - pub(crate) error: Option, - pub(crate) account: Option, -} - -#[derive(Clone, Serialize)] -pub(crate) struct AntigravityQuotaItem { - pub(crate) name: String, - pub(crate) percentage: f64, - pub(crate) reset_at: Option, -} - -#[derive(Clone, Serialize)] -pub(crate) struct AntigravityQuotaSummary { - pub(crate) account_id: String, - pub(crate) plan_type: Option, - pub(crate) quotas: Vec, - pub(crate) error: Option, -} - -#[derive(Clone, Serialize)] -pub(crate) struct AntigravityIdeStatus { - pub(crate) database_available: bool, - pub(crate) ide_running: bool, - pub(crate) active_email: Option, -} - -#[derive(Clone, Serialize, Deserialize)] -pub(crate) struct AntigravityWarmupSchedule { - pub(crate) account_id: String, - pub(crate) model: String, - pub(crate) interval_minutes: u64, - pub(crate) next_run_at: Option, - #[serde(default)] - pub(crate) enabled: bool, -} - -#[derive(Clone, Serialize)] -pub(crate) struct AntigravityWarmupScheduleSummary { - pub(crate) account_id: String, - pub(crate) model: String, - pub(crate) interval_minutes: u64, - pub(crate) next_run_at: Option, - pub(crate) enabled: bool, -} diff --git a/src-tauri/src/antigravity/warmup.rs b/src-tauri/src/antigravity/warmup.rs deleted file mode 100644 index 0173c93..0000000 --- a/src-tauri/src/antigravity/warmup.rs +++ /dev/null @@ -1,262 +0,0 @@ -use std::collections::HashMap; -use std::sync::Arc; -use std::time::Duration; - -use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, USER_AGENT}; -use time::format_description::well_known::Rfc3339; -use time::OffsetDateTime; -use tokio::sync::{Mutex, RwLock}; -use tokio::task::JoinHandle; - -use crate::app_proxy::AppProxyState; -use crate::oauth_util::build_reqwest_client; - -use super::endpoints; -use super::project; -use super::store::AntigravityAccountStore; -use super::types::{AntigravityWarmupSchedule, AntigravityWarmupScheduleSummary}; - -const GENERATE_PATH: &str = "/v1internal:generateContent"; -const STREAM_PATH: &str = "/v1internal:streamGenerateContent"; - -#[derive(Clone)] -pub(crate) struct AntigravityWarmupScheduler { - store: Arc, - app_proxy: AppProxyState, - schedules: Arc>>, - runner: Arc>>>, -} - -impl AntigravityWarmupScheduler { - pub(crate) fn new(store: Arc, app_proxy: AppProxyState) -> Self { - Self { - store, - app_proxy, - schedules: Arc::new(RwLock::new(HashMap::new())), - runner: Arc::new(Mutex::new(None)), - } - } - - pub(crate) async fn start(&self) { - let mut guard = self.runner.lock().await; - if guard.is_some() { - return; - } - let scheduler = self.clone(); - let handle = tokio::spawn(async move { - scheduler.run_loop().await; - }); - *guard = Some(handle); - } - - pub(crate) async fn list_schedules(&self) -> Vec { - let guard = self.schedules.read().await; - let mut items: Vec = guard - .values() - .map(|item| AntigravityWarmupScheduleSummary { - account_id: item.account_id.clone(), - model: item.model.clone(), - interval_minutes: item.interval_minutes, - next_run_at: item.next_run_at.clone(), - enabled: item.enabled, - }) - .collect(); - items.sort_by(|left, right| left.account_id.cmp(&right.account_id)); - items - } - - pub(crate) async fn set_schedule( - &self, - account_id: String, - model: String, - interval_minutes: u64, - enabled: bool, - ) -> Result { - if account_id.trim().is_empty() || model.trim().is_empty() { - return Err("Account ID and model are required.".to_string()); - } - let interval_minutes = interval_minutes.max(1); - let mut guard = self.schedules.write().await; - let key = schedule_key(&account_id, &model); - let next_run_at = if enabled { - Some(format_next_run(interval_minutes)) - } else { - None - }; - let schedule = AntigravityWarmupSchedule { - account_id: account_id.clone(), - model: model.clone(), - interval_minutes, - next_run_at: next_run_at.clone(), - enabled, - }; - guard.insert(key, schedule.clone()); - Ok(AntigravityWarmupScheduleSummary { - account_id, - model, - interval_minutes, - next_run_at, - enabled, - }) - } - - pub(crate) async fn toggle_schedule( - &self, - account_id: String, - model: String, - enabled: bool, - ) -> Result<(), String> { - let mut guard = self.schedules.write().await; - let key = schedule_key(&account_id, &model); - let Some(schedule) = guard.get_mut(&key) else { - return Err("Warmup schedule not found.".to_string()); - }; - schedule.enabled = enabled; - schedule.next_run_at = if enabled { - Some(format_next_run(schedule.interval_minutes)) - } else { - None - }; - Ok(()) - } - - pub(crate) async fn run_warmup( - &self, - account_id: &str, - model: &str, - stream: bool, - ) -> Result<(), String> { - let record = self.store.get_account_record(account_id).await?; - let proxy_url = self.app_proxy.read().await.clone(); - let client = build_reqwest_client(proxy_url.as_deref(), Duration::from_secs(20))?; - let mut project_id = record.project_id.clone(); - if project_id.is_none() { - if let Ok(info) = project::load_code_assist(&record.access_token, proxy_url.as_deref()).await { - if let Some(value) = info.project_id.clone() { - let _ = self.store.update_project_id(account_id, value.clone()).await; - project_id = Some(value); - } else if let Some(tier_id) = info.plan_type.as_deref() { - if let Ok(Some(value)) = - project::onboard_user(&record.access_token, proxy_url.as_deref(), tier_id).await - { - let _ = self.store.update_project_id(account_id, value.clone()).await; - project_id = Some(value); - } - } - } - } - let user_agent = endpoints::default_user_agent(); - let payload = build_warmup_payload(model, project_id.as_deref(), &user_agent); - let path = if stream { STREAM_PATH } else { GENERATE_PATH }; - let mut last_error: Option = None; - for base in endpoints::BASE_URLS { - let url = format!("{}{}", base, path); - let response = client - .post(url) - .header(AUTHORIZATION, format!("Bearer {}", record.access_token)) - .header(USER_AGENT, user_agent.as_str()) - .header(CONTENT_TYPE, "application/json") - .json(&payload) - .send() - .await; - let response = match response { - Ok(response) => response, - Err(err) => { - last_error = Some(format!("Warmup request failed: {err}")); - continue; - } - }; - if !response.status().is_success() { - let status = response.status(); - let body = response.text().await.unwrap_or_default(); - let message = format!("Warmup failed: {status} {body}"); - if status == reqwest::StatusCode::TOO_MANY_REQUESTS || status.is_server_error() { - last_error = Some(message); - continue; - } - return Err(message); - } - return Ok(()); - } - Err(last_error.unwrap_or_else(|| "Warmup failed.".to_string())) - } - - async fn run_loop(&self) { - loop { - let due = self.collect_due().await; - for item in due { - let _ = self - .run_warmup(&item.account_id, &item.model, false) - .await; - self.bump_schedule(&item.account_id, &item.model).await; - } - tokio::time::sleep(Duration::from_secs(30)).await; - } - } - - async fn collect_due(&self) -> Vec { - let now = OffsetDateTime::now_utc(); - let guard = self.schedules.read().await; - guard - .values() - .filter_map(|schedule| { - if !schedule.enabled { - return None; - } - let next_run = schedule - .next_run_at - .as_deref() - .and_then(|value| OffsetDateTime::parse(value, &Rfc3339).ok()); - if next_run.is_some_and(|value| value > now) { - return None; - } - Some(schedule.clone()) - }) - .collect() - } - - async fn bump_schedule(&self, account_id: &str, model: &str) { - let mut guard = self.schedules.write().await; - let key = schedule_key(account_id, model); - if let Some(schedule) = guard.get_mut(&key) { - schedule.next_run_at = Some(format_next_run(schedule.interval_minutes)); - } - } -} - -fn format_next_run(interval_minutes: u64) -> String { - let next = OffsetDateTime::now_utc() + time::Duration::minutes(interval_minutes as i64); - next.format(&Rfc3339).unwrap_or_else(|_| next.unix_timestamp().to_string()) -} - -fn schedule_key(account_id: &str, model: &str) -> String { - format!("{}::{}", account_id.trim(), model.trim()) -} - -fn build_warmup_payload(model: &str, project_id: Option<&str>, user_agent: &str) -> serde_json::Value { - let project = project_id.unwrap_or_default(); - let request_id = format!("agent-{}", OffsetDateTime::now_utc().unix_timestamp()); - serde_json::json!({ - "project": project, - "request": { - "contents": [ - { "role": "user", "parts": [{ "text": "ping" }] } - ], - "generationConfig": { "maxOutputTokens": 1 }, - "toolConfig": { "functionCallingConfig": { "mode": "NONE" } }, - "sessionId": format!("-{}", OffsetDateTime::now_utc().unix_timestamp()) - }, - "model": model, - "requestId": request_id, - "userAgent": user_agent, - "requestType": "agent" - }) -} - -pub(crate) async fn run_blocking(task: F) -> Result -where - F: FnOnce() -> R + Send + 'static, - R: Send + 'static, -{ - tokio::task::spawn_blocking(task).await -} diff --git a/src-tauri/src/codex/login.rs b/src-tauri/src/codex/login.rs deleted file mode 100644 index 71190bf..0000000 --- a/src-tauri/src/codex/login.rs +++ /dev/null @@ -1,265 +0,0 @@ -use std::collections::HashMap; -use std::sync::Arc; -use std::time::Duration; - -use time::OffsetDateTime; -use tokio::sync::RwLock; - -use crate::app_proxy::AppProxyState; -use crate::oauth_util::{expires_at_from_seconds, generate_pkce, generate_state, now_rfc3339}; - -use super::oauth::CodexOAuthClient; -use super::store::CodexAccountStore; -use super::types::{ - CodexAccountSummary, - CodexLoginPollResponse, - CodexLoginStartResponse, - CodexLoginStatus, - CodexTokenRecord, -}; - -const AUTH_CODE_TIMEOUT: Duration = Duration::from_secs(600); -const POLL_INTERVAL_SECONDS: u64 = 2; -const CODEX_CALLBACK_PORT: u16 = 1455; - -#[derive(Clone)] -pub(crate) struct CodexLoginManager { - store: Arc, - sessions: Arc>>, - app_proxy: AppProxyState, -} - -#[derive(Clone)] -struct LoginSession { - status: CodexLoginStatus, - error: Option, - account: Option, - expires_at: Option, -} - -impl CodexLoginManager { - pub(crate) fn new(store: Arc, app_proxy: AppProxyState) -> Self { - Self { - store, - sessions: Arc::new(RwLock::new(HashMap::new())), - app_proxy, - } - } - - pub(crate) async fn start_login(&self) -> Result { - let state = generate_state("codex")?; - let expires_at = Some(OffsetDateTime::now_utc() + time::Duration::seconds(600)); - self.insert_session(&state, expires_at).await; - let (code_verifier, code_challenge) = generate_pkce()?; - let callback = start_auth_code_callback(state.clone()).await?; - let login_url = CodexOAuthClient::build_authorize_url( - &callback.redirect_uri, - &state, - &code_challenge, - ); - let manager = self.clone(); - let state_for_task = state.clone(); - tauri::async_runtime::spawn(async move { - run_auth_code_login(manager, state_for_task, code_verifier, callback).await; - }); - Ok(CodexLoginStartResponse { - state, - login_url, - interval_seconds: POLL_INTERVAL_SECONDS, - expires_at: Some(expires_at_from_seconds(AUTH_CODE_TIMEOUT.as_secs() as i64)), - }) - } - - pub(crate) async fn poll_login(&self, state: &str) -> Result { - let mut guard = self.sessions.write().await; - let session = guard - .get_mut(state) - .ok_or_else(|| "Login session not found.".to_string())?; - if session.status != CodexLoginStatus::Success - && session.status != CodexLoginStatus::Error - && session - .expires_at - .map(|deadline| OffsetDateTime::now_utc() > deadline) - .unwrap_or(false) - { - session.status = CodexLoginStatus::Error; - session.error = Some("Login expired.".to_string()); - } - Ok(CodexLoginPollResponse { - state: state.to_string(), - status: session.status.clone(), - error: session.error.clone(), - account: session.account.clone(), - }) - } - - pub(crate) async fn logout(&self, account_id: &str) -> Result<(), String> { - self.store.delete_account(account_id).await - } - - async fn insert_session(&self, state: &str, expires_at: Option) { - let session = LoginSession { - status: CodexLoginStatus::Waiting, - error: None, - account: None, - expires_at, - }; - let mut guard = self.sessions.write().await; - guard.insert(state.to_string(), session); - } - - async fn complete_session(&self, state: &str, account: CodexAccountSummary) { - let mut guard = self.sessions.write().await; - if let Some(session) = guard.get_mut(state) { - session.status = CodexLoginStatus::Success; - session.error = None; - session.account = Some(account); - } - } - - async fn fail_session(&self, state: &str, message: String) { - let mut guard = self.sessions.write().await; - if let Some(session) = guard.get_mut(state) { - session.status = CodexLoginStatus::Error; - session.error = Some(message); - } - } - - async fn app_proxy_url(&self) -> Option { - self.app_proxy.read().await.clone() - } -} - -struct AuthCodeCallback { - redirect_uri: String, - receiver: tokio::sync::mpsc::Receiver, - shutdown: Option>, -} - -#[derive(Clone)] -struct AuthCodeResult { - code: Option, - state: Option, - error: Option, -} - -async fn start_auth_code_callback(state: String) -> Result { - let (tx, rx) = tokio::sync::mpsc::channel::(1); - let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); - let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{CODEX_CALLBACK_PORT}")) - .await - .map_err(|err| format!("Failed to start callback server: {err}"))?; - let redirect_uri = format!("http://localhost:{CODEX_CALLBACK_PORT}/auth/callback"); - let router = axum::Router::new().route( - "/auth/callback", - axum::routing::get(move |query: axum::extract::Query>| { - let expected_state = state.clone(); - let tx = tx.clone(); - async move { - let code = query.get("code").cloned(); - let state = query.get("state").cloned(); - let error = query.get("error").cloned(); - let has_error = error.is_some(); - let state_matches = state.as_deref() == Some(&expected_state); - let _ = tx.send(AuthCodeResult { code, state, error }).await; - let body = if has_error || !state_matches { - "Login failed. You can close this window." - } else { - "Login successful. You can close this window." - }; - axum::response::Html(body) - } - }), - ); - tauri::async_runtime::spawn(async move { - let _ = axum::serve(listener, router) - .with_graceful_shutdown(async move { - let _ = shutdown_rx.await; - }) - .await; - }); - Ok(AuthCodeCallback { - redirect_uri, - receiver: rx, - shutdown: Some(shutdown_tx), - }) -} - -async fn run_auth_code_login( - manager: CodexLoginManager, - state: String, - code_verifier: String, - mut callback: AuthCodeCallback, -) { - let redirect_uri = callback.redirect_uri.clone(); - let callback_result = match wait_for_auth_code(&mut callback).await { - Ok(result) => result, - Err(err) => { - manager.fail_session(&state, err).await; - return; - } - }; - let code = match extract_auth_code(&state, callback_result) { - Ok(code) => code, - Err(err) => { - manager.fail_session(&state, err).await; - return; - } - }; - let proxy_url = manager.app_proxy_url().await; - let client = match CodexOAuthClient::new(proxy_url.as_deref()) { - Ok(client) => client, - Err(err) => { - manager.fail_session(&state, err).await; - return; - } - }; - let token = match client - .exchange_code(&code, &code_verifier, &redirect_uri) - .await - { - Ok(token) => token, - Err(err) => { - manager.fail_session(&state, err).await; - return; - } - }; - let record = CodexTokenRecord { - access_token: token.access_token, - refresh_token: token.refresh_token, - id_token: token.id_token, - account_id: None, - email: None, - expires_at: expires_at_from_seconds(token.expires_in), - last_refresh: Some(now_rfc3339()), - }; - match manager.store.save_new_account(record).await { - Ok(account) => manager.complete_session(&state, account).await, - Err(err) => manager.fail_session(&state, err).await, - } -} - -async fn wait_for_auth_code(callback: &mut AuthCodeCallback) -> Result { - let shutdown = callback.shutdown.take(); - let result = tokio::time::timeout(AUTH_CODE_TIMEOUT, callback.receiver.recv()).await; - if let Some(shutdown) = shutdown { - let _ = shutdown.send(()); - } - match result { - Ok(Some(callback)) => Ok(callback), - Ok(None) => Err("Authorization callback closed.".to_string()), - Err(_) => Err("Authorization timed out.".to_string()), - } -} - -fn extract_auth_code(state: &str, callback_result: AuthCodeResult) -> Result { - if let Some(err) = callback_result.error { - return Err(err); - } - if callback_result.state.as_deref() != Some(state) { - return Err("OAuth state mismatch.".to_string()); - } - callback_result - .code - .ok_or_else(|| "Authorization code missing.".to_string()) -} diff --git a/src-tauri/src/codex/oauth.rs b/src-tauri/src/codex/oauth.rs deleted file mode 100644 index 3aa011c..0000000 --- a/src-tauri/src/codex/oauth.rs +++ /dev/null @@ -1,140 +0,0 @@ -use reqwest::Client; -use serde::{Deserialize, Serialize}; - -use crate::oauth_util::build_reqwest_client; - -const OPENAI_AUTH_URL: &str = "https://auth.openai.com/oauth/authorize"; -const OPENAI_TOKEN_URL: &str = "https://auth.openai.com/oauth/token"; -const OPENAI_CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann"; - -#[derive(Clone)] -pub(crate) struct CodexOAuthClient { - http: Client, -} - -impl CodexOAuthClient { - pub(crate) fn new(proxy_url: Option<&str>) -> Result { - let http = build_reqwest_client(proxy_url, std::time::Duration::from_secs(30)) - .map_err(|err| format!("Failed to build Codex OAuth client: {err}"))?; - Ok(Self { http }) - } - - pub(crate) fn build_authorize_url( - redirect_uri: &str, - state: &str, - code_challenge: &str, - ) -> String { - let query = url::form_urlencoded::Serializer::new(String::new()) - .append_pair("client_id", OPENAI_CLIENT_ID) - .append_pair("response_type", "code") - .append_pair("redirect_uri", redirect_uri) - .append_pair("scope", "openid email profile offline_access") - .append_pair("state", state) - .append_pair("code_challenge", code_challenge) - .append_pair("code_challenge_method", "S256") - .append_pair("prompt", "login") - .append_pair("id_token_add_organizations", "true") - .append_pair("codex_cli_simplified_flow", "true") - .finish(); - format!("{OPENAI_AUTH_URL}?{query}") - } - - pub(crate) async fn exchange_code( - &self, - code: &str, - code_verifier: &str, - redirect_uri: &str, - ) -> Result { - let payload = TokenExchangeRequest { - grant_type: "authorization_code".to_string(), - client_id: OPENAI_CLIENT_ID.to_string(), - code: code.to_string(), - redirect_uri: redirect_uri.to_string(), - code_verifier: code_verifier.to_string(), - refresh_token: None, - scope: None, - }; - self.post_form(payload).await - } - - pub(crate) async fn refresh_token(&self, refresh_token: &str) -> Result { - let payload = TokenExchangeRequest { - grant_type: "refresh_token".to_string(), - client_id: OPENAI_CLIENT_ID.to_string(), - code: String::new(), - redirect_uri: String::new(), - code_verifier: String::new(), - refresh_token: Some(refresh_token.to_string()), - scope: Some("openid profile email".to_string()), - }; - self.post_form(payload).await - } - - async fn post_form(&self, payload: TokenExchangeRequest) -> Result { - let body = { - let mut form = url::form_urlencoded::Serializer::new(String::new()); - form.append_pair("grant_type", &payload.grant_type) - .append_pair("client_id", &payload.client_id); - if !payload.code.is_empty() { - form.append_pair("code", &payload.code); - } - if !payload.redirect_uri.is_empty() { - form.append_pair("redirect_uri", &payload.redirect_uri); - } - if !payload.code_verifier.is_empty() { - form.append_pair("code_verifier", &payload.code_verifier); - } - if let Some(refresh_token) = payload.refresh_token.as_deref() { - form.append_pair("refresh_token", refresh_token); - } - if let Some(scope) = payload.scope.as_deref() { - form.append_pair("scope", scope); - } - form.finish() - }; - - let response = self - .http - .post(OPENAI_TOKEN_URL) - .header("Content-Type", "application/x-www-form-urlencoded") - .header("Accept", "application/json") - .body(body) - .send() - .await - .map_err(|err| format!("Codex OAuth request failed: {err}"))?; - let status = response.status(); - let bytes = response - .bytes() - .await - .map_err(|err| format!("Failed to read Codex OAuth response: {err}"))?; - if !status.is_success() { - let body = String::from_utf8_lossy(&bytes); - return Err(format!( - "Codex OAuth request failed (status {}): {}", - status.as_u16(), - body - )); - } - serde_json::from_slice(&bytes) - .map_err(|err| format!("Failed to parse Codex OAuth response: {err}")) - } -} - -#[derive(Serialize)] -struct TokenExchangeRequest { - grant_type: String, - client_id: String, - code: String, - redirect_uri: String, - code_verifier: String, - refresh_token: Option, - scope: Option, -} - -#[derive(Clone, Deserialize)] -pub(crate) struct CodexTokenResponse { - pub(crate) access_token: String, - pub(crate) refresh_token: String, - pub(crate) id_token: String, - pub(crate) expires_in: i64, -} diff --git a/src-tauri/src/codex/quota.rs b/src-tauri/src/codex/quota.rs deleted file mode 100644 index d41bf5b..0000000 --- a/src-tauri/src/codex/quota.rs +++ /dev/null @@ -1,297 +0,0 @@ -use serde::{Deserialize, Serialize}; -use std::error::Error as StdError; -use std::time::Duration; -use time::format_description::well_known::Rfc3339; -use time::OffsetDateTime; - -use reqwest::{Client, Proxy}; - -use crate::oauth_util::build_reqwest_client; - -use super::store::CodexAccountStore; -use super::types::CodexAccountSummary; - -const CODEX_USAGE_ENDPOINT: &str = "https://chatgpt.com/backend-api/wham/usage"; -// Match Codex CLI UA to avoid edge filtering on some proxies. -const CODEX_USER_AGENT: &str = "codex_cli_rs/0.50.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"; - -#[derive(Clone, Serialize)] -pub(crate) struct CodexQuotaItem { - pub(crate) name: String, - pub(crate) percentage: f64, - pub(crate) used: Option, - pub(crate) limit: Option, - pub(crate) reset_at: Option, -} - -#[derive(Clone, Serialize)] -pub(crate) struct CodexQuotaSummary { - pub(crate) account_id: String, - pub(crate) plan_type: Option, - pub(crate) quotas: Vec, - pub(crate) error: Option, -} - -pub(crate) async fn fetch_quotas( - store: &CodexAccountStore, -) -> Result, String> { - let accounts = store.list_accounts().await?; - let proxy_url = store.app_proxy_url().await; - let mut results = Vec::with_capacity(accounts.len()); - for account in accounts { - match fetch_account_quota(store, &account, proxy_url.as_deref()).await { - Ok(summary) => results.push(summary), - Err(err) => results.push(CodexQuotaSummary { - account_id: account.account_id.clone(), - plan_type: None, - quotas: Vec::new(), - error: Some(err), - }), - } - } - Ok(results) -} - -async fn fetch_account_quota( - store: &CodexAccountStore, - account: &CodexAccountSummary, - proxy_url: Option<&str>, -) -> Result { - let record = store.get_account_record(&account.account_id).await?; - let response = request_usage(&record.access_token, record.account_id.as_deref(), proxy_url).await?; - Ok(map_usage_response(account, response)) -} - -async fn request_usage( - access_token: &str, - chatgpt_account_id: Option<&str>, - proxy_url: Option<&str>, -) -> Result { - let attempts = build_usage_attempts(proxy_url); - let mut send_errors = Vec::new(); - - for attempt in attempts { - match request_usage_once(access_token, chatgpt_account_id, &attempt).await { - Ok(response) => return Ok(response), - Err(UsageRequestError::Send(err)) => { - send_errors.push(format!( - "{}: {}", - attempt.label, - format_reqwest_error(&err) - )); - } - Err(err) => { - return Err(format!( - "Codex usage request failed: {}", - format_usage_error(err) - )); - } - } - } - - let detail = if send_errors.is_empty() { - "unknown error".to_string() - } else { - send_errors.join(" | ") - }; - Err(format!("Codex usage request failed: {detail}")) -} - -fn map_usage_response( - account: &CodexAccountSummary, - response: CodexUsageResponse, -) -> CodexQuotaSummary { - let mut quotas = Vec::new(); - if let Some(rate_limit) = response.rate_limit { - if let Some(item) = build_window_quota("codex-session", rate_limit.primary_window) { - quotas.push(item); - } - if let Some(item) = build_window_quota("codex-weekly", rate_limit.secondary_window) { - quotas.push(item); - } - } - - CodexQuotaSummary { - account_id: account.account_id.clone(), - plan_type: response.plan_type, - quotas, - error: None, - } -} - -fn build_window_quota(name: &str, window: Option) -> Option { - let window = window?; - let used_percent = window.used_percent?; - let percentage = (100.0 - used_percent).clamp(0.0, 100.0); - Some(CodexQuotaItem { - name: name.to_string(), - percentage, - used: None, - limit: None, - reset_at: window.reset_at.and_then(reset_at_from_seconds), - }) -} - -fn reset_at_from_seconds(seconds: i64) -> Option { - let value = OffsetDateTime::from_unix_timestamp(seconds).ok()?; - Some(value.format(&Rfc3339).unwrap_or_else(|_| seconds.to_string())) -} - -async fn request_usage_once( - access_token: &str, - chatgpt_account_id: Option<&str>, - attempt: &UsageAttempt, -) -> Result { - let http = build_usage_client(attempt.proxy_url.as_deref(), attempt.http1_only) - .map_err(UsageRequestError::Build)?; - let mut request = http - .get(CODEX_USAGE_ENDPOINT) - .header("Authorization", format!("Bearer {access_token}")) - .header("Accept", "application/json") - .header("User-Agent", CODEX_USER_AGENT); - if let Some(account_id) = chatgpt_account_id.filter(|value| !value.trim().is_empty()) { - request = request.header("ChatGPT-Account-Id", account_id); - } - let response = request.send().await.map_err(UsageRequestError::Send)?; - let status = response.status(); - let bytes = response - .bytes() - .await - .map_err(|err| UsageRequestError::Decode(format!("Failed to read response: {err}")))?; - if !status.is_success() { - let body = String::from_utf8_lossy(&bytes); - return Err(UsageRequestError::Status(status.as_u16(), body.to_string())); - } - serde_json::from_slice(&bytes) - .map_err(|err| UsageRequestError::Decode(format!("Invalid response: {err}"))) -} - -fn build_usage_client(proxy_url: Option<&str>, http1_only: bool) -> Result { - if !http1_only { - return build_reqwest_client(proxy_url, Duration::from_secs(30)) - .map_err(|err| format!("Failed to build Codex usage client: {err}")); - } - - let mut builder = Client::builder().timeout(Duration::from_secs(30)); - let proxy_url = proxy_url.map(str::trim).filter(|value| !value.is_empty()); - if let Some(proxy_url) = proxy_url { - let proxy = Proxy::all(proxy_url) - .map_err(|_| "app_proxy_url is not a valid URL.".to_string())?; - builder = builder.proxy(proxy); - } - builder - .http1_only() - .build() - .map_err(|err| format!("Failed to build Codex usage client: {err}")) -} - -fn build_usage_attempts(proxy_url: Option<&str>) -> Vec { - let mut attempts = Vec::new(); - attempts.push(UsageAttempt { - label: "primary", - proxy_url: proxy_url.map(|value| value.to_string()), - http1_only: false, - }); - - if let Some(proxy_url) = proxy_url { - if let Some(upgraded) = upgrade_socks5(proxy_url) { - attempts.push(UsageAttempt { - label: "socks5h", - proxy_url: Some(upgraded), - http1_only: false, - }); - } - attempts.push(UsageAttempt { - label: "http1", - proxy_url: Some(proxy_url.to_string()), - http1_only: true, - }); - } - - attempts -} - -fn upgrade_socks5(proxy_url: &str) -> Option { - let value = proxy_url.trim(); - if value.starts_with("socks5h://") { - return None; - } - if value.starts_with("socks5://") { - return Some(value.replacen("socks5://", "socks5h://", 1)); - } - None -} - -fn format_usage_error(err: UsageRequestError) -> String { - match err { - UsageRequestError::Build(message) => message, - UsageRequestError::Send(err) => format_reqwest_error(&err), - UsageRequestError::Status(status, body) => { - format!("status {status}: {body}") - } - UsageRequestError::Decode(message) => message, - } -} - -fn format_reqwest_error(err: &reqwest::Error) -> String { - let mut details = vec![err.to_string()]; - let mut flags = Vec::new(); - if err.is_timeout() { - flags.push("timeout"); - } - if err.is_connect() { - flags.push("connect"); - } - if err.is_request() { - flags.push("request"); - } - if err.is_builder() { - flags.push("builder"); - } - if !flags.is_empty() { - details.push(format!("flags=[{}]", flags.join(","))); - } - - let mut source = err.source(); - let mut depth = 0; - while let Some(cause) = source { - if depth >= 4 { - break; - } - details.push(format!("cause: {cause}")); - source = cause.source(); - depth += 1; - } - details.join(" | ") -} - -struct UsageAttempt { - label: &'static str, - proxy_url: Option, - http1_only: bool, -} - -enum UsageRequestError { - Build(String), - Send(reqwest::Error), - Status(u16, String), - Decode(String), -} - -#[derive(Deserialize)] -struct CodexUsageResponse { - plan_type: Option, - rate_limit: Option, -} - -#[derive(Deserialize)] -struct CodexRateLimit { - primary_window: Option, - secondary_window: Option, -} - -#[derive(Deserialize)] -struct CodexRateWindow { - used_percent: Option, - reset_at: Option, -} diff --git a/src-tauri/src/codex/store.rs b/src-tauri/src/codex/store.rs deleted file mode 100644 index af919be..0000000 --- a/src-tauri/src/codex/store.rs +++ /dev/null @@ -1,265 +0,0 @@ -use std::collections::HashMap; -use std::path::PathBuf; - -use tauri::AppHandle; -use time::OffsetDateTime; -use tokio::sync::RwLock; - -use crate::app_proxy::AppProxyState; -use crate::oauth_util::{ - expires_at_from_seconds, - extract_chatgpt_account_id_from_jwt, - extract_email_from_jwt, - now_rfc3339, - sanitize_id_part, -}; -use crate::proxy::config::config_dir_path; - -use super::oauth::CodexOAuthClient; -use super::types::{CodexAccountStatus, CodexAccountSummary, CodexTokenRecord}; - -const CODEX_AUTH_DIR_NAME: &str = "codex-auth"; - -pub(crate) struct CodexAccountStore { - dir: PathBuf, - cache: RwLock>, - app_proxy: AppProxyState, -} - -impl CodexAccountStore { - pub(crate) fn new(app: &AppHandle, app_proxy: AppProxyState) -> Result { - let dir = config_dir_path(app)?.join(CODEX_AUTH_DIR_NAME); - Ok(Self { - dir, - cache: RwLock::new(HashMap::new()), - app_proxy, - }) - } - - pub(crate) async fn list_accounts(&self) -> Result, String> { - self.refresh_cache().await?; - let cache = self.cache.read().await; - let mut items: Vec = cache - .iter() - .map(|(account_id, record)| CodexAccountSummary { - account_id: account_id.clone(), - email: record.email.clone(), - expires_at: record.expires_at().map(|value| { - value - .format(&time::format_description::well_known::Rfc3339) - .unwrap_or_else(|_| record.expires_at.clone()) - }), - status: record.status(), - }) - .collect(); - items.sort_by(|left, right| left.account_id.cmp(&right.account_id)); - Ok(items) - } - - pub(crate) async fn get_account_record( - &self, - account_id: &str, - ) -> Result { - let record = self.load_account(account_id).await?; - self.refresh_if_needed(account_id, record).await - } - - pub(crate) async fn save_record( - &self, - account_id: String, - record: CodexTokenRecord, - ) -> Result { - self.ensure_dir().await?; - let path = self.account_path(&account_id); - let payload = serde_json::to_string_pretty(&record) - .map_err(|err| format!("Failed to serialize token record: {err}"))?; - tokio::fs::write(&path, payload) - .await - .map_err(|err| format!("Failed to write token record: {err}"))?; - let mut cache = self.cache.write().await; - cache.insert(account_id.clone(), record.clone()); - Ok(CodexAccountSummary { - account_id, - email: record.email.clone(), - expires_at: record.expires_at().map(|value| { - value - .format(&time::format_description::well_known::Rfc3339) - .unwrap_or_else(|_| record.expires_at.clone()) - }), - status: record.status(), - }) - } - - pub(crate) async fn save_new_account( - &self, - mut record: CodexTokenRecord, - ) -> Result { - fill_record_from_jwt(&mut record); - let id_part_source = record - .email - .as_deref() - .or(record.account_id.as_deref()) - .unwrap_or_default(); - let mut id_part = sanitize_id_part(id_part_source); - if id_part.is_empty() { - id_part = format!("{}", OffsetDateTime::now_utc().unix_timestamp()); - } - let account_id = self.unique_account_id(&id_part).await?; - self.save_record(account_id, record).await - } - - pub(crate) async fn delete_account(&self, account_id: &str) -> Result<(), String> { - let path = self.account_path(account_id); - if tokio::fs::try_exists(&path).await.unwrap_or(false) { - tokio::fs::remove_file(&path) - .await - .map_err(|err| format!("Failed to delete token record: {err}"))?; - } - let mut cache = self.cache.write().await; - cache.remove(account_id); - Ok(()) - } - - async fn refresh_if_needed( - &self, - account_id: &str, - record: CodexTokenRecord, - ) -> Result { - if !record.is_expired() { - return Ok(record); - } - self.refresh_record(account_id, record).await - } - - async fn refresh_record( - &self, - account_id: &str, - record: CodexTokenRecord, - ) -> Result { - let proxy_url = self.app_proxy_url().await; - let client = CodexOAuthClient::new(proxy_url.as_deref())?; - let response = client.refresh_token(&record.refresh_token).await?; - let mut refreshed = CodexTokenRecord { - access_token: response.access_token, - refresh_token: if response.refresh_token.trim().is_empty() { - record.refresh_token.clone() - } else { - response.refresh_token - }, - id_token: if response.id_token.trim().is_empty() { - record.id_token.clone() - } else { - response.id_token - }, - account_id: record.account_id.clone(), - email: record.email.clone(), - expires_at: expires_at_from_seconds(response.expires_in), - last_refresh: Some(now_rfc3339()), - }; - fill_record_from_jwt(&mut refreshed); - let summary = self - .save_record(account_id.to_string(), refreshed.clone()) - .await?; - if matches!(summary.status, CodexAccountStatus::Expired) { - return Err("Codex token refresh failed.".to_string()); - } - Ok(refreshed) - } - - async fn load_account(&self, account_id: &str) -> Result { - if let Some(record) = self.cache.read().await.get(account_id).cloned() { - return Ok(record); - } - self.refresh_cache().await?; - self.cache - .read() - .await - .get(account_id) - .cloned() - .ok_or_else(|| format!("Codex account not found: {account_id}")) - } - - pub(crate) async fn app_proxy_url(&self) -> Option { - self.app_proxy.read().await.clone() - } - - async fn refresh_cache(&self) -> Result<(), String> { - let mut cache = HashMap::new(); - let dir = self.dir.clone(); - let mut entries = match tokio::fs::read_dir(&dir).await { - Ok(entries) => entries, - Err(err) if err.kind() == std::io::ErrorKind::NotFound => { - let mut guard = self.cache.write().await; - guard.clear(); - return Ok(()); - } - Err(err) => return Err(format!("Failed to read Codex auth directory: {err}")), - }; - - while let Some(entry) = entries - .next_entry() - .await - .map_err(|err| format!("Failed to read Codex auth entry: {err}"))? - { - let path = entry.path(); - if path.extension().and_then(|ext| ext.to_str()) != Some("json") { - continue; - } - let file_name = match path.file_name().and_then(|name| name.to_str()) { - Some(name) => name.to_string(), - None => continue, - }; - let contents = match tokio::fs::read_to_string(&path).await { - Ok(contents) => contents, - Err(_) => continue, - }; - let record: CodexTokenRecord = match serde_json::from_str(&contents) { - Ok(record) => record, - Err(_) => continue, - }; - cache.insert(file_name, record); - } - - let mut guard = self.cache.write().await; - *guard = cache; - Ok(()) - } - - async fn ensure_dir(&self) -> Result<(), String> { - tokio::fs::create_dir_all(&self.dir) - .await - .map_err(|err| format!("Failed to create Codex auth dir: {err}")) - } - - async fn unique_account_id(&self, id_part: &str) -> Result { - self.ensure_dir().await?; - let mut suffix = 0u32; - loop { - let candidate = if suffix == 0 { - format!("codex-{id_part}.json") - } else { - format!("codex-{id_part}-{suffix}.json") - }; - if !tokio::fs::try_exists(self.account_path(&candidate)) - .await - .unwrap_or(false) - { - return Ok(candidate); - } - suffix += 1; - } - } - - fn account_path(&self, account_id: &str) -> PathBuf { - self.dir.join(account_id) - } -} - -fn fill_record_from_jwt(record: &mut CodexTokenRecord) { - if record.account_id.is_none() { - record.account_id = extract_chatgpt_account_id_from_jwt(&record.id_token); - } - if record.email.is_none() { - record.email = extract_email_from_jwt(&record.id_token); - } -} diff --git a/src-tauri/src/codex/types.rs b/src-tauri/src/codex/types.rs deleted file mode 100644 index c3e9991..0000000 --- a/src-tauri/src/codex/types.rs +++ /dev/null @@ -1,78 +0,0 @@ -use serde::{Deserialize, Serialize}; -use time::format_description::well_known::Rfc3339; -use time::OffsetDateTime; - -#[derive(Clone, Serialize, Deserialize)] -pub(crate) struct CodexTokenRecord { - pub(crate) access_token: String, - pub(crate) refresh_token: String, - pub(crate) id_token: String, - pub(crate) account_id: Option, - pub(crate) email: Option, - pub(crate) expires_at: String, - pub(crate) last_refresh: Option, -} - -impl CodexTokenRecord { - pub(crate) fn expires_at(&self) -> Option { - let value = self.expires_at.trim(); - if value.is_empty() { - return None; - } - OffsetDateTime::parse(value, &Rfc3339).ok() - } - - pub(crate) fn is_expired(&self) -> bool { - let Some(expires_at) = self.expires_at() else { - return true; - }; - OffsetDateTime::now_utc() >= expires_at - } - - pub(crate) fn status(&self) -> CodexAccountStatus { - if self.is_expired() { - CodexAccountStatus::Expired - } else { - CodexAccountStatus::Active - } - } -} - -#[derive(Clone, Serialize)] -#[serde(rename_all = "snake_case")] -pub(crate) enum CodexAccountStatus { - Active, - Expired, -} - -#[derive(Clone, Serialize)] -pub(crate) struct CodexAccountSummary { - pub(crate) account_id: String, - pub(crate) email: Option, - pub(crate) expires_at: Option, - pub(crate) status: CodexAccountStatus, -} - -#[derive(Clone, Serialize, PartialEq, Eq)] -#[serde(rename_all = "snake_case")] -pub(crate) enum CodexLoginStatus { - Waiting, - Success, - Error, -} - -#[derive(Clone, Serialize)] -pub(crate) struct CodexLoginStartResponse { - pub(crate) state: String, - pub(crate) login_url: String, - pub(crate) interval_seconds: u64, - pub(crate) expires_at: Option, -} - -#[derive(Clone, Serialize)] -pub(crate) struct CodexLoginPollResponse { - pub(crate) state: String, - pub(crate) status: CodexLoginStatus, - pub(crate) error: Option, - pub(crate) account: Option, -} diff --git a/src-tauri/src/kiro/callback.rs b/src-tauri/src/kiro/callback.rs deleted file mode 100644 index 9cf930a..0000000 --- a/src-tauri/src/kiro/callback.rs +++ /dev/null @@ -1,62 +0,0 @@ -use serde::{Deserialize, Serialize}; -use std::path::{Path, PathBuf}; - -const CALLBACK_FILE_PREFIX: &str = ".oauth-kiro-"; -const CALLBACK_FILE_SUFFIX: &str = ".oauth"; - -#[derive(Clone, Serialize, Deserialize)] -pub(crate) struct OAuthCallbackPayload { - pub(crate) code: Option, - pub(crate) state: Option, - pub(crate) error: Option, -} - -pub(crate) fn parse_callback_url(url: &str) -> Result { - let parsed = url::Url::parse(url).map_err(|err| format!("Invalid callback URL: {err}"))?; - let mut payload = OAuthCallbackPayload { - code: None, - state: None, - error: None, - }; - for (key, value) in parsed.query_pairs() { - match key.as_ref() { - "code" => payload.code = Some(value.to_string()), - "state" => payload.state = Some(value.to_string()), - "error" => payload.error = Some(value.to_string()), - _ => {} - } - } - Ok(payload) -} - -pub(crate) fn callback_file_path(dir: &Path, state: &str) -> PathBuf { - dir.join(format!("{CALLBACK_FILE_PREFIX}{state}{CALLBACK_FILE_SUFFIX}")) -} - -pub(crate) async fn write_callback_file( - dir: &Path, - payload: &OAuthCallbackPayload, -) -> Result { - let state = payload - .state - .as_deref() - .ok_or_else(|| "Missing state in callback payload.".to_string())?; - let path = callback_file_path(dir, state); - tokio::fs::create_dir_all(dir) - .await - .map_err(|err| format!("Failed to create callback dir: {err}"))?; - let content = serde_json::to_string(payload) - .map_err(|err| format!("Failed to serialize callback payload: {err}"))?; - tokio::fs::write(&path, content) - .await - .map_err(|err| format!("Failed to write callback file: {err}"))?; - Ok(path) -} - -pub(crate) async fn read_callback_file(path: &Path) -> Result { - let content = tokio::fs::read_to_string(path) - .await - .map_err(|err| format!("Failed to read callback file: {err}"))?; - serde_json::from_str(&content) - .map_err(|err| format!("Failed to parse callback file: {err}")) -} diff --git a/src-tauri/src/kiro/login.rs b/src-tauri/src/kiro/login.rs deleted file mode 100644 index aca9532..0000000 --- a/src-tauri/src/kiro/login.rs +++ /dev/null @@ -1,523 +0,0 @@ -use std::collections::HashMap; -use std::sync::Arc; -use std::time::Duration; - -use time::OffsetDateTime; -use tokio::sync::RwLock; - -use super::callback::{callback_file_path, parse_callback_url, read_callback_file, write_callback_file}; -use super::oauth::{build_login_url, KiroOAuthClient}; -use super::sso_oidc::{build_auth_code_url, CreateTokenResponse, RegisterClientResponse, SsoOidcClient, StartDeviceAuthResponse, TokenPollError}; -use super::store::KiroAccountStore; -use super::types::{KiroAccountSummary, KiroLoginMethod, KiroLoginPollResponse, KiroLoginStartResponse, KiroLoginStatus, KiroTokenRecord}; -use super::util::{expires_at_from_seconds, generate_pkce, generate_state, now_rfc3339}; -use crate::app_proxy::AppProxyState; - -const SOCIAL_CALLBACK_TIMEOUT: Duration = Duration::from_secs(300); -const AUTH_CODE_TIMEOUT: Duration = Duration::from_secs(600); -const KIRO_REDIRECT_URI: &str = "kiro://kiro.kiroAgent/authenticate-success"; - -#[derive(Clone)] -pub(crate) struct KiroLoginManager { - store: Arc, - sessions: Arc>>, - app_proxy: AppProxyState, -} - -#[derive(Clone)] -struct LoginSession { - status: KiroLoginStatus, - error: Option, - account: Option, - expires_at: Option, -} - -impl KiroLoginManager { - pub(crate) fn new(store: Arc, app_proxy: AppProxyState) -> Self { - Self { - store, - sessions: Arc::new(RwLock::new(HashMap::new())), - app_proxy, - } - } - - pub(crate) async fn start_login( - &self, - method: KiroLoginMethod, - ) -> Result { - let state = generate_state("kiro")?; - let expires_at = Some(OffsetDateTime::now_utc() + time::Duration::seconds(600)); - self.insert_session(&state, expires_at).await; - match method { - KiroLoginMethod::Aws => self.start_device_code_login(state, method).await, - KiroLoginMethod::AwsAuthcode => self.start_auth_code_login(state, method).await, - KiroLoginMethod::Google => self.start_social_login(state, method).await, - } - } - - pub(crate) async fn poll_login(&self, state: &str) -> Result { - let mut guard = self.sessions.write().await; - let session = guard - .get_mut(state) - .ok_or_else(|| "Login session not found.".to_string())?; - if session.status != KiroLoginStatus::Success - && session.status != KiroLoginStatus::Error - && session - .expires_at - .map(|deadline| OffsetDateTime::now_utc() > deadline) - .unwrap_or(false) - { - session.status = KiroLoginStatus::Error; - session.error = Some("Login expired.".to_string()); - } - Ok(KiroLoginPollResponse { - state: state.to_string(), - status: session.status.clone(), - error: session.error.clone(), - account: session.account.clone(), - }) - } - - pub(crate) async fn logout(&self, account_id: &str) -> Result<(), String> { - self.store.delete_account(account_id).await - } - - pub(crate) async fn handle_callback_url(&self, url: &str) -> Result<(), String> { - let payload = parse_callback_url(url)?; - write_callback_file(self.store.dir(), &payload).await?; - Ok(()) - } - - async fn start_device_code_login( - &self, - state: String, - method: KiroLoginMethod, - ) -> Result { - let proxy_url = self.app_proxy_url().await; - let client = SsoOidcClient::new(proxy_url.as_deref())?; - let reg = client.register_client().await?; - let auth = client - .start_device_authorization(®.client_id, ®.client_secret) - .await?; - let manager = self.clone(); - let state_for_task = state.clone(); - let verification_uri = auth.verification_uri.clone(); - let verification_uri_complete = auth.verification_uri_complete.clone(); - let user_code = auth.user_code.clone(); - let interval_seconds = auth.interval as u64; - let expires_at = expires_at_from_seconds(auth.expires_in); - tauri::async_runtime::spawn(async move { - run_device_code_login(manager, state_for_task, reg, auth).await; - }); - Ok(KiroLoginStartResponse { - state, - method, - login_url: None, - verification_uri: Some(verification_uri), - verification_uri_complete: Some(verification_uri_complete), - user_code: Some(user_code), - interval_seconds: Some(interval_seconds), - expires_at: Some(expires_at), - }) - } - - async fn start_auth_code_login( - &self, - state: String, - method: KiroLoginMethod, - ) -> Result { - let (code_verifier, code_challenge) = generate_pkce()?; - let callback = start_auth_code_callback(state.clone()).await?; - let proxy_url = self.app_proxy_url().await; - let client = SsoOidcClient::new(proxy_url.as_deref())?; - let reg = client - .register_client_for_auth_code(&callback.redirect_uri) - .await?; - let login_url = build_auth_code_url( - ®.client_id, - &callback.redirect_uri, - &state, - &code_challenge, - ); - let manager = self.clone(); - let state_for_task = state.clone(); - tauri::async_runtime::spawn(async move { - run_auth_code_login(manager, state_for_task, reg, code_verifier, callback).await; - }); - Ok(KiroLoginStartResponse { - state, - method, - login_url: Some(login_url), - verification_uri: None, - verification_uri_complete: None, - user_code: None, - interval_seconds: None, - expires_at: Some(expires_at_from_seconds(AUTH_CODE_TIMEOUT.as_secs() as i64)), - }) - } - - async fn start_social_login( - &self, - state: String, - method: KiroLoginMethod, - ) -> Result { - let provider = match method { - KiroLoginMethod::Google => "Google", - _ => "", - }; - let (code_verifier, code_challenge) = generate_pkce()?; - let login_url = build_login_url(provider, KIRO_REDIRECT_URI, &code_challenge, &state); - let manager = self.clone(); - let state_for_task = state.clone(); - tauri::async_runtime::spawn(async move { - run_social_login(manager, state_for_task, provider.to_string(), code_verifier).await; - }); - Ok(KiroLoginStartResponse { - state, - method, - login_url: Some(login_url), - verification_uri: None, - verification_uri_complete: None, - user_code: None, - interval_seconds: None, - expires_at: Some(expires_at_from_seconds(SOCIAL_CALLBACK_TIMEOUT.as_secs() as i64)), - }) - } - - async fn insert_session( - &self, - state: &str, - expires_at: Option, - ) { - let session = LoginSession { - status: KiroLoginStatus::Waiting, - error: None, - account: None, - expires_at, - }; - let mut guard = self.sessions.write().await; - guard.insert(state.to_string(), session); - } - - async fn complete_session(&self, state: &str, account: KiroAccountSummary) { - let mut guard = self.sessions.write().await; - if let Some(session) = guard.get_mut(state) { - session.status = KiroLoginStatus::Success; - session.error = None; - session.account = Some(account); - } - } - - async fn fail_session(&self, state: &str, message: String) { - let mut guard = self.sessions.write().await; - if let Some(session) = guard.get_mut(state) { - session.status = KiroLoginStatus::Error; - session.error = Some(message); - } - } - - async fn app_proxy_url(&self) -> Option { - self.app_proxy.read().await.clone() - } -} - -struct AuthCodeCallback { - redirect_uri: String, - receiver: tokio::sync::mpsc::Receiver, - shutdown: Option>, -} - -#[derive(Clone)] -struct AuthCodeResult { - code: Option, - state: Option, - error: Option, -} - -async fn start_auth_code_callback(state: String) -> Result { - let (tx, rx) = tokio::sync::mpsc::channel::(1); - let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); - let listener = tokio::net::TcpListener::bind("127.0.0.1:0") - .await - .map_err(|err| format!("Failed to start callback server: {err}"))?; - let port = listener - .local_addr() - .map_err(|err| format!("Failed to read callback port: {err}"))? - .port(); - let redirect_uri = format!("http://127.0.0.1:{port}/oauth/callback"); - let router = axum::Router::new().route( - "/oauth/callback", - axum::routing::get(move |query: axum::extract::Query>| { - let expected_state = state.clone(); - let tx = tx.clone(); - async move { - let code = query.get("code").cloned(); - let state = query.get("state").cloned(); - let error = query.get("error").cloned(); - let has_error = error.is_some(); - let state_matches = state.as_deref() == Some(&expected_state); - let _ = tx.send(AuthCodeResult { code, state, error }).await; - let body = if has_error || !state_matches { - "Login failed. You can close this window." - } else { - "Login successful. You can close this window." - }; - axum::response::Html(body) - } - }), - ); - tauri::async_runtime::spawn(async move { - let _ = axum::serve(listener, router) - .with_graceful_shutdown(async move { - let _ = shutdown_rx.await; - }) - .await; - }); - Ok(AuthCodeCallback { - redirect_uri, - receiver: rx, - shutdown: Some(shutdown_tx), - }) -} - -async fn run_device_code_login( - manager: KiroLoginManager, - state: String, - reg: RegisterClientResponse, - auth: StartDeviceAuthResponse, -) { - let mut interval = Duration::from_secs(auth.interval.max(1) as u64); - let deadline = OffsetDateTime::now_utc() - + time::Duration::seconds(auth.expires_in.max(1)); - let proxy_url = manager.app_proxy_url().await; - let client = match SsoOidcClient::new(proxy_url.as_deref()) { - Ok(client) => client, - Err(err) => { - manager.fail_session(&state, err).await; - return; - } - }; - - while OffsetDateTime::now_utc() < deadline { - tokio::time::sleep(interval).await; - match client - .create_token_device_code(®.client_id, ®.client_secret, &auth.device_code) - .await - { - Ok(token) => { - handle_builder_success(manager, state, token, reg).await; - return; - } - Err(TokenPollError::Pending) => continue, - Err(TokenPollError::SlowDown) => { - interval += Duration::from_secs(5); - continue; - } - Err(TokenPollError::Other(err)) => { - manager.fail_session(&state, err).await; - return; - } - } - } - manager.fail_session(&state, "Authorization timed out.".to_string()).await; -} - -async fn run_auth_code_login( - manager: KiroLoginManager, - state: String, - reg: RegisterClientResponse, - code_verifier: String, - mut callback: AuthCodeCallback, -) { - let redirect_uri = callback.redirect_uri.clone(); - let callback_result = match wait_for_auth_code(&mut callback).await { - Ok(result) => result, - Err(err) => { - manager.fail_session(&state, err).await; - return; - } - }; - let code = match extract_auth_code(&state, callback_result) { - Ok(code) => code, - Err(err) => { - manager.fail_session(&state, err).await; - return; - } - }; - let proxy_url = manager.app_proxy_url().await; - let client = match SsoOidcClient::new(proxy_url.as_deref()) { - Ok(client) => client, - Err(err) => { - manager.fail_session(&state, err).await; - return; - } - }; - let token = match client - .create_token_auth_code( - ®.client_id, - ®.client_secret, - &code, - &code_verifier, - &redirect_uri, - ) - .await - { - Ok(token) => token, - Err(err) => { - manager.fail_session(&state, err).await; - return; - } - }; - handle_builder_success(manager, state, token, reg).await; -} - -async fn wait_for_auth_code(callback: &mut AuthCodeCallback) -> Result { - let shutdown = callback.shutdown.take(); - let result = tokio::time::timeout(AUTH_CODE_TIMEOUT, callback.receiver.recv()).await; - if let Some(shutdown) = shutdown { - let _ = shutdown.send(()); - } - match result { - Ok(Some(callback)) => Ok(callback), - Ok(None) => Err("Authorization callback closed.".to_string()), - Err(_) => Err("Authorization timed out.".to_string()), - } -} - -fn extract_auth_code(state: &str, callback_result: AuthCodeResult) -> Result { - if let Some(err) = callback_result.error { - return Err(err); - } - if callback_result.state.as_deref() != Some(state) { - return Err("OAuth state mismatch.".to_string()); - } - callback_result - .code - .ok_or_else(|| "Authorization code missing.".to_string()) -} - -async fn run_social_login( - manager: KiroLoginManager, - state: String, - provider: String, - code_verifier: String, -) { - let callback_path = callback_file_path(manager.store.dir(), &state); - let deadline = OffsetDateTime::now_utc() - + time::Duration::seconds(SOCIAL_CALLBACK_TIMEOUT.as_secs() as i64); - loop { - if OffsetDateTime::now_utc() > deadline { - manager - .fail_session(&state, "OAuth flow timed out.".to_string()) - .await; - return; - } - if tokio::fs::try_exists(&callback_path).await.unwrap_or(false) { - match read_callback_file(&callback_path).await { - Ok(payload) => { - let _ = tokio::fs::remove_file(&callback_path).await; - if let Some(err) = payload.error { - manager.fail_session(&state, err).await; - return; - } - if payload.state.as_deref() != Some(&state) { - manager - .fail_session(&state, "OAuth state mismatch.".to_string()) - .await; - return; - } - let Some(code) = payload.code else { - manager - .fail_session(&state, "Authorization code missing.".to_string()) - .await; - return; - }; - handle_social_success(manager, state, provider.clone(), code, code_verifier) - .await; - return; - } - Err(err) => { - manager.fail_session(&state, err).await; - return; - } - } - } - tokio::time::sleep(Duration::from_millis(500)).await; - } -} - -async fn handle_builder_success( - manager: KiroLoginManager, - state: String, - token: CreateTokenResponse, - reg: RegisterClientResponse, -) { - let proxy_url = manager.app_proxy_url().await; - let profile_arn = match SsoOidcClient::new(proxy_url.as_deref()) { - Ok(client) => client.fetch_profile_arn(&token.access_token).await, - Err(_) => None, - }; - let record = KiroTokenRecord { - access_token: token.access_token, - refresh_token: token.refresh_token, - profile_arn, - expires_at: expires_at_from_seconds(token.expires_in), - auth_method: "builder-id".to_string(), - provider: "AWS".to_string(), - client_id: Some(reg.client_id), - client_secret: Some(reg.client_secret), - email: None, - last_refresh: Some(now_rfc3339()), - start_url: None, - region: None, - }; - match manager.store.save_new_account(record).await { - Ok(account) => manager.complete_session(&state, account).await, - Err(err) => manager.fail_session(&state, err).await, - } -} - -async fn handle_social_success( - manager: KiroLoginManager, - state: String, - provider: String, - code: String, - code_verifier: String, -) { - let proxy_url = manager.app_proxy_url().await; - let client = match KiroOAuthClient::new(proxy_url.as_deref()) { - Ok(client) => client, - Err(err) => { - manager.fail_session(&state, err).await; - return; - } - }; - let token = match client - .exchange_code(&code, &code_verifier, KIRO_REDIRECT_URI) - .await - { - Ok(token) => token, - Err(err) => { - manager.fail_session(&state, err).await; - return; - } - }; - let record = KiroTokenRecord { - access_token: token.access_token, - refresh_token: token.refresh_token, - profile_arn: token.profile_arn, - expires_at: expires_at_from_seconds(token.expires_in), - auth_method: "social".to_string(), - provider, - client_id: None, - client_secret: None, - email: None, - last_refresh: Some(now_rfc3339()), - start_url: None, - region: None, - }; - match manager.store.save_new_account(record).await { - Ok(account) => manager.complete_session(&state, account).await, - Err(err) => manager.fail_session(&state, err).await, - } -} diff --git a/src-tauri/src/kiro/oauth.rs b/src-tauri/src/kiro/oauth.rs deleted file mode 100644 index 603fd0e..0000000 --- a/src-tauri/src/kiro/oauth.rs +++ /dev/null @@ -1,135 +0,0 @@ -use reqwest::Client; -use serde::{Deserialize, Serialize}; - -use crate::oauth_util::build_reqwest_client; - -use super::types::KiroTokenRecord; -use super::util::{expires_at_from_seconds, now_rfc3339}; - -const KIRO_AUTH_ENDPOINT: &str = "https://prod.us-east-1.auth.desktop.kiro.dev"; -const KIRO_USER_AGENT: &str = "token-proxy/1.0.0"; - -#[derive(Clone)] -pub(crate) struct KiroOAuthClient { - http: Client, -} - -impl KiroOAuthClient { - pub(crate) fn new(proxy_url: Option<&str>) -> Result { - let http = build_reqwest_client(proxy_url, std::time::Duration::from_secs(30)) - .map_err(|err| format!("Failed to build Kiro OAuth client: {err}"))?; - Ok(Self { http }) - } - - pub(crate) async fn exchange_code( - &self, - code: &str, - code_verifier: &str, - redirect_uri: &str, - ) -> Result { - let payload = CreateTokenRequest { - code: code.to_string(), - code_verifier: code_verifier.to_string(), - redirect_uri: redirect_uri.to_string(), - }; - self.post_json("/oauth/token", &payload).await - } - - pub(crate) async fn refresh_token(&self, refresh_token: &str) -> Result { - let payload = RefreshTokenRequest { - refresh_token: refresh_token.to_string(), - }; - self.post_json("/refreshToken", &payload).await - } - - async fn post_json Deserialize<'de>>( - &self, - path: &str, - payload: &TReq, - ) -> Result { - let url = format!("{KIRO_AUTH_ENDPOINT}{path}"); - let response = self - .http - .post(url) - .header("Content-Type", "application/json") - .header("User-Agent", KIRO_USER_AGENT) - .json(payload) - .send() - .await - .map_err(|err| format!("Kiro OAuth request failed: {err}"))?; - let status = response.status(); - let bytes = response - .bytes() - .await - .map_err(|err| format!("Failed to read Kiro OAuth response: {err}"))?; - if !status.is_success() { - return Err(format!( - "Kiro OAuth request failed (status {})", - status.as_u16() - )); - } - serde_json::from_slice(&bytes) - .map_err(|err| format!("Failed to parse Kiro OAuth response: {err}")) - } -} - -pub(crate) fn build_login_url( - provider: &str, - redirect_uri: &str, - code_challenge: &str, - state: &str, -) -> String { - let query = url::form_urlencoded::Serializer::new(String::new()) - .append_pair("idp", provider) - .append_pair("redirect_uri", redirect_uri) - .append_pair("code_challenge", code_challenge) - .append_pair("code_challenge_method", "S256") - .append_pair("state", state) - .append_pair("prompt", "select_account") - .finish(); - format!("{KIRO_AUTH_ENDPOINT}/login?{query}") -} - -pub(crate) async fn refresh_social_token( - record: &KiroTokenRecord, - proxy_url: Option<&str>, -) -> Result { - let client = KiroOAuthClient::new(proxy_url)?; - let response = client.refresh_token(&record.refresh_token).await?; - Ok(KiroTokenRecord { - access_token: response.access_token, - refresh_token: response.refresh_token, - profile_arn: response.profile_arn, - expires_at: expires_at_from_seconds(response.expires_in), - auth_method: "social".to_string(), - provider: record.provider.clone(), - client_id: record.client_id.clone(), - client_secret: record.client_secret.clone(), - email: record.email.clone(), - last_refresh: Some(now_rfc3339()), - start_url: record.start_url.clone(), - region: record.region.clone(), - }) -} - -#[derive(Serialize)] -struct CreateTokenRequest { - code: String, - code_verifier: String, - redirect_uri: String, -} - -#[derive(Serialize)] -struct RefreshTokenRequest { - #[serde(rename = "refreshToken")] - refresh_token: String, -} - -#[derive(Deserialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct KiroTokenResponse { - pub(crate) access_token: String, - pub(crate) refresh_token: String, - pub(crate) profile_arn: Option, - pub(crate) expires_in: i64, -} diff --git a/src-tauri/src/kiro/quota.rs b/src-tauri/src/kiro/quota.rs deleted file mode 100644 index 54bb88f..0000000 --- a/src-tauri/src/kiro/quota.rs +++ /dev/null @@ -1,235 +0,0 @@ -use serde::Serialize; -use serde_json::{Map, Value}; - -use crate::oauth_util::build_reqwest_client; - -use super::store::KiroAccountStore; -use super::types::KiroAccountSummary; - -const KIRO_USAGE_ENDPOINT: &str = "https://codewhisperer.us-east-1.amazonaws.com"; -const KIRO_USAGE_TARGET: &str = "AmazonCodeWhispererService.GetUsageLimits"; -const KIRO_USAGE_ORIGIN: &str = "AI_EDITOR"; -const KIRO_USAGE_RESOURCE_TYPE: &str = "AGENTIC_REQUEST"; -const KIRO_CONTENT_TYPE: &str = "application/x-amz-json-1.0"; -const KIRO_ACCEPT: &str = "application/json"; - -#[derive(Clone, Serialize)] -pub(crate) struct KiroQuotaItem { - pub(crate) name: String, - pub(crate) percentage: f64, - pub(crate) used: Option, - pub(crate) limit: Option, - pub(crate) reset_at: Option, - pub(crate) is_trial: bool, -} - -#[derive(Clone, Serialize)] -pub(crate) struct KiroQuotaSummary { - pub(crate) account_id: String, - pub(crate) provider: String, - pub(crate) plan_type: Option, - pub(crate) quotas: Vec, - pub(crate) error: Option, -} - -pub(crate) async fn fetch_quotas( - store: &KiroAccountStore, -) -> Result, String> { - let accounts = store.list_accounts().await?; - let proxy_url = store.app_proxy_url().await; - let mut results = Vec::with_capacity(accounts.len()); - for account in accounts { - match fetch_account_quota(store, &account, proxy_url.as_deref()).await { - Ok(summary) => results.push(summary), - Err(err) => results.push(KiroQuotaSummary { - account_id: account.account_id.clone(), - provider: account.provider.clone(), - plan_type: None, - quotas: Vec::new(), - error: Some(err), - }), - } - } - Ok(results) -} - -async fn fetch_account_quota( - store: &KiroAccountStore, - account: &KiroAccountSummary, - proxy_url: Option<&str>, -) -> Result { - let record = store.get_account_record(&account.account_id).await?; - let profile_arn = record - .profile_arn - .as_deref() - .ok_or_else(|| "Missing Kiro profile ARN.".to_string())?; - let response = request_usage_limits(&record.access_token, profile_arn, proxy_url).await?; - Ok(map_usage_response(account, &response)) -} - -async fn request_usage_limits( - access_token: &str, - profile_arn: &str, - proxy_url: Option<&str>, -) -> Result { - let http = build_reqwest_client(proxy_url, std::time::Duration::from_secs(30)) - .map_err(|err| format!("Failed to build Kiro usage client: {err}"))?; - let payload = serde_json::json!({ - "origin": KIRO_USAGE_ORIGIN, - "profileArn": profile_arn, - "resourceType": KIRO_USAGE_RESOURCE_TYPE, - }); - let response = http - .post(KIRO_USAGE_ENDPOINT) - .header("Authorization", format!("Bearer {access_token}")) - .header("Content-Type", KIRO_CONTENT_TYPE) - .header("x-amz-target", KIRO_USAGE_TARGET) - .header("Accept", KIRO_ACCEPT) - .json(&payload) - .send() - .await - .map_err(|err| format!("Kiro usage request failed: {err}"))?; - let status = response.status(); - let bytes = response - .bytes() - .await - .map_err(|err| format!("Failed to read Kiro usage response: {err}"))?; - if !status.is_success() { - let body = String::from_utf8_lossy(&bytes); - return Err(format!( - "Kiro usage request failed (status {}): {}", - status.as_u16(), - body - )); - } - serde_json::from_slice(&bytes).map_err(|err| format!("Invalid Kiro usage response: {err}")) -} - -fn map_usage_response(account: &KiroAccountSummary, value: &Value) -> KiroQuotaSummary { - let plan_type = value - .get("subscriptionInfo") - .and_then(Value::as_object) - .and_then(|info| get_string(info, "subscriptionTitle")); - let reset_at = extract_reset_at(value); - let quotas = value - .get("usageBreakdownList") - .and_then(Value::as_array) - .map(|list| build_quota_items(list, reset_at.as_deref())) - .unwrap_or_default(); - - KiroQuotaSummary { - account_id: account.account_id.clone(), - provider: account.provider.clone(), - plan_type, - quotas, - error: None, - } -} - -fn build_quota_items(items: &[Value], reset_at: Option<&str>) -> Vec { - let mut quotas = Vec::new(); - for item in items { - let Some(obj) = item.as_object() else { - continue; - }; - let display_name = get_string(obj, "displayName") - .or_else(|| get_string(obj, "resourceType")) - .unwrap_or_else(|| "Usage".to_string()); - let (base_used, base_limit) = extract_usage_values(obj); - let trial_info = obj.get("freeTrialInfo").and_then(Value::as_object); - let trial_active = trial_info - .and_then(|info| get_string(info, "freeTrialStatus")) - .map(|status| status.eq_ignore_ascii_case("ACTIVE")) - .unwrap_or(false); - - if trial_active { - if let Some(info) = trial_info { - let (trial_used, trial_limit) = extract_usage_values(info); - let trial_reset = get_string(info, "freeTrialExpiry"); - if let Some(item) = build_quota_item( - format!("Bonus {display_name}"), - trial_used, - trial_limit, - trial_reset, - true, - ) { - quotas.push(item); - } - } - } - - let base_name = if trial_active { - format!("{display_name} (Base)") - } else { - display_name - }; - let base_reset = reset_at.map(|val| val.to_string()); - if let Some(item) = build_quota_item(base_name, base_used, base_limit, base_reset, false) { - quotas.push(item); - } - } - quotas -} - -fn build_quota_item( - name: String, - used: Option, - limit: Option, - reset_at: Option, - is_trial: bool, -) -> Option { - if used.is_none() && limit.is_none() { - return None; - } - let percentage = calc_percentage(used, limit); - Some(KiroQuotaItem { - name, - percentage, - used, - limit, - reset_at, - is_trial, - }) -} - -fn extract_usage_values(obj: &Map) -> (Option, Option) { - let used = get_f64(obj, "currentUsageWithPrecision").or_else(|| get_f64(obj, "currentUsage")); - let limit = get_f64(obj, "usageLimitWithPrecision").or_else(|| get_f64(obj, "usageLimit")); - (used, limit) -} - -fn get_string(obj: &Map, key: &str) -> Option { - obj.get(key).and_then(Value::as_str).map(|val| val.to_string()) -} - -fn get_f64(obj: &Map, key: &str) -> Option { - obj.get(key).and_then(as_f64) -} - -fn as_f64(value: &Value) -> Option { - match value { - Value::Number(num) => num.as_f64(), - Value::String(val) => val.parse::().ok(), - _ => None, - } -} - -fn calc_percentage(used: Option, limit: Option) -> f64 { - let (Some(used), Some(limit)) = (used, limit) else { - return 0.0; - }; - if limit <= 0.0 { - return 0.0; - } - let remaining = (limit - used) / limit * 100.0; - remaining.clamp(0.0, 100.0) -} - -fn extract_reset_at(value: &Value) -> Option { - let reset = value.get("nextDateReset")?; - match reset { - Value::String(val) => Some(val.to_string()), - Value::Number(val) => Some(val.to_string()), - _ => None, - } -} diff --git a/src-tauri/src/kiro/sso_oidc.rs b/src-tauri/src/kiro/sso_oidc.rs deleted file mode 100644 index bd3b0a5..0000000 --- a/src-tauri/src/kiro/sso_oidc.rs +++ /dev/null @@ -1,522 +0,0 @@ -use reqwest::Client; -use serde::{Deserialize, Serialize}; -use serde_json::{json, Value}; -use url::form_urlencoded; - -use crate::oauth_util::build_reqwest_client; - -use super::types::KiroTokenRecord; -use super::util::{expires_at_from_seconds, now_rfc3339}; - -const SSO_OIDC_ENDPOINT: &str = "https://oidc.us-east-1.amazonaws.com"; -const BUILDER_ID_START_URL: &str = "https://view.awsapps.com/start"; -const KIRO_USER_AGENT: &str = "KiroIDE"; -const DEFAULT_IDC_REGION: &str = "us-east-1"; -const IDC_AMZ_USER_AGENT: &str = - "aws-sdk-js/3.738.0 ua/2.1 os/other lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE"; -const IDC_USER_AGENT: &str = "node"; -const CODEWHISPERER_ENDPOINT: &str = "https://codewhisperer.us-east-1.amazonaws.com"; -const CODEWHISPERER_CONTENT_TYPE: &str = "application/x-amz-json-1.0"; -const CODEWHISPERER_ACCEPT: &str = "application/json"; -const CW_TARGET_LIST_PROFILES: &str = "AmazonCodeWhispererService.ListProfiles"; -const CW_TARGET_LIST_CUSTOMIZATIONS: &str = "AmazonCodeWhispererService.ListAvailableCustomizations"; -const DEFAULT_SCOPES: [&str; 5] = [ - "codewhisperer:completions", - "codewhisperer:analysis", - "codewhisperer:conversations", - "codewhisperer:transformations", - "codewhisperer:taskassist", -]; -const AUTH_CODE_SCOPES: [&str; 3] = [ - "codewhisperer:completions", - "codewhisperer:analysis", - "codewhisperer:conversations", -]; - -#[derive(Clone)] -pub(crate) struct SsoOidcClient { - http: Client, -} - -impl SsoOidcClient { - pub(crate) fn new(proxy_url: Option<&str>) -> Result { - let http = build_reqwest_client(proxy_url, std::time::Duration::from_secs(30)) - .map_err(|err| format!("Failed to build OIDC client: {err}"))?; - Ok(Self { http }) - } - - pub(crate) async fn register_client(&self) -> Result { - let payload = RegisterClientRequest { - client_name: "Kiro IDE".to_string(), - client_type: "public".to_string(), - scopes: DEFAULT_SCOPES.iter().map(|s| s.to_string()).collect(), - grant_types: vec![ - "urn:ietf:params:oauth:grant-type:device_code".to_string(), - "refresh_token".to_string(), - ], - redirect_uris: None, - issuer_url: None, - }; - self.post_json("/client/register", &payload).await - } - - pub(crate) async fn register_client_for_auth_code( - &self, - redirect_uri: &str, - ) -> Result { - let payload = RegisterClientRequest { - client_name: "Kiro IDE".to_string(), - client_type: "public".to_string(), - scopes: DEFAULT_SCOPES.iter().map(|s| s.to_string()).collect(), - grant_types: vec!["authorization_code".to_string(), "refresh_token".to_string()], - redirect_uris: Some(vec![redirect_uri.to_string()]), - issuer_url: Some(BUILDER_ID_START_URL.to_string()), - }; - self.post_json("/client/register", &payload).await - } - - pub(crate) async fn start_device_authorization( - &self, - client_id: &str, - client_secret: &str, - ) -> Result { - let payload = StartDeviceAuthRequest { - client_id: client_id.to_string(), - client_secret: client_secret.to_string(), - start_url: BUILDER_ID_START_URL.to_string(), - }; - self.post_json("/device_authorization", &payload).await - } - - pub(crate) async fn create_token_device_code( - &self, - client_id: &str, - client_secret: &str, - device_code: &str, - ) -> Result { - let payload = CreateTokenDeviceCodeRequest { - client_id: client_id.to_string(), - client_secret: client_secret.to_string(), - device_code: device_code.to_string(), - grant_type: "urn:ietf:params:oauth:grant-type:device_code".to_string(), - }; - self.post_json_result("/token", &payload).await - } - - pub(crate) async fn create_token_auth_code( - &self, - client_id: &str, - client_secret: &str, - code: &str, - code_verifier: &str, - redirect_uri: &str, - ) -> Result { - let payload = CreateTokenAuthCodeRequest { - client_id: client_id.to_string(), - client_secret: client_secret.to_string(), - code: code.to_string(), - code_verifier: code_verifier.to_string(), - redirect_uri: redirect_uri.to_string(), - grant_type: "authorization_code".to_string(), - }; - self.post_json("/token", &payload).await - } - - pub(crate) async fn refresh_token_with_region( - &self, - client_id: &str, - client_secret: &str, - refresh_token: &str, - region: &str, - ) -> Result { - let payload = RefreshTokenRequest { - client_id: client_id.to_string(), - client_secret: client_secret.to_string(), - refresh_token: refresh_token.to_string(), - grant_type: "refresh_token".to_string(), - }; - let endpoint = oidc_endpoint_for_region(region); - let url = format!("{endpoint}/token"); - let host = format!("oidc.{region}.amazonaws.com"); - let response = self - .http - .post(url) - .header("Content-Type", "application/json") - .header("Host", host) - .header("Connection", "keep-alive") - .header("x-amz-user-agent", IDC_AMZ_USER_AGENT) - .header("Accept", "*/*") - .header("Accept-Language", "*") - .header("sec-fetch-mode", "cors") - .header("User-Agent", IDC_USER_AGENT) - .header("Accept-Encoding", "br, gzip, deflate") - .json(&payload) - .send() - .await - .map_err(|err| format!("IDC refresh request failed: {err}"))?; - let status = response.status(); - let bytes = response - .bytes() - .await - .map_err(|err| format!("Failed to read IDC refresh response: {err}"))?; - if !status.is_success() { - return Err(format!( - "IDC token refresh failed (status {})", - status.as_u16() - )); - } - serde_json::from_slice(&bytes) - .map_err(|err| format!("Failed to parse IDC refresh response: {err}")) - } - - pub(crate) async fn fetch_profile_arn(&self, access_token: &str) -> Option { - if let Some(arn) = self.try_list_profiles(access_token).await { - return Some(arn); - } - self.try_list_customizations(access_token).await - } - - async fn try_list_profiles(&self, access_token: &str) -> Option { - let payload = json!({"origin": "AI_EDITOR"}); - let value = self - .post_codewhisperer(access_token, CW_TARGET_LIST_PROFILES, &payload) - .await - .ok()?; - parse_profile_arn_from_profiles(&value) - } - - async fn try_list_customizations(&self, access_token: &str) -> Option { - let payload = json!({"origin": "AI_EDITOR"}); - let value = self - .post_codewhisperer(access_token, CW_TARGET_LIST_CUSTOMIZATIONS, &payload) - .await - .ok()?; - parse_profile_arn_from_customizations(&value) - } - - async fn post_codewhisperer( - &self, - access_token: &str, - target: &str, - payload: &Value, - ) -> Result { - let response = self - .http - .post(CODEWHISPERER_ENDPOINT) - .header("Content-Type", CODEWHISPERER_CONTENT_TYPE) - .header("x-amz-target", target) - .header("Authorization", format!("Bearer {access_token}")) - .header("Accept", CODEWHISPERER_ACCEPT) - .json(payload) - .send() - .await - .map_err(|err| format!("CodeWhisperer request failed: {err}"))?; - let status = response.status(); - let bytes = response - .bytes() - .await - .map_err(|err| format!("Failed to read CodeWhisperer response: {err}"))?; - if !status.is_success() { - return Err(format!( - "CodeWhisperer request failed (status {})", - status.as_u16() - )); - } - serde_json::from_slice(&bytes) - .map_err(|err| format!("Failed to parse CodeWhisperer response: {err}")) - } - - pub(crate) async fn refresh_builder_token( - &self, - client_id: &str, - client_secret: &str, - refresh_token: &str, - ) -> Result { - let payload = RefreshTokenRequest { - client_id: client_id.to_string(), - client_secret: client_secret.to_string(), - refresh_token: refresh_token.to_string(), - grant_type: "refresh_token".to_string(), - }; - let response: CreateTokenResponse = self.post_json("/token", &payload).await?; - Ok(KiroTokenRecord { - access_token: response.access_token, - refresh_token: response.refresh_token, - profile_arn: None, - expires_at: expires_at_from_seconds(response.expires_in), - auth_method: "builder-id".to_string(), - provider: "AWS".to_string(), - client_id: Some(client_id.to_string()), - client_secret: Some(client_secret.to_string()), - email: None, - last_refresh: Some(now_rfc3339()), - start_url: None, - region: None, - }) - } - - async fn post_json Deserialize<'de>>( - &self, - path: &str, - payload: &TReq, - ) -> Result { - self.post_json_result(path, payload).await.map_err(|err| match err { - TokenPollError::Pending => "Authorization pending.".to_string(), - TokenPollError::SlowDown => "Slow down.".to_string(), - TokenPollError::Other(message) => message, - }) - } - - async fn post_json_result Deserialize<'de>>( - &self, - path: &str, - payload: &TReq, - ) -> Result { - let url = format!("{SSO_OIDC_ENDPOINT}{path}"); - let response = self - .http - .post(url) - .header("Content-Type", "application/json") - .header("User-Agent", KIRO_USER_AGENT) - .json(payload) - .send() - .await - .map_err(|err| TokenPollError::Other(format!("OIDC request failed: {err}")))?; - let status = response.status(); - let bytes = response - .bytes() - .await - .map_err(|err| TokenPollError::Other(format!("Failed to read OIDC response: {err}")))?; - if status.is_success() { - return serde_json::from_slice(&bytes) - .map_err(|err| TokenPollError::Other(format!("Failed to parse OIDC response: {err}"))); - } - if status == reqwest::StatusCode::BAD_REQUEST { - if let Ok(error) = serde_json::from_slice::(&bytes) { - return match error.error.as_str() { - "authorization_pending" => Err(TokenPollError::Pending), - "slow_down" => Err(TokenPollError::SlowDown), - _ => Err(TokenPollError::Other(format!( - "OIDC error: {}", - error.error - ))), - }; - } - } - Err(TokenPollError::Other(format!( - "OIDC request failed (status {})", - status.as_u16() - ))) - } -} - -pub(crate) async fn refresh_builder_token( - record: &KiroTokenRecord, - proxy_url: Option<&str>, -) -> Result { - let client_id = record - .client_id - .as_deref() - .ok_or_else(|| "Missing OIDC client_id.".to_string())?; - let client_secret = record - .client_secret - .as_deref() - .ok_or_else(|| "Missing OIDC client_secret.".to_string())?; - let client = SsoOidcClient::new(proxy_url)?; - client - .refresh_builder_token(client_id, client_secret, &record.refresh_token) - .await -} - -pub(crate) async fn refresh_idc_token( - record: &KiroTokenRecord, - proxy_url: Option<&str>, -) -> Result { - let client_id = record - .client_id - .as_deref() - .ok_or_else(|| "Missing OIDC client_id.".to_string())?; - let client_secret = record - .client_secret - .as_deref() - .ok_or_else(|| "Missing OIDC client_secret.".to_string())?; - let region = record - .region - .as_deref() - .filter(|value| !value.trim().is_empty()) - .unwrap_or(DEFAULT_IDC_REGION); - let client = SsoOidcClient::new(proxy_url)?; - let response = client - .refresh_token_with_region( - client_id, - client_secret, - &record.refresh_token, - region, - ) - .await?; - Ok(KiroTokenRecord { - access_token: response.access_token, - refresh_token: response.refresh_token, - profile_arn: record.profile_arn.clone(), - expires_at: expires_at_from_seconds(response.expires_in), - auth_method: "idc".to_string(), - provider: "AWS".to_string(), - client_id: Some(client_id.to_string()), - client_secret: Some(client_secret.to_string()), - email: record.email.clone(), - last_refresh: Some(now_rfc3339()), - start_url: record.start_url.clone(), - region: Some(region.to_string()), - }) -} - -pub(crate) fn build_auth_code_url( - client_id: &str, - redirect_uri: &str, - state: &str, - code_challenge: &str, -) -> String { - let scopes = AUTH_CODE_SCOPES.join(","); - let query = form_urlencoded::Serializer::new(String::new()) - .append_pair("response_type", "code") - .append_pair("client_id", client_id) - .append_pair("redirect_uri", redirect_uri) - .append_pair("scopes", &scopes) - .append_pair("state", state) - .append_pair("code_challenge", code_challenge) - .append_pair("code_challenge_method", "S256") - .finish(); - format!("{SSO_OIDC_ENDPOINT}/authorize?{query}") -} - -fn oidc_endpoint_for_region(region: &str) -> String { - let trimmed = region.trim(); - let region = if trimmed.is_empty() { - DEFAULT_IDC_REGION - } else { - trimmed - }; - format!("https://oidc.{region}.amazonaws.com") -} - -fn parse_profile_arn_from_profiles(value: &Value) -> Option { - value - .get("profileArn") - .and_then(Value::as_str) - .map(|value| value.to_string()) - .or_else(|| { - value - .get("profiles") - .and_then(Value::as_array) - .and_then(|items| items.first()) - .and_then(|item| item.get("arn")) - .and_then(Value::as_str) - .map(|value| value.to_string()) - }) -} - -fn parse_profile_arn_from_customizations(value: &Value) -> Option { - value - .get("profileArn") - .and_then(Value::as_str) - .map(|value| value.to_string()) - .or_else(|| { - value - .get("customizations") - .and_then(Value::as_array) - .and_then(|items| items.first()) - .and_then(|item| item.get("arn")) - .and_then(Value::as_str) - .map(|value| value.to_string()) - }) -} - -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -struct RegisterClientRequest { - client_name: String, - client_type: String, - scopes: Vec, - grant_types: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - redirect_uris: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - issuer_url: Option, -} - -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -struct StartDeviceAuthRequest { - client_id: String, - client_secret: String, - start_url: String, -} - -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -struct CreateTokenDeviceCodeRequest { - client_id: String, - client_secret: String, - device_code: String, - grant_type: String, -} - -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -struct CreateTokenAuthCodeRequest { - client_id: String, - client_secret: String, - code: String, - code_verifier: String, - redirect_uri: String, - grant_type: String, -} - -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -struct RefreshTokenRequest { - client_id: String, - client_secret: String, - refresh_token: String, - grant_type: String, -} - -#[derive(Debug)] -pub(crate) enum TokenPollError { - Pending, - SlowDown, - Other(String), -} - - -#[derive(Deserialize)] -#[serde(rename_all = "camelCase")] -struct OidcError { - error: String, -} - -#[derive(Deserialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct RegisterClientResponse { - pub(crate) client_id: String, - pub(crate) client_secret: String, -} - -#[derive(Deserialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct StartDeviceAuthResponse { - pub(crate) device_code: String, - pub(crate) user_code: String, - pub(crate) verification_uri: String, - pub(crate) verification_uri_complete: String, - pub(crate) expires_in: i64, - pub(crate) interval: i64, -} - -#[derive(Deserialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct CreateTokenResponse { - pub(crate) access_token: String, - pub(crate) refresh_token: String, - pub(crate) expires_in: i64, -} diff --git a/src-tauri/src/kiro/store.rs b/src-tauri/src/kiro/store.rs deleted file mode 100644 index 147f746..0000000 --- a/src-tauri/src/kiro/store.rs +++ /dev/null @@ -1,514 +0,0 @@ -use std::collections::HashMap; -use std::path::{Path, PathBuf}; - -use serde::Deserialize; -use tauri::AppHandle; -use time::format_description::well_known::Rfc3339; -use time::OffsetDateTime; -use tokio::sync::RwLock; - -use crate::app_proxy::AppProxyState; -use crate::proxy::config::config_dir_path; - -use super::oauth; -use super::sso_oidc; -use super::types::{KiroAccountStatus, KiroAccountSummary, KiroTokenRecord}; -use super::util::{expires_at_from_seconds, extract_email_from_jwt, now_rfc3339, sanitize_id_part}; - -const KIRO_AUTH_DIR_NAME: &str = "kiro-auth"; - -pub(crate) struct KiroAccountStore { - dir: PathBuf, - cache: RwLock>, - app_proxy: AppProxyState, -} - -impl KiroAccountStore { - pub(crate) fn new(app: &AppHandle, app_proxy: AppProxyState) -> Result { - let dir = config_dir_path(app)?.join(KIRO_AUTH_DIR_NAME); - Ok(Self { - dir, - cache: RwLock::new(HashMap::new()), - app_proxy, - }) - } - - pub(crate) fn dir(&self) -> &Path { - &self.dir - } - - pub(crate) async fn import_ide_tokens( - &self, - directory: PathBuf, - ) -> Result, String> { - if directory.as_os_str().is_empty() { - return Err("Directory is required.".to_string()); - } - let mut entries = match tokio::fs::read_dir(&directory).await { - Ok(entries) => entries, - Err(err) if err.kind() == std::io::ErrorKind::NotFound => { - return Err("Selected directory not found.".to_string()); - } - Err(err) => { - return Err(format!("Failed to read selected directory: {err}")); - } - }; - let mut imported = Vec::new(); - // 仅扫描所选目录本层的 JSON 文件,忽略无效内容。 - while let Some(entry) = entries - .next_entry() - .await - .map_err(|err| format!("Failed to read directory entry: {err}"))? - { - let path = entry.path(); - let file_type = entry - .file_type() - .await - .map_err(|err| format!("Failed to read entry type: {err}"))?; - if !file_type.is_file() || !is_json_file(&path) { - continue; - } - let Some(record) = load_ide_token_record(&path).await else { - continue; - }; - if let Ok(summary) = self.save_new_account(record).await { - imported.push(summary); - } - } - if imported.is_empty() { - return Err("No valid Kiro token JSON files found.".to_string()); - } - Ok(imported) - } - - pub(crate) async fn import_kam_export( - &self, - path: PathBuf, - ) -> Result, String> { - if path.as_os_str().is_empty() { - return Err("File path is required.".to_string()); - } - if !tokio::fs::try_exists(&path).await.unwrap_or(false) { - return Err("Selected file not found.".to_string()); - } - let contents = tokio::fs::read_to_string(&path) - .await - .map_err(|err| format!("Failed to read JSON file: {err}"))?; - let data: KamExportData = serde_json::from_str(&contents) - .map_err(|err| format!("Invalid Kiro account JSON file: {err}"))?; - let mut imported = Vec::new(); - for account in data.accounts { - let Some(record) = kam_account_to_record(account) else { - continue; - }; - if let Ok(summary) = self.save_new_account(record).await { - imported.push(summary); - } - } - if imported.is_empty() { - return Err("No valid Kiro accounts found in JSON file.".to_string()); - } - Ok(imported) - } - - pub(crate) async fn list_accounts(&self) -> Result, String> { - self.refresh_cache().await?; - let cache = self.cache.read().await; - let mut items: Vec = cache - .iter() - .map(|(account_id, record)| KiroAccountSummary { - account_id: account_id.clone(), - provider: record.provider.clone(), - auth_method: record.auth_method.clone(), - email: record.email.clone(), - expires_at: record.expires_at().map(|value| value.format(&time::format_description::well_known::Rfc3339).unwrap_or_else(|_| record.expires_at.clone())), - status: record.status(), - }) - .collect(); - items.sort_by(|left, right| left.account_id.cmp(&right.account_id)); - Ok(items) - } - - pub(crate) async fn get_access_token(&self, account_id: &str) -> Result { - let record = self.load_account(account_id).await?; - let refreshed = self.refresh_if_needed(account_id, record).await?; - Ok(refreshed.access_token) - } - - pub(crate) async fn get_account_record( - &self, - account_id: &str, - ) -> Result { - let record = self.load_account(account_id).await?; - self.refresh_if_needed(account_id, record).await - } - - pub(crate) async fn refresh_account(&self, account_id: &str) -> Result<(), String> { - let record = self.load_account(account_id).await?; - let refreshed = self.refresh_record(account_id, record).await?; - let summary = self - .save_record(account_id.to_string(), refreshed) - .await?; - if matches!(summary.status, KiroAccountStatus::Expired) { - return Err("Kiro token refresh failed.".to_string()); - } - Ok(()) - } - - pub(crate) async fn save_record( - &self, - account_id: String, - record: KiroTokenRecord, - ) -> Result { - self.ensure_dir().await?; - let path = self.account_path(&account_id); - let payload = serde_json::to_string_pretty(&record) - .map_err(|err| format!("Failed to serialize token record: {err}"))?; - tokio::fs::write(&path, payload) - .await - .map_err(|err| format!("Failed to write token record: {err}"))?; - let mut cache = self.cache.write().await; - cache.insert(account_id.clone(), record.clone()); - Ok(KiroAccountSummary { - account_id, - provider: record.provider.clone(), - auth_method: record.auth_method.clone(), - email: record.email.clone(), - expires_at: record.expires_at().map(|value| value.format(&time::format_description::well_known::Rfc3339).unwrap_or_else(|_| record.expires_at.clone())), - status: record.status(), - }) - } - - pub(crate) async fn save_new_account( - &self, - mut record: KiroTokenRecord, - ) -> Result { - if record.email.is_none() { - record.email = extract_email_from_jwt(&record.access_token); - } - let provider = record.provider.trim().to_ascii_lowercase(); - let id_part_source = record - .email - .as_deref() - .or(record.profile_arn.as_deref()) - .unwrap_or_default(); - let mut id_part = sanitize_id_part(id_part_source); - if id_part.is_empty() { - id_part = format!("{}", OffsetDateTime::now_utc().unix_timestamp()); - } - let account_id = self.unique_account_id(&provider, &id_part).await?; - self.save_record(account_id, record).await - } - - pub(crate) async fn delete_account(&self, account_id: &str) -> Result<(), String> { - let path = self.account_path(account_id); - if tokio::fs::try_exists(&path).await.unwrap_or(false) { - tokio::fs::remove_file(&path) - .await - .map_err(|err| format!("Failed to delete token record: {err}"))?; - } - let mut cache = self.cache.write().await; - cache.remove(account_id); - Ok(()) - } - - async fn refresh_if_needed( - &self, - account_id: &str, - record: KiroTokenRecord, - ) -> Result { - if !record.is_expired() { - return Ok(record); - } - self.refresh_record(account_id, record).await - } - - async fn refresh_record( - &self, - account_id: &str, - record: KiroTokenRecord, - ) -> Result { - let proxy_url = self.app_proxy_url().await; - let refreshed = match record.auth_method.as_str() { - "builder-id" => sso_oidc::refresh_builder_token(&record, proxy_url.as_deref()).await?, - "idc" => sso_oidc::refresh_idc_token(&record, proxy_url.as_deref()).await?, - "social" => oauth::refresh_social_token(&record, proxy_url.as_deref()).await?, - _ => return Err("Unsupported Kiro auth method.".to_string()), - }; - let summary = self.save_record(account_id.to_string(), refreshed.clone()).await?; - if matches!(summary.status, KiroAccountStatus::Expired) { - return Err("Kiro token refresh failed.".to_string()); - } - Ok(refreshed) - } - - async fn load_account(&self, account_id: &str) -> Result { - if let Some(record) = self.cache.read().await.get(account_id).cloned() { - return Ok(record); - } - self.refresh_cache().await?; - self.cache - .read() - .await - .get(account_id) - .cloned() - .ok_or_else(|| format!("Kiro account not found: {account_id}")) - } - - pub(crate) async fn app_proxy_url(&self) -> Option { - self.app_proxy.read().await.clone() - } - - async fn refresh_cache(&self) -> Result<(), String> { - let mut cache = HashMap::new(); - let dir = self.dir.clone(); - let mut entries = match tokio::fs::read_dir(&dir).await { - Ok(entries) => entries, - Err(err) if err.kind() == std::io::ErrorKind::NotFound => { - let mut guard = self.cache.write().await; - guard.clear(); - return Ok(()); - } - Err(err) => return Err(format!("Failed to read Kiro auth directory: {err}")), - }; - - while let Some(entry) = entries - .next_entry() - .await - .map_err(|err| format!("Failed to read Kiro auth entry: {err}"))? - { - let path = entry.path(); - if path.extension().and_then(|ext| ext.to_str()) != Some("json") { - continue; - } - let file_name = match path.file_name().and_then(|name| name.to_str()) { - Some(name) => name.to_string(), - None => continue, - }; - let contents = match tokio::fs::read_to_string(&path).await { - Ok(contents) => contents, - Err(_) => continue, - }; - let record: KiroTokenRecord = match serde_json::from_str(&contents) { - Ok(record) => record, - Err(_) => continue, - }; - cache.insert(file_name, record); - } - - let mut guard = self.cache.write().await; - *guard = cache; - Ok(()) - } - - async fn ensure_dir(&self) -> Result<(), String> { - tokio::fs::create_dir_all(&self.dir) - .await - .map_err(|err| format!("Failed to create Kiro auth dir: {err}")) - } - - async fn unique_account_id(&self, provider: &str, id_part: &str) -> Result { - self.ensure_dir().await?; - let mut suffix = 0u32; - loop { - let candidate = if suffix == 0 { - format!("kiro-{provider}-{id_part}.json") - } else { - format!("kiro-{provider}-{id_part}-{suffix}.json") - }; - if !tokio::fs::try_exists(self.account_path(&candidate)) - .await - .unwrap_or(false) - { - return Ok(candidate); - } - suffix += 1; - } - } - - fn account_path(&self, account_id: &str) -> PathBuf { - self.dir.join(account_id) - } -} - -fn is_json_file(path: &Path) -> bool { - path.extension() - .and_then(|ext| ext.to_str()) - .is_some_and(|ext| ext.eq_ignore_ascii_case("json")) -} - -async fn load_ide_token_record(path: &Path) -> Option { - let contents = tokio::fs::read_to_string(path).await.ok()?; - let token: KiroIdeTokenFile = serde_json::from_str(&contents).ok()?; - token.into_record().ok() -} - -#[derive(Deserialize)] -#[serde(rename_all = "camelCase")] -struct KamExportData { - accounts: Vec, -} - -#[derive(Deserialize)] -#[serde(rename_all = "camelCase")] -struct KamAccount { - email: Option, - idp: Option, - credentials: Option, -} - -#[derive(Deserialize)] -#[serde(rename_all = "camelCase")] -struct KamCredentials { - access_token: Option, - refresh_token: Option, - client_id: Option, - client_secret: Option, - region: Option, - start_url: Option, - expires_at: Option, - auth_method: Option, - provider: Option, -} - -fn kam_account_to_record(account: KamAccount) -> Option { - let credentials = account.credentials?; - let access_token = credentials.access_token?.trim().to_string(); - let refresh_token = credentials.refresh_token?.trim().to_string(); - if access_token.is_empty() || refresh_token.is_empty() { - return None; - } - let provider = credentials - .provider - .filter(|value| !value.trim().is_empty()) - .or(account.idp.filter(|value| !value.trim().is_empty())) - .unwrap_or_else(|| "AWS".to_string()); - let auth_method = normalize_auth_method( - credentials.auth_method.as_deref(), - Some(provider.as_str()), - ); - let expires_at = credentials - .expires_at - .and_then(format_expires_at) - .unwrap_or_else(|| expires_at_from_seconds(3600)); - Some(KiroTokenRecord { - access_token, - refresh_token, - profile_arn: None, - expires_at, - auth_method, - provider, - client_id: credentials.client_id, - client_secret: credentials.client_secret, - email: account.email.filter(|value| !value.trim().is_empty()), - last_refresh: Some(now_rfc3339()), - start_url: credentials.start_url, - region: credentials.region, - }) -} - -fn normalize_auth_method(raw: Option<&str>, provider: Option<&str>) -> String { - let raw_value = raw.unwrap_or("").trim().to_ascii_lowercase(); - if matches!(raw_value.as_str(), "idc") { - return "idc".to_string(); - } - if matches!(raw_value.as_str(), "social") { - return "social".to_string(); - } - if matches!(raw_value.as_str(), "builder-id" | "builder_id") { - return "builder-id".to_string(); - } - let provider_value = provider.unwrap_or("").trim().to_ascii_lowercase(); - if provider_value.contains("google") || provider_value.contains("github") { - return "social".to_string(); - } - if provider_value.contains("idc") - || provider_value.contains("enterprise") - || provider_value.contains("iam") - { - return "idc".to_string(); - } - "builder-id".to_string() -} - -fn format_expires_at(value: i64) -> Option { - let (seconds, nanos) = if value >= 10_000_000_000 { - let secs = value / 1000; - let ms = value % 1000; - (secs, ms * 1_000_000) - } else { - (value, 0) - }; - let nanos_total = i128::from(seconds) - .checked_mul(1_000_000_000)? - .checked_add(i128::from(nanos))?; - OffsetDateTime::from_unix_timestamp_nanos(nanos_total) - .ok()? - .format(&Rfc3339) - .ok() -} - -#[derive(Deserialize)] -#[serde(rename_all = "camelCase")] -struct KiroIdeTokenFile { - access_token: String, - refresh_token: String, - profile_arn: Option, - expires_at: Option, - auth_method: Option, - provider: Option, - client_id: Option, - client_secret: Option, - email: Option, - start_url: Option, - region: Option, - last_refresh: Option, -} - -impl KiroIdeTokenFile { - fn into_record(self) -> Result { - if self.access_token.trim().is_empty() { - return Err("Missing access token.".to_string()); - } - if self.refresh_token.trim().is_empty() { - return Err("Missing refresh token.".to_string()); - } - let provider = self - .provider - .filter(|value| !value.trim().is_empty()) - .unwrap_or_else(|| "AWS".to_string()); - // Default to Builder ID when metadata is missing in IDE token files. - let auth_method = self - .auth_method - .filter(|value| !value.trim().is_empty()) - .unwrap_or_else(|| { - if provider.eq_ignore_ascii_case("google") { - "social".to_string() - } else { - "builder-id".to_string() - } - }); - let expires_at = match self.expires_at.as_deref() { - Some(value) if !value.trim().is_empty() => value.to_string(), - _ => expires_at_from_seconds(3600), - }; - let last_refresh = self - .last_refresh - .filter(|value| !value.trim().is_empty()) - .unwrap_or_else(now_rfc3339); - Ok(KiroTokenRecord { - access_token: self.access_token, - refresh_token: self.refresh_token, - profile_arn: self.profile_arn, - expires_at, - auth_method, - provider, - client_id: self.client_id, - client_secret: self.client_secret, - email: self.email.filter(|value| !value.trim().is_empty()), - last_refresh: Some(last_refresh), - start_url: self.start_url, - region: self.region, - }) - } -} diff --git a/src-tauri/src/kiro/types.rs b/src-tauri/src/kiro/types.rs deleted file mode 100644 index 8dd75fe..0000000 --- a/src-tauri/src/kiro/types.rs +++ /dev/null @@ -1,112 +0,0 @@ -use serde::{Deserialize, Serialize}; -use time::format_description::well_known::Rfc3339; -use time::OffsetDateTime; - -#[derive(Clone, Serialize, Deserialize)] -pub(crate) struct KiroTokenRecord { - pub(crate) access_token: String, - pub(crate) refresh_token: String, - pub(crate) profile_arn: Option, - pub(crate) expires_at: String, - pub(crate) auth_method: String, - pub(crate) provider: String, - pub(crate) client_id: Option, - pub(crate) client_secret: Option, - pub(crate) email: Option, - pub(crate) last_refresh: Option, - pub(crate) start_url: Option, - pub(crate) region: Option, -} - -impl KiroTokenRecord { - pub(crate) fn expires_at(&self) -> Option { - let value = self.expires_at.trim(); - if value.is_empty() { - return None; - } - OffsetDateTime::parse(value, &Rfc3339).ok() - } - - pub(crate) fn is_expired(&self) -> bool { - let Some(expires_at) = self.expires_at() else { - return true; - }; - OffsetDateTime::now_utc() >= expires_at - } - - pub(crate) fn status(&self) -> KiroAccountStatus { - if self.is_expired() { - KiroAccountStatus::Expired - } else { - KiroAccountStatus::Active - } - } -} - -#[derive(Clone, Serialize)] -#[serde(rename_all = "snake_case")] -pub(crate) enum KiroAccountStatus { - Active, - Expired, -} - -#[derive(Clone, Serialize)] -pub(crate) struct KiroAccountSummary { - pub(crate) account_id: String, - pub(crate) provider: String, - pub(crate) auth_method: String, - pub(crate) email: Option, - pub(crate) expires_at: Option, - pub(crate) status: KiroAccountStatus, -} - -#[derive(Clone, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub(crate) enum KiroLoginMethod { - Aws, - AwsAuthcode, - Google, -} - -impl std::str::FromStr for KiroLoginMethod { - type Err = String; - - fn from_str(value: &str) -> Result { - match value.trim().to_ascii_lowercase().as_str() { - "" | "aws" | "builder-id" | "builder_id" => Ok(Self::Aws), - "aws_authcode" | "aws-authcode" | "builder-authcode" | "builder_authcode" => { - Ok(Self::AwsAuthcode) - } - "google" => Ok(Self::Google), - other => Err(format!("Unsupported login method: {other}")), - } - } -} - -#[derive(Clone, Serialize, PartialEq, Eq)] -#[serde(rename_all = "snake_case")] -pub(crate) enum KiroLoginStatus { - Waiting, - Success, - Error, -} - -#[derive(Clone, Serialize)] -pub(crate) struct KiroLoginStartResponse { - pub(crate) state: String, - pub(crate) method: KiroLoginMethod, - pub(crate) login_url: Option, - pub(crate) verification_uri: Option, - pub(crate) verification_uri_complete: Option, - pub(crate) user_code: Option, - pub(crate) interval_seconds: Option, - pub(crate) expires_at: Option, -} - -#[derive(Clone, Serialize)] -pub(crate) struct KiroLoginPollResponse { - pub(crate) state: String, - pub(crate) status: KiroLoginStatus, - pub(crate) error: Option, - pub(crate) account: Option, -} diff --git a/src-tauri/src/kiro/util.rs b/src-tauri/src/kiro/util.rs deleted file mode 100644 index c5bd320..0000000 --- a/src-tauri/src/kiro/util.rs +++ /dev/null @@ -1,8 +0,0 @@ -pub(crate) use crate::oauth_util::{ - expires_at_from_seconds, - extract_email_from_jwt, - generate_pkce, - generate_state, - now_rfc3339, - sanitize_id_part, -};