diff --git a/README.md b/README.md index 184536f..9261cd5 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ Add the plugin to your `opencode.json` or `opencode.jsonc`: "variants": { "low": { "thinkingConfig": { "thinkingBudget": 8192 } }, "medium": { "thinkingConfig": { "thinkingBudget": 16384 } }, - "max": { "thinkingConfig": { "thinkingBudget": 32768 } } + "high": { "thinkingConfig": { "thinkingBudget": 32768 } } } }, "claude-haiku-4-5": { @@ -58,8 +58,18 @@ Add the plugin to your `opencode.json` or `opencode.jsonc`: "variants": { "low": { "thinkingConfig": { "thinkingBudget": 8192 } }, "medium": { "thinkingConfig": { "thinkingBudget": 16384 } }, - "max": { "thinkingConfig": { "thinkingBudget": 32768 } } + "high": { "thinkingConfig": { "thinkingBudget": 32768 } } } + }, + "claude-sonnet-4-5-1m": { + "name": "Claude Sonnet 4.5 1M", + "limit": { "context": 1000000, "output": 64000 }, + "modalities": { "input": ["text", "image", "pdf"], "output": ["text"] } + }, + "qwen3-coder-480b": { + "name": "Qwen3 Coder 480B", + "limit": { "context": 200000, "output": 64000 }, + "modalities": { "input": ["text"], "output": ["text"] } } } } @@ -108,7 +118,9 @@ The plugin supports extensive configuration options. Edit `~/.config/opencode/ki "usage_sync_max_retries": 3, "auth_server_port_start": 19847, "auth_server_port_range": 10, + "builder_id_start_url": "https://view.awsapps.com/start", "usage_tracking_enabled": true, + "usage_toast_enabled": false, "enable_log_api_request": false } ``` @@ -117,7 +129,7 @@ The plugin supports extensive configuration options. Edit `~/.config/opencode/ki - `auto_sync_kiro_cli`: Automatically sync sessions from Kiro CLI (default: `true`). - `account_selection_strategy`: Account rotation strategy (`sticky`, `round-robin`, `lowest-usage`). -- `default_region`: AWS region (`us-east-1`, `us-west-2`). +- `default_region`: AWS region (e.g. `us-east-1`, `eu-west-1`). - `rate_limit_retry_delay_ms`: Delay between rate limit retries (1000-60000ms). - `rate_limit_max_retries`: Maximum retry attempts for rate limits (0-10). - `max_request_iterations`: Maximum loop iterations to prevent hangs (10-1000). @@ -126,7 +138,16 @@ The plugin supports extensive configuration options. Edit `~/.config/opencode/ki - `usage_sync_max_retries`: Retry attempts for usage sync (0-5). - `auth_server_port_start`: Starting port for auth server (1024-65535). - `auth_server_port_range`: Number of ports to try (1-100). +- `builder_id_start_url`: Default AWS start URL shown in the auth window (you can override it in the browser UI). + +## Authentication UI + +The browser page is a single window: + +- Enter `Start URL` and `Region`, then click `Begin`. +- The AWS verification page opens only when you click `Open Browser`. - `usage_tracking_enabled`: Enable usage tracking and toast notifications. +- `usage_toast_enabled`: Show per-account usage toasts (default: `false`). - `enable_log_api_request`: Enable detailed API request logging. ## Storage diff --git a/package.json b/package.json index 5107980..d41d1c6 100644 --- a/package.json +++ b/package.json @@ -1,12 +1,13 @@ { "name": "@zhafron/opencode-kiro-auth", - "version": "1.5.1", + "version": "1.5.2", "description": "OpenCode plugin for AWS Kiro (CodeWhisperer) providing access to Claude models", "type": "module", "main": "dist/index.js", "types": "dist/index.d.ts", "scripts": { "build": "tsc -p tsconfig.build.json", + "test": "npm run build && node --test", "format": "prettier --write 'src/**/*.ts'", "typecheck": "tsc --noEmit", "prepare": "husky" diff --git a/src/constants.ts b/src/constants.ts index 40801fc..405ae35 100644 --- a/src/constants.ts +++ b/src/constants.ts @@ -1,9 +1,9 @@ import type { KiroRegion } from './plugin/types' -const VALID_REGIONS: readonly KiroRegion[] = ['us-east-1', 'us-west-2'] +const REGION_REGEX = /^[a-z]{2}-[a-z-]+-\d+$/ export function isValidRegion(region: string): region is KiroRegion { - return VALID_REGIONS.includes(region as KiroRegion) + return REGION_REGEX.test(region) } export function normalizeRegion(region: string | undefined): KiroRegion { @@ -38,14 +38,21 @@ export const KIRO_CONSTANTS = { } export const MODEL_MAPPING: Record = { - 'claude-haiku-4-5': 'claude-haiku-4.5', + 'claude-haiku-4-5': 'CLAUDE_HAIKU_4_5_20251001_V1_0', + 'claude-haiku-4-5-thinking': 'CLAUDE_HAIKU_4_5_20251001_V1_0', 'claude-sonnet-4-5': 'CLAUDE_SONNET_4_5_20250929_V1_0', 'claude-sonnet-4-5-thinking': 'CLAUDE_SONNET_4_5_20250929_V1_0', - 'claude-sonnet-4-5-20250929': 'CLAUDE_SONNET_4_5_20250929_V1_0', + 'claude-sonnet-4-5-1m': 'CLAUDE_SONNET_4_5_20250929_LONG_V1_0', + 'claude-sonnet-4-5-1m-thinking': 'CLAUDE_SONNET_4_5_20250929_LONG_V1_0', 'claude-opus-4-5': 'CLAUDE_OPUS_4_5_20251101_V1_0', 'claude-opus-4-5-thinking': 'CLAUDE_OPUS_4_5_20251101_V1_0', - 'claude-sonnet-4-20250514': 'CLAUDE_SONNET_4_20250514_V1_0', - 'claude-3-7-sonnet-20250219': 'CLAUDE_3_7_SONNET_20250219_V1_0' + 'claude-sonnet-4': 'CLAUDE_SONNET_4_20250514_V1_0', + 'claude-3-7-sonnet': 'CLAUDE_3_7_SONNET_20250219_V1_0', + 'nova-swe': 'AGI_NOVA_SWE_V1_5', + 'gpt-oss-120b': 'OPENAI_GPT_OSS_120B_1_0', + 'qwen3-coder-480b': 'QWEN3_CODER_480B_A35B_1_0', + 'minimax-m2': 'MINIMAX_MINIMAX_M2', + 'kimi-k2-thinking': 'MOONSHOT_KIMI_K2_THINKING' } export const SUPPORTED_MODELS = Object.keys(MODEL_MAPPING) diff --git a/src/core/account/account-selector.ts b/src/core/account/account-selector.ts index 1bf26ba..43dc7f8 100644 --- a/src/core/account/account-selector.ts +++ b/src/core/account/account-selector.ts @@ -7,6 +7,8 @@ type ToastFunction = (message: string, variant: 'info' | 'warning' | 'success' | interface AccountSelectorConfig { auto_sync_kiro_cli: boolean account_selection_strategy: 'sticky' | 'round-robin' | 'lowest-usage' + usage_tracking_enabled: boolean + usage_toast_enabled: boolean } export class AccountSelector { @@ -61,6 +63,8 @@ export class AccountSelector { } if ( + this.config.usage_tracking_enabled && + this.config.usage_toast_enabled && this.accountManager.shouldShowUsageToast() && acc.usedCount !== undefined && acc.limitCount !== undefined diff --git a/src/core/auth/auth-handler.ts b/src/core/auth/auth-handler.ts index 880fe7b..8e148aa 100644 --- a/src/core/auth/auth-handler.ts +++ b/src/core/auth/auth-handler.ts @@ -31,7 +31,7 @@ export class AuthHandler { return [] } - const idcMethod = new IdcAuthMethod(this.config, this.repository) + const idcMethod = new IdcAuthMethod(this.config, this.repository, this.accountManager) return [ { diff --git a/src/core/auth/idc-auth-method.ts b/src/core/auth/idc-auth-method.ts index 0c32a58..68927f2 100644 --- a/src/core/auth/idc-auth-method.ts +++ b/src/core/auth/idc-auth-method.ts @@ -1,7 +1,7 @@ import { exec } from 'node:child_process' import type { AccountRepository } from '../../infrastructure/database/account-repository.js' import type { KiroIDCTokenResult } from '../../kiro/oauth-idc.js' -import { authorizeKiroIDC } from '../../kiro/oauth-idc.js' +import type { AccountManager } from '../../plugin/accounts.js' import { createDeterministicAccountId } from '../../plugin/accounts.js' import { promptAddAnotherAccount, promptDeleteAccount, promptLoginMode } from '../../plugin/cli.js' import * as logger from '../../plugin/logger.js' @@ -26,7 +26,8 @@ const openBrowser = (url: string) => { export class IdcAuthMethod { constructor( private config: any, - private repository: AccountRepository + private repository: AccountRepository, + private accountManager: AccountManager ) {} async authorize(inputs?: any): Promise<{ @@ -90,9 +91,8 @@ export class IdcAuthMethod { } while (true) { try { - const authData = await authorizeKiroIDC(region) const { url, waitForAuth } = await startIDCAuthServer( - authData, + { defaultRegion: region, defaultStartUrl: this.config.builder_id_start_url }, this.config.auth_server_port_start, this.config.auth_server_port_range ) @@ -108,9 +108,9 @@ export class IdcAuthMethod { clientSecret: res.clientSecret }) if (!u.email) { - console.log('\n[Error] Failed to fetch account email. Skipping...\n') - continue + console.log('\n[Warn] Failed to fetch account email; saving with fallback email.\n') } + const email = u.email || res.email || 'builder-id@aws.amazon.com' accounts.push(res as KiroIDCTokenResult) if (accounts.length === 1 && startFresh) { const allAccounts = await this.repository.findAll() @@ -119,10 +119,10 @@ export class IdcAuthMethod { await this.repository.delete(acc.id) } } - const id = createDeterministicAccountId(u.email, 'idc', res.clientId) + const id = createDeterministicAccountId(email, 'idc', res.clientId) const acc: ManagedAccount = { id, - email: u.email, + email, authMethod: 'idc', region, clientId: res.clientId, @@ -137,6 +137,7 @@ export class IdcAuthMethod { limitCount: u.limitCount } await this.repository.save(acc) + this.accountManager.addAccount(acc) const currentCount = (await this.repository.findAll()).length console.log(`\n[Success] Added: ${u.email} (Quota: ${u.usedCount}/${u.limitCount})\n`) if (!(await promptAddAnotherAccount(currentCount))) break @@ -159,9 +160,8 @@ export class IdcAuthMethod { private async handleSingleLogin(region: KiroRegion, resolve: any): Promise { try { - const authData = await authorizeKiroIDC(region) const { url, waitForAuth } = await startIDCAuthServer( - authData, + { defaultRegion: region, defaultStartUrl: this.config.builder_id_start_url }, this.config.auth_server_port_start, this.config.auth_server_port_range ) @@ -173,20 +173,30 @@ export class IdcAuthMethod { callback: async () => { try { const res = await waitForAuth() - const u = await fetchUsageLimits({ - refresh: '', - access: res.accessToken, - expires: res.expiresAt, - authMethod: 'idc', - region, - clientId: res.clientId, - clientSecret: res.clientSecret - }) - if (!u.email) throw new Error('No email') - const id = createDeterministicAccountId(u.email, 'idc', res.clientId) + + let u: any = {} + try { + u = await fetchUsageLimits({ + refresh: '', + access: res.accessToken, + expires: res.expiresAt, + authMethod: 'idc', + region, + clientId: res.clientId, + clientSecret: res.clientSecret + }) + } catch (e) { + logger.warn( + 'Failed to fetch usage/email after auth; saving account with fallback email', + e + ) + } + + const email = u.email || res.email || 'builder-id@aws.amazon.com' + const id = createDeterministicAccountId(email, 'idc', res.clientId) const acc: ManagedAccount = { id, - email: u.email, + email, authMethod: 'idc', region, clientId: res.clientId, @@ -201,6 +211,7 @@ export class IdcAuthMethod { limitCount: u.limitCount } await this.repository.save(acc) + this.accountManager.addAccount(acc) return { type: 'success', key: res.accessToken } } catch (e: any) { return { type: 'failed' } diff --git a/src/core/request/error-handler.ts b/src/core/request/error-handler.ts index adece42..91c19b9 100644 --- a/src/core/request/error-handler.ts +++ b/src/core/request/error-handler.ts @@ -101,13 +101,25 @@ export class ErrorHandler { let isPermanent = false try { const errorBody = await response.text() - const errorData = JSON.parse(errorBody) - if (errorData.reason === 'INVALID_MODEL_ID') { - throw new Error(`Invalid model: ${errorData.message}`) - } - if (errorData.reason === 'TEMPORARILY_SUSPENDED') { - errorReason = 'Account Suspended' - isPermanent = true + try { + const errorData = JSON.parse(errorBody) + if (errorData.reason === 'INVALID_MODEL_ID') { + throw new Error(`Invalid model: ${errorData.message}`) + } + if (errorData.reason === 'TEMPORARILY_SUSPENDED') { + errorReason = 'Account Suspended' + isPermanent = true + } else if (errorData.reason || errorData.message) { + const detail = errorData.reason + ? `${errorData.reason}${errorData.message ? `: ${errorData.message}` : ''}` + : errorData.message + errorReason = `${errorReason} (${detail})` + } + } catch (parseError) { + if (errorBody) { + const trimmed = errorBody.replace(/\s+/g, ' ').trim().slice(0, 160) + if (trimmed) errorReason = `${errorReason} (${trimmed})` + } } } catch (e) { if (e instanceof Error && e.message.includes('Invalid model')) { diff --git a/src/core/request/request-handler.ts b/src/core/request/request-handler.ts index 9f8b8e3..304daf1 100644 --- a/src/core/request/request-handler.ts +++ b/src/core/request/request-handler.ts @@ -12,6 +12,7 @@ import { TokenRefresher } from '../auth/token-refresher' import { ErrorHandler } from './error-handler' import { ResponseHandler } from './response-handler' import { RetryStrategy } from './retry-strategy' +import { resolveThinkingConfig } from './thinking' type ToastFunction = (message: string, variant: 'info' | 'warning' | 'success' | 'error') => void @@ -55,8 +56,9 @@ export class RequestHandler { ): Promise { const body = init?.body ? JSON.parse(init.body) : {} const model = this.extractModel(url) || body.model || 'claude-sonnet-4-5' - const think = model.endsWith('-thinking') || !!body.providerOptions?.thinkingConfig - const budget = body.providerOptions?.thinkingConfig?.thinkingBudget || 20000 + const thinking = resolveThinkingConfig(model, body) + const think = thinking.enabled + const budget = thinking.budget let reductionFactor = 1.0 let retry = 0 @@ -100,7 +102,6 @@ export class RequestHandler { try { const res = await fetch(prep.url, prep.init) - if (apiTimestamp) { this.logResponse(res, prep, apiTimestamp) } @@ -135,7 +136,7 @@ export class RequestHandler { continue } - this.logError(prep, res, acc, apiTimestamp) + await this.logError(prep, res, acc, apiTimestamp) throw new Error(`Kiro Error: ${res.status}`) } catch (e) { const networkResult = await this.errorHandler.handleNetworkError( @@ -189,11 +190,12 @@ export class RequestHandler { try { b = prep.init.body ? JSON.parse(prep.init.body as string) : null } catch {} + const headers = this.redactHeaders(prep.init.headers) logger.logApiRequest( { url: prep.url, method: prep.init.method, - headers: prep.init.headers, + headers, body: b, conversationId: prep.conversationId, model: prep.effectiveModel, @@ -220,20 +222,28 @@ export class RequestHandler { ) } - private logError( + private async logError( prep: PreparedRequest, res: Response, acc: ManagedAccount, apiTimestamp: string | null - ): void { + ): Promise { const h: any = {} res.headers.forEach((v, k) => { h[k] = v }) + let errorBody: string | undefined + try { + errorBody = await res.text() + if (errorBody) { + errorBody = errorBody.replace(/\s+/g, ' ').trim().slice(0, 1000) + } + } catch {} const rData = { status: res.status, statusText: res.statusText, headers: h, + body: errorBody, error: `Kiro Error: ${res.status}`, conversationId: prep.conversationId, model: prep.effectiveModel @@ -242,12 +252,13 @@ export class RequestHandler { try { lastB = prep.init.body ? JSON.parse(prep.init.body as string) : null } catch {} + const headers = this.redactHeaders(prep.init.headers) if (!this.config.enable_log_api_request) { logger.logApiError( { url: prep.url, method: prep.init.method, - headers: prep.init.headers, + headers, body: lastB, conversationId: prep.conversationId, model: prep.effectiveModel, @@ -259,6 +270,14 @@ export class RequestHandler { } } + private redactHeaders(headers: any): any { + if (!headers || typeof headers !== 'object') return headers + const clone = { ...headers } + if ('Authorization' in clone) clone.Authorization = 'REDACTED' + if ('authorization' in clone) clone.authorization = 'REDACTED' + return clone + } + private allAccountsPermanentlyUnhealthy(): boolean { const accounts = this.accountManager.getAccounts() if (accounts.length === 0) { diff --git a/src/core/request/thinking.ts b/src/core/request/thinking.ts new file mode 100644 index 0000000..c9dbe1e --- /dev/null +++ b/src/core/request/thinking.ts @@ -0,0 +1,56 @@ +type RawBody = { + variant?: unknown + providerOptions?: { + variant?: unknown + modelVariant?: unknown + thinkingConfig?: { + thinkingBudget?: unknown + } + } +} + +export type ThinkingVariant = 'low' | 'medium' | 'high' | 'max' + +export function resolveThinkingConfig( + model: string, + body: unknown, + defaults: { budget: number } = { budget: 20000 } +): { enabled: boolean; budget: number; variant?: string } { + const b = (body || {}) as RawBody + + const rawVariant = + (typeof b.variant === 'string' && b.variant) || + (typeof b.providerOptions?.variant === 'string' && b.providerOptions.variant) || + (typeof b.providerOptions?.modelVariant === 'string' && b.providerOptions.modelVariant) || + undefined + + const explicitBudget = b.providerOptions?.thinkingConfig?.thinkingBudget + const budgetFromBody = + typeof explicitBudget === 'number' && Number.isFinite(explicitBudget) && explicitBudget > 0 + ? explicitBudget + : undefined + + const budgetFromVariant = variantToBudget(rawVariant) + const enabled = + model.endsWith('-thinking') || + budgetFromBody !== undefined || + b.providerOptions?.thinkingConfig !== undefined || + budgetFromVariant !== undefined + + return { + enabled, + budget: budgetFromBody ?? budgetFromVariant ?? defaults.budget, + variant: rawVariant + } +} + +function variantToBudget(rawVariant: string | undefined): number | undefined { + if (!rawVariant) return undefined + + const v = rawVariant.toLowerCase() + if (v === 'low') return 8192 + if (v === 'medium') return 16384 + // "high" is the documented name, but "max" is kept as backward compat + if (v === 'high' || v === 'max') return 32768 + return undefined +} diff --git a/src/infrastructure/transformers/history-builder.ts b/src/infrastructure/transformers/history-builder.ts index 43f2342..43adc10 100644 --- a/src/infrastructure/transformers/history-builder.ts +++ b/src/infrastructure/transformers/history-builder.ts @@ -36,7 +36,6 @@ export function buildHistory( history.push({ userInputMessage: { content: system, - modelId: resolved, origin: KIRO_CONSTANTS.ORIGIN_AI_EDITOR } }) @@ -46,7 +45,7 @@ export function buildHistory( const m = msgs[i] if (!m) continue if (m.role === 'user') { - const uim: any = { content: '', modelId: resolved, origin: KIRO_CONSTANTS.ORIGIN_AI_EDITOR } + const uim: any = { content: '', origin: KIRO_CONSTANTS.ORIGIN_AI_EDITOR } const trs: any[] = [] if (Array.isArray(m.content)) { @@ -97,7 +96,6 @@ export function buildHistory( history.push({ userInputMessage: { content: 'Tool results provided.', - modelId: resolved, origin: KIRO_CONSTANTS.ORIGIN_AI_EDITOR, userInputMessageContext: { toolResults: deduplicateToolResults(trs) } } diff --git a/src/kiro/oauth-idc.ts b/src/kiro/oauth-idc.ts index 2e5dd24..6d069b4 100644 --- a/src/kiro/oauth-idc.ts +++ b/src/kiro/oauth-idc.ts @@ -1,5 +1,5 @@ -import { KIRO_AUTH_SERVICE, KIRO_CONSTANTS, buildUrl, normalizeRegion } from '../constants' -import type { KiroRegion } from '../plugin/types' +import { KIRO_AUTH_SERVICE, KIRO_CONSTANTS, buildUrl, normalizeRegion } from '../constants.js' +import type { KiroRegion } from '../plugin/types.js' export interface KiroIDCAuthorization { verificationUrl: string @@ -24,10 +24,18 @@ export interface KiroIDCTokenResult { authMethod: 'idc' } -export async function authorizeKiroIDC(region?: KiroRegion): Promise { +export async function authorizeKiroIDC( + region?: KiroRegion, + builderIdStartUrl?: string +): Promise { const effectiveRegion = normalizeRegion(region) const ssoOIDCEndpoint = buildUrl(KIRO_AUTH_SERVICE.SSO_OIDC_ENDPOINT, effectiveRegion) + const startUrl = await resolveStartUrl( + builderIdStartUrl || KIRO_AUTH_SERVICE.BUILDER_ID_START_URL, + KIRO_CONSTANTS.USER_AGENT + ) + try { const registerResponse = await fetch(`${ssoOIDCEndpoint}/client/register`, { method: 'POST', @@ -66,7 +74,7 @@ export async function authorizeKiroIDC(region?: KiroRegion): Promise { + const candidate = normalizeStartUrl(input) + const parsed = new URL(candidate) + + // Best effort: follow redirects to discover the canonical access portal hostname. + // Some orgs use vanity domains that redirect to the underlying d-*.awsapps.com portal. + try { + const controller = new AbortController() + const timeout = setTimeout(() => controller.abort(), 4000) + const res = await fetch(candidate, { + method: 'GET', + redirect: 'follow', + headers: { 'User-Agent': userAgent }, + signal: controller.signal + }) + clearTimeout(timeout) + + const finalUrl = res.url ? new URL(res.url) : null + if (finalUrl) { + const host = finalUrl.hostname.toLowerCase() + if (host.endsWith('.awsapps.com') || host === 'view.awsapps.com') { + return `${finalUrl.origin}/start` + } + } + } catch { + // Ignore and fall back to candidate. + } + + // If there was no useful redirect, return normalized origin/start. + return `${parsed.origin}/start` +} + export async function pollKiroIDCToken( clientId: string, clientSecret: string, diff --git a/src/opencode-config.ts b/src/opencode-config.ts new file mode 100644 index 0000000..cc8bb30 --- /dev/null +++ b/src/opencode-config.ts @@ -0,0 +1,44 @@ +import type { Config } from '@opencode-ai/sdk' + +import { KIRO_CONSTANTS, buildUrl, normalizeRegion } from './constants.js' + +function isNonEmptyString(value: unknown): value is string { + return typeof value === 'string' && value.trim().length > 0 +} + +export function getKiroOpenAICompatibleBaseURL(region: string | undefined): string { + const normalizedRegion = normalizeRegion(region) + const template = KIRO_CONSTANTS.BASE_URL.replace('/generateAssistantResponse', '') + return buildUrl(template, normalizedRegion) +} + +/** + * Ensure OpenCode's provider options include a baseURL. + * + * OpenCode wires `provider..options.baseURL` into the bundled + * `@ai-sdk/openai-compatible` provider. If missing, it can attempt to call + * `undefined/chat/completions`. + */ +export function ensureProviderBaseURL( + config: Config, + providerId: string, + baseURL: string +): boolean { + if (!config.provider) { + config.provider = {} + } + + const provider = config.provider[providerId] ?? {} + config.provider[providerId] = provider + + if (!provider.options) { + provider.options = {} + } + + if (isNonEmptyString(provider.options.baseURL)) { + return false + } + + provider.options.baseURL = baseURL + return true +} diff --git a/src/plugin.ts b/src/plugin.ts index aa987e2..9a6cbde 100644 --- a/src/plugin.ts +++ b/src/plugin.ts @@ -1,3 +1,4 @@ +import type { Config as OpencodeConfig } from '@opencode-ai/sdk' import { KIRO_CONSTANTS } from './constants.js' import { AuthHandler } from './core/auth/auth-handler.js' import { RequestHandler } from './core/request/request-handler.js' @@ -6,6 +7,8 @@ import { AccountRepository } from './infrastructure/database/account-repository. import { AccountManager } from './plugin/accounts.js' import { loadConfig } from './plugin/config/index.js' +import { ensureProviderBaseURL, getKiroOpenAICompatibleBaseURL } from './opencode-config.js' + type ToastFunction = (message: string, variant: string) => void const KIRO_PROVIDER_ID = 'kiro' @@ -29,14 +32,33 @@ export const createKiroPlugin = const requestHandler = new RequestHandler(accountManager, config, repository) return { + config: async (opencodeConfig: OpencodeConfig) => { + const baseURL = getKiroOpenAICompatibleBaseURL(config.default_region) + ensureProviderBaseURL(opencodeConfig, id, baseURL) + + // OpenCode wires `provider..options` into the bundled @ai-sdk/openai-compatible. + // We need to ensure requests to /chat/completions are intercepted and translated to + // CodeWhisperer/Kiro APIs. + const provider = opencodeConfig.provider?.[id] + if (provider) { + provider.options = provider.options || {} + if (!provider.options.fetch) { + provider.options.fetch = (input: any, init?: any) => + requestHandler.handle(input, init, showToast) + } + } + }, auth: { provider: id, loader: async (getAuth: any) => { - await getAuth() + const stored = (await getAuth()) || {} await authHandler.initialize() return { - apiKey: '', + // OpenCode uses apiKey presence to determine "connected". + // We don't require an OpenAI-style key for requests, but returning the stored + // key (set by the auth method callback) makes the UI reflect connection state. + apiKey: stored.apiKey || stored.key || '', baseURL: KIRO_CONSTANTS.BASE_URL.replace('/generateAssistantResponse', '').replace( '{{region}}', config.default_region || 'us-east-1' diff --git a/src/plugin/accounts.ts b/src/plugin/accounts.ts index f5e9004..1f4d17e 100644 --- a/src/plugin/accounts.ts +++ b/src/plugin/accounts.ts @@ -95,25 +95,36 @@ export class AccountManager { } return !(a.rateLimitResetTime && now < a.rateLimitResetTime) }) + + const hasRealEmail = available.some((a) => !a.email.endsWith('@awsapps.local')) + const selectable = hasRealEmail + ? available.filter((a) => !a.email.endsWith('@awsapps.local')) + : available + let selected: ManagedAccount | undefined - if (available.length > 0) { + if (selectable.length > 0) { if (this.strategy === 'sticky') { - selected = available.find((_, i) => i === this.cursor) || available[0] + selected = selectable.find((_, i) => i === this.cursor) || selectable[0] } else if (this.strategy === 'round-robin') { - selected = available[this.cursor % available.length] - this.cursor = (this.cursor + 1) % available.length + selected = selectable[this.cursor % selectable.length] + this.cursor = (this.cursor + 1) % selectable.length } else if (this.strategy === 'lowest-usage') { - selected = [...available].sort( + selected = [...selectable].sort( (a, b) => (a.usedCount || 0) - (b.usedCount || 0) || (a.lastUsed || 0) - (b.lastUsed || 0) )[0] } } if (!selected) { - const fallback = this.accounts - .filter((a) => !a.isHealthy && a.failCount < 10 && !isPermanentError(a.unhealthyReason)) - .sort( - (a, b) => (a.usedCount || 0) - (b.usedCount || 0) || (a.lastUsed || 0) - (b.lastUsed || 0) - )[0] + const fallbackPool = this.accounts.filter( + (a) => !a.isHealthy && a.failCount < 10 && !isPermanentError(a.unhealthyReason) + ) + const preferFallback = fallbackPool.some((a) => !a.email.endsWith('@awsapps.local')) + ? fallbackPool.filter((a) => !a.email.endsWith('@awsapps.local')) + : fallbackPool + + const fallback = preferFallback.sort( + (a, b) => (a.usedCount || 0) - (b.usedCount || 0) || (a.lastUsed || 0) - (b.lastUsed || 0) + )[0] if (fallback) { fallback.isHealthy = true delete fallback.unhealthyReason diff --git a/src/plugin/auth-page.ts b/src/plugin/auth-page.ts index 055ef7f..c9aa8dd 100644 --- a/src/plugin/auth-page.ts +++ b/src/plugin/auth-page.ts @@ -231,12 +231,6 @@ export function getIDCAuthHtml( }).catch(() => {}); } - window.addEventListener('load', () => { - setTimeout(() => { - window.open(verificationUrl, '_blank'); - }, 500); - }); - async function checkStatus() { try { const response = await fetch(statusUrl); @@ -245,7 +239,9 @@ export function getIDCAuthHtml( if (data.status === 'success') { window.location.href = '/success'; } else if (data.status === 'failed' || data.status === 'timeout') { - window.location.href = '/error?message=' + encodeURIComponent(data.message || 'Authentication failed'); + window.location.href = + '/error?message=' + + encodeURIComponent(data.error || data.message || 'Authentication failed'); } } catch (error) { console.error('Status check failed:', error); @@ -259,6 +255,492 @@ export function getIDCAuthHtml( ` } +export function getIDCCombinedHtml( + defaultStartUrl: string, + defaultRegion: string, + beginUrl: string, + statusUrl: string +): string { + const escapedStartUrl = escapeHtml(defaultStartUrl) + const escapedRegion = escapeHtml(defaultRegion) + const escapedBeginUrl = escapeHtml(beginUrl) + const escapedStatusUrl = escapeHtml(statusUrl) + + return ` + + + + + AWS Builder ID Authentication + + + +
+

AWS Builder ID Authentication

+

Using the defaults below. If needed, edit them and click Begin. The AWS verification page opens only when you click "Open Browser".

+ +
+
+ + +
We normalize to https://<host>/start and may follow redirects to the canonical AWS access portal.
+
+
+ + +
Region for your IAM Identity Center instance (e.g. eu-west-1).
+
+
+ +
+
+
+ +
+ +
+
+
User Code
+
+
Click to copy
+
+ +
+
Verification URL
+ Open Browser +
+ +
+
+ Preparing authorization... +
+
+
+ + + +` +} + +export function getIDCStartHtml( + defaultStartUrl: string, + defaultRegion: string, + beginUrl: string +): string { + const escapedStartUrl = escapeHtml(defaultStartUrl) + const escapedRegion = escapeHtml(defaultRegion) + const escapedBeginUrl = escapeHtml(beginUrl) + + return ` + + + + + AWS Builder ID Authentication + + + +
+

AWS Builder ID Authentication

+

Choose the AWS start URL and begin the device authorization flow.

+
+
+ + +
Paste your org portal URL. We'll normalize it and may follow redirects to the underlying AWS access portal hostname.
+
+
+ + +
Region for your IAM Identity Center instance (e.g. us-east-1, eu-west-1).
+
+
+ +
+
+
+ +` +} + export function getSuccessHtml(): string { return ` diff --git a/src/plugin/config/loader.ts b/src/plugin/config/loader.ts index 5e20595..fb14665 100644 --- a/src/plugin/config/loader.ts +++ b/src/plugin/config/loader.ts @@ -152,11 +152,15 @@ function applyEnvOverrides(config: KiroConfig): KiroConfig { config.auth_server_port_range ), + builder_id_start_url: config.builder_id_start_url, + usage_tracking_enabled: parseBooleanEnv( env.KIRO_USAGE_TRACKING_ENABLED, config.usage_tracking_enabled ), + usage_toast_enabled: parseBooleanEnv(env.KIRO_USAGE_TOAST_ENABLED, config.usage_toast_enabled), + enable_log_api_request: parseBooleanEnv( env.KIRO_ENABLE_LOG_API_REQUEST, config.enable_log_api_request diff --git a/src/plugin/config/schema.ts b/src/plugin/config/schema.ts index ba781eb..bcbfb65 100644 --- a/src/plugin/config/schema.ts +++ b/src/plugin/config/schema.ts @@ -3,7 +3,8 @@ import { z } from 'zod' export const AccountSelectionStrategySchema = z.enum(['sticky', 'round-robin', 'lowest-usage']) export type AccountSelectionStrategy = z.infer -export const RegionSchema = z.enum(['us-east-1', 'us-west-2']) +const REGION_REGEX = /^[a-z]{2}-[a-z-]+-\d+$/ +export const RegionSchema = z.string().regex(REGION_REGEX) export type Region = z.infer export const KiroConfigSchema = z.object({ @@ -29,7 +30,10 @@ export const KiroConfigSchema = z.object({ auth_server_port_range: z.number().min(1).max(100).default(10), + builder_id_start_url: z.string().url().default('https://view.awsapps.com/start'), + usage_tracking_enabled: z.boolean().default(true), + usage_toast_enabled: z.boolean().default(false), auto_sync_kiro_cli: z.boolean().default(true), enable_log_api_request: z.boolean().default(false) }) @@ -47,7 +51,9 @@ export const DEFAULT_CONFIG: KiroConfig = { usage_sync_max_retries: 3, auth_server_port_start: 19847, auth_server_port_range: 10, + builder_id_start_url: 'https://view.awsapps.com/start', usage_tracking_enabled: true, + usage_toast_enabled: false, auto_sync_kiro_cli: true, enable_log_api_request: false } diff --git a/src/plugin/request.ts b/src/plugin/request.ts index 669463b..37861d3 100644 --- a/src/plugin/request.ts +++ b/src/plugin/request.ts @@ -13,10 +13,7 @@ import { mergeAdjacentMessages, truncate } from '../infrastructure/transformers/message-transformer.js' -import { - convertToolsToCodeWhisperer, - deduplicateToolResults -} from '../infrastructure/transformers/tool-transformer.js' +import { deduplicateToolResults } from '../infrastructure/transformers/tool-transformer.js' import { convertImagesToKiroFormat, extractAllImages, @@ -47,7 +44,7 @@ export function transformToCodeWhisperer( const msgs = mergeAdjacentMessages([...messages]) const lastMsg = msgs[msgs.length - 1] if (lastMsg && lastMsg.role === 'assistant' && getContentText(lastMsg) === '{') msgs.pop() - const cwTools = tools ? convertToolsToCodeWhisperer(tools) : [] + const cwTools: any[] = [] const toolResultLimit = Math.floor(250000 * reductionFactor) let history = buildHistory(msgs, resolved, sys, toolResultLimit) const historyLimit = Math.floor(850000 * reductionFactor) @@ -139,12 +136,14 @@ export function transformToCodeWhisperer( currentMessage: { userInputMessage: { content: curContent, - modelId: resolved, origin: KIRO_CONSTANTS.ORIGIN_AI_EDITOR } } } } + if (auth.profileArn) { + request.profileArn = auth.profileArn + } const toolUsesInHistory = history.flatMap((h) => h.assistantResponseMessage?.toolUses || []) const allToolUseIdsInHistory = new Set(toolUsesInHistory.map((tu) => tu.toolUseId)) const finalCurTrs: any[] = [] @@ -175,7 +174,6 @@ export function transformToCodeWhisperer( history.push({ userInputMessage: { content: 'Running tools...', - modelId: resolved, origin: KIRO_CONSTANTS.ORIGIN_AI_EDITOR } }) @@ -195,7 +193,9 @@ export function transformToCodeWhisperer( if (curImgs.length) uim.images = curImgs const ctx: any = {} if (finalCurTrs.length) ctx.toolResults = deduplicateToolResults(finalCurTrs) - if (cwTools.length) ctx.tools = cwTools + if (cwTools.length) { + ctx.tools = cwTools + } if (Object.keys(ctx).length) uim.userInputMessageContext = ctx const hasToolsInHistory = historyHasToolCalling(history) if (hasToolsInHistory) { diff --git a/src/plugin/server.ts b/src/plugin/server.ts index 42193f4..0dbc954 100644 --- a/src/plugin/server.ts +++ b/src/plugin/server.ts @@ -1,7 +1,9 @@ import { createServer, type Server, type ServerResponse } from 'node:http' -import { getErrorHtml, getIDCAuthHtml, getSuccessHtml } from './auth-page' -import * as logger from './logger' -import type { KiroRegion } from './types' +import { KIRO_CONSTANTS } from '../constants.js' +import { authorizeKiroIDC } from '../kiro/oauth-idc.js' +import { getErrorHtml, getIDCCombinedHtml, getSuccessHtml } from './auth-page.js' +import * as logger from './logger.js' +import type { KiroRegion } from './types.js' export interface KiroIDCTokenResult { email: string @@ -23,6 +25,11 @@ export interface IDCAuthData { region: KiroRegion } +interface IDCAuthServerOptions { + defaultRegion: KiroRegion + defaultStartUrl: string +} + async function tryPort(port: number): Promise { return new Promise((resolve) => { const testServer = createServer() @@ -47,7 +54,7 @@ async function findAvailablePort(startPort: number, range: number): Promise Promise }> { @@ -66,7 +73,10 @@ export async function startIDCAuthServer( let timeoutId: any = null let resolver: any = null let rejector: any = null - const status: any = { status: 'pending' } + const status: any = { status: 'idle' } + + let authData: IDCAuthData | null = null + let pollGeneration = 0 const cleanup = () => { if (timeoutId) clearTimeout(timeoutId) @@ -77,18 +87,31 @@ export async function startIDCAuthServer( res.end(html) } - const poll = async () => { + const poll = async (generation: number) => { try { - const body = { - grantType: 'urn:ietf:params:oauth:grant-type:device_code', - deviceCode: authData.deviceCode, - clientId: authData.clientId, - clientSecret: authData.clientSecret + if (!authData) { + return + } + if (generation !== pollGeneration) { + return } + // AWS SSO OIDC CreateToken expects JSON keys (clientId/clientSecret/deviceCode/grantType) + // matching the StartDeviceAuthorization flow. + const body = JSON.stringify({ + clientId: authData.clientId, + clientSecret: authData.clientSecret, + deviceCode: authData.deviceCode, + grantType: 'urn:ietf:params:oauth:grant-type:device_code' + }) + const res = await fetch(`https://oidc.${authData.region}.amazonaws.com/token`, { method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify(body) + headers: { + 'Content-Type': 'application/json', + 'User-Agent': KIRO_CONSTANTS.USER_AGENT, + Accept: 'application/json' + }, + body }) const responseText = await res.text() @@ -138,7 +161,7 @@ export async function startIDCAuthServer( }) setTimeout(cleanup, 2000) } else if (d.error === 'authorization_pending') { - setTimeout(poll, authData.interval * 1000) + setTimeout(() => poll(generation), authData.interval * 1000) } else { status.status = 'failed' status.error = d.error_description || d.error @@ -156,22 +179,77 @@ export async function startIDCAuthServer( } server = createServer((req, res) => { - const u = req.url || '' - if (u === '/' || u.startsWith('/?')) + const parsed = new URL(req.url || '/', `http://127.0.0.1:${port}`) + const pathname = parsed.pathname + if (pathname === '/') { sendHtml( res, - getIDCAuthHtml( - authData.verificationUriComplete, - authData.userCode, + getIDCCombinedHtml( + options.defaultStartUrl, + options.defaultRegion, + `http://127.0.0.1:${port}/begin`, `http://127.0.0.1:${port}/status` ) ) - else if (u === '/status') { + } else if (pathname === '/begin') { + ;(async () => { + try { + const startUrl = parsed.searchParams.get('startUrl') || options.defaultStartUrl + const region = parsed.searchParams.get('region') || options.defaultRegion + + // Validate region format early to avoid confusing OIDC errors. + if (!/^[a-z]{2}-[a-z-]+-\d+$/.test(region)) { + throw new Error(`Invalid region: ${region}`) + } + + status.status = 'pending' + delete status.error + + pollGeneration++ + const generation = pollGeneration + + if (timeoutId) clearTimeout(timeoutId) + timeoutId = setTimeout(() => { + status.status = 'timeout' + logger.warn('Auth timeout waiting for authorization') + if (rejector) rejector(new Error('Timeout')) + cleanup() + }, 900000) + + const d = await authorizeKiroIDC(region, startUrl) + authData = d as unknown as IDCAuthData + + res.writeHead(200, { 'Content-Type': 'application/json' }) + res.end( + JSON.stringify({ + verificationUrl: authData.verificationUrl, + verificationUriComplete: authData.verificationUriComplete, + userCode: authData.userCode, + region: authData.region + }) + ) + + poll(generation) + } catch (e: any) { + const msg = e?.message || 'Failed to begin authentication' + status.status = 'failed' + status.error = msg + res.writeHead(400, { 'Content-Type': 'application/json' }) + res.end(JSON.stringify({ message: msg })) + } + })().catch(() => {}) + } else if (pathname === '/status') { res.writeHead(200, { 'Content-Type': 'application/json' }) - res.end(JSON.stringify(status)) - } else if (u === '/success') sendHtml(res, getSuccessHtml()) - else if (u === '/error') sendHtml(res, getErrorHtml(status.error || 'Failed')) - else { + const payload = { + ...status, + message: status.error + } + res.end(JSON.stringify(payload)) + } else if (pathname === '/success') sendHtml(res, getSuccessHtml()) + else if (pathname === '/error') { + const msg = parsed.searchParams.get('message') || status.error || 'Failed' + sendHtml(res, getErrorHtml(msg)) + } else { res.writeHead(404) res.end() } @@ -183,13 +261,6 @@ export async function startIDCAuthServer( reject(e) }) server.listen(port, '127.0.0.1', () => { - timeoutId = setTimeout(() => { - status.status = 'timeout' - logger.warn('Auth timeout waiting for authorization') - if (rejector) rejector(new Error('Timeout')) - cleanup() - }, 900000) - poll() resolve({ url: `http://127.0.0.1:${port}`, waitForAuth: () => diff --git a/src/plugin/storage/locked-operations.ts b/src/plugin/storage/locked-operations.ts index 7825bc2..9d4f291 100644 --- a/src/plugin/storage/locked-operations.ts +++ b/src/plugin/storage/locked-operations.ts @@ -21,6 +21,13 @@ export async function withDatabaseLock(dbPath: string, fn: () => Promise): if (!existsSync(dbPath)) { const dir = dbPath.substring(0, dbPath.lastIndexOf('/')) await fs.mkdir(dir, { recursive: true }) + + // If the DB was removed but WAL/SHM sidecars remain, SQLite can fail with + // "disk I/O error" when opening the newly created DB. + await fs.rm(`${dbPath}-wal`, { force: true }).catch(() => {}) + await fs.rm(`${dbPath}-shm`, { force: true }).catch(() => {}) + await fs.rm(lockPath, { force: true }).catch(() => {}) + await fs.writeFile(dbPath, '') } @@ -63,8 +70,25 @@ export function mergeAccounts( const existingAcc = accountMap.get(acc.id) if (existingAcc) { + const refreshChanged = + typeof acc.refreshToken === 'string' && acc.refreshToken !== existingAcc.refreshToken + const accessChanged = + typeof acc.accessToken === 'string' && acc.accessToken !== existingAcc.accessToken + const clientIdChanged = + typeof acc.clientId === 'string' && acc.clientId !== existingAcc.clientId + const clientSecretChanged = + typeof acc.clientSecret === 'string' && acc.clientSecret !== existingAcc.clientSecret + const incomingIsFresh = (acc.lastSync || 0) >= (existingAcc.lastSync || 0) + const allowRecovery = + refreshChanged || + accessChanged || + clientIdChanged || + clientSecretChanged || + (acc.isHealthy && incomingIsFresh) + const hasPermanentError = - isPermanentError(existingAcc.unhealthyReason) || isPermanentError(acc.unhealthyReason) + !allowRecovery && + (isPermanentError(existingAcc.unhealthyReason) || isPermanentError(acc.unhealthyReason)) accountMap.set(acc.id, { ...existingAcc, @@ -77,7 +101,13 @@ export function mergeAccounts( acc.rateLimitResetTime || 0 ), isHealthy: hasPermanentError ? false : existingAcc.isHealthy || acc.isHealthy, - failCount: Math.max(existingAcc.failCount || 0, acc.failCount || 0), + unhealthyReason: hasPermanentError + ? existingAcc.unhealthyReason || acc.unhealthyReason + : acc.unhealthyReason, + recoveryTime: hasPermanentError ? existingAcc.recoveryTime : acc.recoveryTime, + failCount: hasPermanentError + ? Math.max(existingAcc.failCount || 0, acc.failCount || 0) + : acc.failCount || 0, lastSync: Math.max(existingAcc.lastSync || 0, acc.lastSync || 0) }) } else { diff --git a/src/plugin/storage/sqlite-recovery.ts b/src/plugin/storage/sqlite-recovery.ts new file mode 100644 index 0000000..ab8284a --- /dev/null +++ b/src/plugin/storage/sqlite-recovery.ts @@ -0,0 +1,12 @@ +import { rmSync } from 'node:fs' + +export function cleanupSqliteSidecars(dbPath: string): void { + const candidates = [`${dbPath}-wal`, `${dbPath}-shm`, `${dbPath}.lock`] + for (const p of candidates) { + try { + rmSync(p, { force: true }) + } catch { + // Best-effort cleanup. + } + } +} diff --git a/src/plugin/storage/sqlite.ts b/src/plugin/storage/sqlite.ts index c7e29a8..3811eac 100644 --- a/src/plugin/storage/sqlite.ts +++ b/src/plugin/storage/sqlite.ts @@ -5,6 +5,7 @@ import { join } from 'node:path' import type { ManagedAccount } from '../types' import { deduplicateAccounts, mergeAccounts, withDatabaseLock } from './locked-operations' import { runMigrations } from './migrations' +import { cleanupSqliteSidecars } from './sqlite-recovery.js' function getBaseDir(): string { const p = process.platform @@ -25,7 +26,23 @@ export class KiroDatabase { if (!existsSync(dir)) mkdirSync(dir, { recursive: true }) this.db = new Database(path) this.db.run('PRAGMA busy_timeout = 5000') - this.init() + + try { + this.init() + } catch (e) { + if (!isRecoverableSqliteIoError(e)) throw e + + // Common case: kiro.db was deleted while kiro.db-wal/kiro.db-shm remained. + // SQLite can report this as "disk I/O error". Best-effort cleanup and retry once. + try { + this.db.close() + } catch {} + cleanupSqliteSidecars(this.path) + + this.db = new Database(path) + this.db.run('PRAGMA busy_timeout = 5000') + this.init() + } } private init() { this.db.run('PRAGMA journal_mode = WAL') @@ -56,14 +73,25 @@ export class KiroDatabase { is_healthy, unhealthy_reason, recovery_time, fail_count, last_used, used_count, limit_count, last_sync ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ON CONFLICT(refresh_token) DO UPDATE SET - id=excluded.id, email=excluded.email, auth_method=excluded.auth_method, - region=excluded.region, client_id=excluded.client_id, client_secret=excluded.client_secret, - profile_arn=excluded.profile_arn, access_token=excluded.access_token, expires_at=excluded.expires_at, - rate_limit_reset=excluded.rate_limit_reset, is_healthy=excluded.is_healthy, - unhealthy_reason=excluded.unhealthy_reason, recovery_time=excluded.recovery_time, - fail_count=excluded.fail_count, last_used=excluded.last_used, - used_count=excluded.used_count, limit_count=excluded.limit_count, last_sync=excluded.last_sync + ON CONFLICT(id) DO UPDATE SET + email=excluded.email, + auth_method=excluded.auth_method, + region=excluded.region, + client_id=excluded.client_id, + client_secret=excluded.client_secret, + profile_arn=excluded.profile_arn, + refresh_token=excluded.refresh_token, + access_token=excluded.access_token, + expires_at=excluded.expires_at, + rate_limit_reset=excluded.rate_limit_reset, + is_healthy=excluded.is_healthy, + unhealthy_reason=excluded.unhealthy_reason, + recovery_time=excluded.recovery_time, + fail_count=excluded.fail_count, + last_used=excluded.last_used, + used_count=excluded.used_count, + limit_count=excluded.limit_count, + last_sync=excluded.last_sync ` ) .run( @@ -162,6 +190,20 @@ export class KiroDatabase { } } +function isRecoverableSqliteIoError(e: unknown): boolean { + const msg = + e instanceof Error + ? e.message + : typeof e === 'string' + ? e + : typeof (e as any)?.message === 'string' + ? String((e as any).message) + : '' + + const m = msg.toLowerCase() + return m.includes('disk i/o error') || m.includes('sqlite_ioerr') +} + export function createDatabase(path?: string): KiroDatabase { return new KiroDatabase(path) } diff --git a/src/plugin/sync/idc-region.ts b/src/plugin/sync/idc-region.ts new file mode 100644 index 0000000..378c664 --- /dev/null +++ b/src/plugin/sync/idc-region.ts @@ -0,0 +1,13 @@ +let idcRegion: string | undefined + +export function setIdcRegionFromState(region: string | undefined): void { + if (typeof region === 'string' && region.trim()) { + idcRegion = region.trim() + return + } + idcRegion = undefined +} + +export function getIdcRegionFromState(): string | undefined { + return idcRegion +} diff --git a/src/plugin/sync/kiro-cli.ts b/src/plugin/sync/kiro-cli.ts index 8219d58..c43eff8 100644 --- a/src/plugin/sync/kiro-cli.ts +++ b/src/plugin/sync/kiro-cli.ts @@ -1,9 +1,11 @@ import { Database } from 'bun:sqlite' import { existsSync } from 'node:fs' +import { normalizeRegion } from '../../constants.js' import { createDeterministicAccountId } from '../accounts' import * as logger from '../logger' import { kiroDb } from '../storage/sqlite' import { fetchUsageLimits } from '../usage' +import { setIdcRegionFromState } from './idc-region' import { findClientCredsRecursive, getCliDbPath, @@ -20,26 +22,81 @@ export async function syncFromKiroCli() { cliDb.run('PRAGMA busy_timeout = 5000') const rows = cliDb.prepare('SELECT key, value FROM auth_kv').all() as any[] - const deviceRegRow = rows.find( + let profileArnFromState: string | undefined + try { + const idcRegionRow = cliDb + .prepare('SELECT value FROM state WHERE key = ?') + .get('auth.idc.region') as { value?: string } | undefined + const parsedRegion = safeJsonParse(idcRegionRow?.value) + if (typeof parsedRegion === 'string') { + setIdcRegionFromState(parsedRegion) + } + const profileRow = cliDb + .prepare('SELECT value FROM state WHERE key = ?') + .get('api.codewhisperer.profile') as { value?: string } | undefined + const profile = safeJsonParse(profileRow?.value) + if (profile && typeof profile.arn === 'string') { + profileArnFromState = profile.arn + } + } catch { + setIdcRegionFromState(undefined) + } + + const tokenRows = rows.filter((r) => typeof r?.key === 'string' && r.key.includes(':token')) + const parsedTokens = tokenRows + .map((row) => { + const data = safeJsonParse(row.value) + const expiresAt = normalizeExpiresAt(data?.expires_at ?? data?.expiresAt) + return { row, data, expiresAt } + }) + .filter((t) => t.data) + + const now = Date.now() + const validTokens = parsedTokens.filter((t) => t.expiresAt > now) + const candidates = validTokens.length ? validTokens : parsedTokens + + let tokenRowsToImport = tokenRows + if (candidates.length > 0) { + const maxExpiresAt = Math.max(...candidates.map((t) => t.expiresAt || 0)) + tokenRowsToImport = candidates.filter((t) => t.expiresAt === maxExpiresAt).map((t) => t.row) + } + + const deviceRegRows = rows.filter( (r) => typeof r?.key === 'string' && r.key.includes('device-registration') ) - const deviceReg = safeJsonParse(deviceRegRow?.value) - const regCreds = deviceReg ? findClientCredsRecursive(deviceReg) : {} + const deviceRegByKey = new Map() + for (const row of deviceRegRows) { + const deviceReg = safeJsonParse(row.value) + const regCreds = deviceReg ? findClientCredsRecursive(deviceReg) : {} + if (regCreds.clientId && regCreds.clientSecret) { + const baseKey = row.key.replace(':device-registration', '') + deviceRegByKey.set(baseKey, regCreds) + } + } + + const importedIds = new Set() - for (const row of rows) { + for (const row of tokenRowsToImport) { if (row.key.includes(':token')) { const data = safeJsonParse(row.value) if (!data) continue - const isIdc = row.key.includes('odic') + const isIdc = row.key.includes('odic') || row.key.includes('oidc') const authMethod = isIdc ? 'idc' : 'desktop' - const region = data.region || 'us-east-1' - const profileArn = data.profile_arn || data.profileArn - const accessToken = data.access_token || data.accessToken || '' + const profileArn = data.profile_arn || data.profileArn || profileArnFromState + const regionFromProfile = profileArn?.split(':')[3] + const region = normalizeRegion(regionFromProfile || data.region) const refreshToken = data.refresh_token || data.refreshToken if (!refreshToken) continue + const baseKey = row.key.replace(':token', '') + const regCreds = + deviceRegByKey.get(baseKey) || + deviceRegByKey.get(baseKey.replace('kirocli', 'codewhisperer')) || + deviceRegByKey.get(baseKey.replace('codewhisperer', 'kirocli')) || + {} + const clientId = data.client_id || data.clientId || (isIdc ? regCreds.clientId : undefined) const clientSecret = data.client_secret || data.clientSecret || (isIdc ? regCreds.clientSecret : undefined) @@ -108,7 +165,8 @@ export async function syncFromKiroCli() { if ( existingById && existingById.is_healthy === 1 && - existingById.expires_at >= cliExpiresAt + existingById.expires_at >= cliExpiresAt && + existingById.region === region ) continue @@ -123,27 +181,14 @@ export async function syncFromKiroCli() { if (placeholderId !== id) { const placeholderRow = all.find((a) => a.id === placeholderId) if (placeholderRow) { - await kiroDb.upsertAccount({ - id: placeholderId, - email: placeholderRow.email, - authMethod, - region: placeholderRow.region || region, - clientId, - clientSecret, - profileArn, - refreshToken: placeholderRow.refresh_token || refreshToken, - accessToken: placeholderRow.access_token || accessToken, - expiresAt: placeholderRow.expires_at || cliExpiresAt, - rateLimitResetTime: 0, - isHealthy: false, - failCount: 10, - unhealthyReason: 'Replaced by real email', - recoveryTime: Date.now() + 31536000000, - usedCount: placeholderRow.used_count || 0, - limitCount: placeholderRow.limit_count || 0, - lastSync: Date.now() - }) + usedCount = Math.max(usedCount, placeholderRow.used_count || 0) + limitCount = Math.max(limitCount, placeholderRow.limit_count || 0) } + + // We enforce a unique index on refresh_token. When we later insert the real-email + // account (different id) using the same refresh token, a placeholder row would + // violate that constraint. Delete it now; it will be recreated under the real id. + await kiroDb.deleteAccount(placeholderId) } } @@ -165,6 +210,19 @@ export async function syncFromKiroCli() { limitCount, lastSync: Date.now() }) + importedIds.add(id) + } + } + + const existing = kiroDb.getAccounts() + for (const acc of existing) { + if ( + typeof acc?.email === 'string' && + acc.email.endsWith('@awsapps.local') && + acc.auth_method === 'idc' && + !importedIds.has(acc.id) + ) { + await kiroDb.deleteAccount(acc.id) } } cliDb.close() diff --git a/src/plugin/token.ts b/src/plugin/token.ts index 225e39f..bb7670d 100644 --- a/src/plugin/token.ts +++ b/src/plugin/token.ts @@ -1,14 +1,19 @@ import crypto from 'node:crypto' import { decodeRefreshToken, encodeRefreshToken } from '../kiro/auth' import { KiroTokenRefreshError } from './errors' +import * as logger from './logger' +import { getIdcRegionFromState } from './sync/idc-region' import type { KiroAuthDetails, RefreshParts } from './types' export async function refreshAccessToken(auth: KiroAuthDetails): Promise { const p = decodeRefreshToken(auth.refresh) const isIdc = auth.authMethod === 'idc' + const idcRegion = isIdc ? getIdcRegionFromState() : undefined + const profileRegion = auth.profileArn?.split(':')[3] + const authRegion = idcRegion || profileRegion || auth.region const url = isIdc - ? `https://oidc.${auth.region}.amazonaws.com/token` - : `https://prod.${auth.region}.auth.desktop.kiro.dev/refreshToken` + ? `https://oidc.${authRegion}.amazonaws.com/token` + : `https://prod.${authRegion}.auth.desktop.kiro.dev/refreshToken` if (isIdc && (!p.clientId || !p.clientSecret)) { throw new KiroTokenRefreshError('Missing creds', 'MISSING_CREDENTIALS') @@ -53,6 +58,11 @@ export async function refreshAccessToken(auth: KiroAuthDetails): Promise userInputMessageContext?: { diff --git a/test/auth-page.test.js b/test/auth-page.test.js new file mode 100644 index 0000000..9d66d1a --- /dev/null +++ b/test/auth-page.test.js @@ -0,0 +1,23 @@ +import assert from 'node:assert/strict' +import { test } from 'node:test' + +import { getIDCAuthHtml } from '../dist/plugin/auth-page.js' +import { getIDCCombinedHtml } from '../dist/plugin/auth-page.js' + +test('IDC auth page uses error field when redirecting to /error', () => { + const html = getIDCAuthHtml('https://example.invalid', 'ABCD-1234', 'http://127.0.0.1:1/status') + assert.ok( + html.includes("encodeURIComponent(data.error || data.message || 'Authentication failed')"), + 'expected auth page to prefer data.error over data.message' + ) +}) + +test('combined IDC page does not auto-open the verification URL', () => { + const html = getIDCCombinedHtml( + 'https://example.invalid/start', + 'us-east-1', + 'http://127.0.0.1:1/begin', + 'http://127.0.0.1:1/status' + ) + assert.ok(!html.includes('window.open('), 'expected no automatic tab opening') +}) diff --git a/test/auth-server.test.js b/test/auth-server.test.js new file mode 100644 index 0000000..aaeacde --- /dev/null +++ b/test/auth-server.test.js @@ -0,0 +1,99 @@ +import assert from 'node:assert/strict' +import { test } from 'node:test' + +import { startIDCAuthServer } from '../dist/plugin/server.js' + +function makeResponse({ ok, status, body, headers }) { + return { + ok, + status, + headers: { + get: (k) => (headers ? headers[k.toLowerCase()] || headers[k] : null) + }, + async text() { + return body + }, + async json() { + return JSON.parse(body) + } + } +} + +async function sleep(ms) { + await new Promise((r) => setTimeout(r, ms)) +} + +test('auth server: /status includes message alias and /error supports query params', async () => { + const originalFetch = globalThis.fetch + + globalThis.fetch = async (url, init) => { + const s = String(url) + + if (s.endsWith('/client/register')) { + return makeResponse({ + ok: true, + status: 200, + body: JSON.stringify({ clientId: 'client-id', clientSecret: 'client-secret' }) + }) + } + + if (s.endsWith('/device_authorization')) { + return makeResponse({ + ok: true, + status: 200, + body: JSON.stringify({ + verificationUri: 'https://example.invalid/verify', + verificationUriComplete: 'https://example.invalid/verify?user_code=ABCD', + userCode: 'ABCD-1234', + deviceCode: 'device-code', + interval: 1, + expiresIn: 600 + }) + }) + } + + if (s.includes('/token')) { + return makeResponse({ + ok: false, + status: 400, + body: JSON.stringify({ + error: 'invalid_grant', + error_description: 'Invalid device code provided' + }) + }) + } + return originalFetch(url, init) + } + + try { + const { url } = await startIDCAuthServer( + { defaultRegion: 'us-east-1', defaultStartUrl: 'https://example.invalid/start' }, + 19857, + 20 + ) + + // Kick off auth. + await fetch( + `${url}/begin?startUrl=${encodeURIComponent('https://example.invalid/start')}®ion=${encodeURIComponent('us-east-1')}` + ) + + // Wait for polling to run and flip status. + let status = null + for (let i = 0; i < 50; i++) { + const res = await fetch(`${url}/status`) + status = await res.json() + if (status.status !== 'pending' && status.status !== 'idle') break + await sleep(50) + } + + assert.equal(status.status, 'failed') + assert.equal(status.error, 'Invalid device code provided') + assert.equal(status.message, 'Invalid device code provided') + + const errRes = await fetch(`${url}/error?message=Hello%20World`) + const html = await errRes.text() + assert.ok(html.includes('Hello World')) + } finally { + globalThis.fetch = originalFetch + } +}) diff --git a/test/idc-start-url.test.js b/test/idc-start-url.test.js new file mode 100644 index 0000000..48de6d2 --- /dev/null +++ b/test/idc-start-url.test.js @@ -0,0 +1,78 @@ +import assert from 'node:assert/strict' +import { test } from 'node:test' + +import { authorizeKiroIDC } from '../dist/kiro/oauth-idc.js' + +function makeResponse({ ok, status, body }) { + return { + ok, + status, + async text() { + return body + }, + async json() { + return JSON.parse(body) + } + } +} + +test('authorizeKiroIDC uses configurable builderIdStartUrl', async () => { + const originalFetch = globalThis.fetch + const calls = [] + + globalThis.fetch = async (url, init) => { + calls.push({ url: String(url), init }) + + if (String(url) === 'https://example.invalid/start') { + return { + ok: true, + status: 200, + url: 'https://d-1234567890.awsapps.com/start', + async text() { + return '' + }, + async json() { + return {} + } + } + } + + if (String(url).endsWith('/client/register')) { + return makeResponse({ + ok: true, + status: 200, + body: JSON.stringify({ clientId: 'cid', clientSecret: 'csec' }) + }) + } + + if (String(url).endsWith('/device_authorization')) { + return makeResponse({ + ok: true, + status: 200, + body: JSON.stringify({ + verificationUri: 'https://example.invalid/verify', + verificationUriComplete: 'https://example.invalid/verify?user_code=AAAA', + userCode: 'AAAA-BBBB', + deviceCode: 'device-code', + interval: 5, + expiresIn: 600 + }) + }) + } + + return makeResponse({ ok: false, status: 404, body: '' }) + } + + try { + const customStartUrl = 'https://example.invalid/start/#/?tab=accounts' + await authorizeKiroIDC('us-east-1', customStartUrl) + + const deviceCall = calls.find((c) => c.url.endsWith('/device_authorization')) + assert.ok(deviceCall, 'expected /device_authorization call') + + const body = JSON.parse(deviceCall.init.body) + assert.equal(body.startUrl, 'https://d-1234567890.awsapps.com/start') + } finally { + globalThis.fetch = originalFetch + } +}) diff --git a/test/opencode-config.test.js b/test/opencode-config.test.js new file mode 100644 index 0000000..beea089 --- /dev/null +++ b/test/opencode-config.test.js @@ -0,0 +1,33 @@ +import assert from 'node:assert/strict' +import { test } from 'node:test' + +import { ensureProviderBaseURL, getKiroOpenAICompatibleBaseURL } from '../dist/opencode-config.js' + +test('getKiroOpenAICompatibleBaseURL uses region and strips path', () => { + assert.equal(getKiroOpenAICompatibleBaseURL('us-east-1'), 'https://q.us-east-1.amazonaws.com') + assert.equal(getKiroOpenAICompatibleBaseURL('us-west-2'), 'https://q.us-west-2.amazonaws.com') +}) + +test('ensureProviderBaseURL sets provider..options.baseURL if missing', () => { + const cfg = { provider: { kiro: { models: {} } } } + const changed = ensureProviderBaseURL(cfg, 'kiro', 'https://q.us-west-2.amazonaws.com') + + assert.equal(changed, true) + assert.equal(cfg.provider.kiro.options.baseURL, 'https://q.us-west-2.amazonaws.com') +}) + +test('ensureProviderBaseURL does not override an existing baseURL', () => { + const cfg = { provider: { kiro: { options: { baseURL: 'https://example.invalid' }, models: {} } } } + const changed = ensureProviderBaseURL(cfg, 'kiro', 'https://q.us-east-1.amazonaws.com') + + assert.equal(changed, false) + assert.equal(cfg.provider.kiro.options.baseURL, 'https://example.invalid') +}) + +test('ensureProviderBaseURL treats empty string as missing', () => { + const cfg = { provider: { kiro: { options: { baseURL: ' ' }, models: {} } } } + const changed = ensureProviderBaseURL(cfg, 'kiro', 'https://q.us-east-1.amazonaws.com') + + assert.equal(changed, true) + assert.equal(cfg.provider.kiro.options.baseURL, 'https://q.us-east-1.amazonaws.com') +}) diff --git a/test/sqlite-recovery.test.js b/test/sqlite-recovery.test.js new file mode 100644 index 0000000..f81c934 --- /dev/null +++ b/test/sqlite-recovery.test.js @@ -0,0 +1,35 @@ +import assert from 'node:assert/strict' +import { test } from 'node:test' + +import { mkdirSync, rmSync, writeFileSync, existsSync } from 'node:fs' +import { tmpdir } from 'node:os' +import { join } from 'node:path' + +import { cleanupSqliteSidecars } from '../dist/plugin/storage/sqlite-recovery.js' + +test('cleanupSqliteSidecars removes orphaned -wal/-shm/.lock files', () => { + const dir = join(tmpdir(), `opencode-kiro-auth-sqlite-recovery-${Date.now()}`) + mkdirSync(dir, { recursive: true }) + + const dbPath = join(dir, 'kiro.db') + + const wal = `${dbPath}-wal` + const shm = `${dbPath}-shm` + const lock = `${dbPath}.lock` + + writeFileSync(wal, 'x') + writeFileSync(shm, 'x') + writeFileSync(lock, 'x') + + assert.equal(existsSync(wal), true) + assert.equal(existsSync(shm), true) + assert.equal(existsSync(lock), true) + + cleanupSqliteSidecars(dbPath) + + assert.equal(existsSync(wal), false) + assert.equal(existsSync(shm), false) + assert.equal(existsSync(lock), false) + + rmSync(dir, { recursive: true, force: true }) +}) diff --git a/test/thinking.test.js b/test/thinking.test.js new file mode 100644 index 0000000..f20c8d5 --- /dev/null +++ b/test/thinking.test.js @@ -0,0 +1,30 @@ +import assert from 'node:assert/strict' +import { test } from 'node:test' + +import { resolveThinkingConfig } from '../dist/core/request/thinking.js' + +test('resolveThinkingConfig: model suffix enables thinking with default budget', () => { + const r = resolveThinkingConfig('claude-sonnet-4-5-thinking', {}) + assert.equal(r.enabled, true) + assert.equal(r.budget, 20000) +}) + +test('resolveThinkingConfig: explicit thinkingConfig takes precedence', () => { + const r = resolveThinkingConfig('claude-sonnet-4-5', { + providerOptions: { thinkingConfig: { thinkingBudget: 12345 } } + }) + assert.equal(r.enabled, true) + assert.equal(r.budget, 12345) +}) + +test('resolveThinkingConfig: variant low/medium/high maps to budgets', () => { + assert.equal(resolveThinkingConfig('claude-sonnet-4-5', { variant: 'low' }).budget, 8192) + assert.equal(resolveThinkingConfig('claude-sonnet-4-5', { variant: 'medium' }).budget, 16384) + assert.equal(resolveThinkingConfig('claude-sonnet-4-5', { variant: 'high' }).budget, 32768) +}) + +test('resolveThinkingConfig: max is accepted as backward-compatible alias of high', () => { + const r = resolveThinkingConfig('claude-sonnet-4-5', { providerOptions: { variant: 'max' } }) + assert.equal(r.enabled, true) + assert.equal(r.budget, 32768) +})