diff --git a/README.md b/README.md index f534f2e..de344ce 100644 --- a/README.md +++ b/README.md @@ -103,7 +103,7 @@ Notes: | `retryable_failure_cooldown_secs` | `15` | Cooldown window after retryable failures that should temporarily sideline an upstream. `0` disables cooldown. Reloading or restarting the running proxy resets current cooldown state. | | `tray_token_rate.enabled` | `true` | macOS tray live rate; harmless elsewhere. | | `tray_token_rate.format` | `split` | `combined` (`total`), `split` (`↑in ↓out`), `both` (`total | ↑in ↓out`). | -| `upstream_strategy` | `priority_fill_first` | `priority_fill_first` (default) keeps trying the highest-priority group in list order; `priority_round_robin` rotates within each priority group. | +| `upstream_strategy` | `{ "order": "fill_first", "dispatch": { "type": "serial" } }` | Structured strategy object. `order` controls candidate ordering inside one priority group; `dispatch` controls serial / hedged / race execution. | ### Upstream entries (`upstreams[]`) | Field | Default | Notes | @@ -147,7 +147,14 @@ Notes: - **Gemini**: `upstream.api_key` → `x-goog-api-key` → query `?key=...` → error. ## Load balancing & retries -- Priorities: higher `priority` groups first; inside a group use list order (fill-first) or round-robin (if `priority_round_robin`). +- Priorities: higher `priority` groups first. +- `upstream_strategy.order` controls selection inside the same priority group: + - `fill_first`: keep the configured list order. + - `round_robin`: rotate the starting point across requests. +- `upstream_strategy.dispatch` controls how requests are launched inside one priority group: + - `{"type":"serial"}`: try one candidate at a time. + - `{"type":"hedged","delay_ms":2000,"max_parallel":2}`: launch the first candidate immediately, then add one more attempt after `delay_ms` if the prior attempt is still unresolved, up to `max_parallel`. + - `{"type":"race","max_parallel":3}`: launch up to `max_parallel` candidates immediately and take the first successful result. - Retryable conditions: network timeout/connect errors, or status 400/401/403/404/408/422/429/307/5xx (including 504/524). Retries stay within the same provider's priority groups. - Cooldown conditions: `401/403/408/429/5xx` will temporarily move the failed upstream behind ready peers for `retryable_failure_cooldown_secs` (default `15`); `400/404/422/307` stay retryable but do not trigger cross-request cooldown. - `/v1/messages` only: after the chosen native provider is exhausted (retryable errors), the proxy can fall back to the other native provider (`anthropic` ↔ `kiro`) if it is configured. diff --git a/README.zh-CN.md b/README.zh-CN.md index 22999c7..ecb9e88 100644 --- a/README.zh-CN.md +++ b/README.zh-CN.md @@ -103,7 +103,7 @@ pnpm exec tsc --noEmit | `retryable_failure_cooldown_secs` | `15` | 对适合短时降级的可重试失败施加冷却窗口;`0` 表示关闭冷却。重载或重启运行中的代理会重置当前冷却状态 | | `tray_token_rate.enabled` | `true` | macOS 托盘实时速率;其他平台无害 | | `tray_token_rate.format` | `split` | `combined`(总数) / `split`(↑入 ↓出) / `both`(总数 | ↑入 ↓出) | -| `upstream_strategy` | `priority_fill_first` | `priority_fill_first` 默认先填满高优先级;`priority_round_robin` 在同组内轮询 | +| `upstream_strategy` | `{ "order": "fill_first", "dispatch": { "type": "serial" } }` | 结构化策略对象。`order` 控制同一优先级组内的候选顺序;`dispatch` 控制串行 / hedged / race 派发方式 | ### 上游条目(`upstreams[]`) | 字段 | 默认值 | 说明 | @@ -147,7 +147,14 @@ pnpm exec tsc --noEmit - **Gemini**:`upstream.api_key` → `x-goog-api-key` → 查询参数 `?key=` → 报错 ## 负载均衡与重试 -- 优先级:高优先级组先尝试;组内按列表顺序(fill-first)或轮询(round-robin) +- 优先级:高优先级组先尝试。 +- `upstream_strategy.order` 控制同一优先级组内的选择顺序: + - `fill_first`:保持配置列表顺序。 + - `round_robin`:跨请求轮换起点。 +- `upstream_strategy.dispatch` 控制同一优先级组内的发起方式: + - `{"type":"serial"}`:一次只尝试一个候选。 + - `{"type":"hedged","delay_ms":2000,"max_parallel":2}`:先立即发第一个;若 `delay_ms` 后仍未决,再补发下一个,最多并发到 `max_parallel`。 + - `{"type":"race","max_parallel":3}`:立即并发发起最多 `max_parallel` 个候选,谁先成功就返回谁。 - 可重试条件:网络超时/连接错误,或状态码 400/401/403/404/408/422/429/307/5xx(包含 504/524);重试只在同一 provider 的优先级组内进行 - 冷却条件:`401/403/408/429/5xx` 会让失败 upstream 在 `retryable_failure_cooldown_secs`(默认 `15`)内被暂时后置;`400/404/422/307` 仍可重试,但不会触发跨请求冷却 - 仅 `/v1/messages`:当命中的 native provider(`anthropic`/`kiro`)被耗尽(仍是可重试错误)时,若另一个 native provider 已配置,会自动 fallback(Anthropic ↔ Kiro) diff --git a/crates/token_proxy_core/src/proxy/config/io.rs b/crates/token_proxy_core/src/proxy/config/io.rs index 903ed8e..7cd01a8 100644 --- a/crates/token_proxy_core/src/proxy/config/io.rs +++ b/crates/token_proxy_core/src/proxy/config/io.rs @@ -10,6 +10,8 @@ const DEFAULT_CONFIG_HEADER: &str = concat!( "// Token Proxy config (JSONC). Comments and trailing commas are supported.\n", "// log_level (optional): silent|error|warn|info|debug|trace. Default: silent.\n", "// upstream_no_data_timeout_secs (optional): upstream no-data timeout in seconds. Minimum: 3. Default: 120.\n", + "// upstream_strategy (optional): { order: \"fill_first\"|\"round_robin\", dispatch: { type: \"serial\"|\"hedged\"|\"race\", ... } }.\n", + "// Example hedged: { \"order\": \"round_robin\", \"dispatch\": { \"type\": \"hedged\", \"delay_ms\": 2000, \"max_parallel\": 2 } }\n", "// upstreams[].api_keys (optional): one or more API keys for the same upstream. Example: [\"key-a\", \"key-b\"].\n", "// app_proxy_url (optional): http(s)://... | socks5(h)://... (used for app updates and upstream proxy reuse).\n", "// upstreams[].proxy_url (optional): empty => direct; \"$app_proxy_url\" => use app_proxy_url; or an explicit proxy URL.\n", @@ -218,3 +220,7 @@ async fn ensure_parent_dir(path: &Path) -> Result<(), String> { ); Ok(()) } + +#[cfg(test)] +#[path = "io.test.rs"] +mod tests; diff --git a/crates/token_proxy_core/src/proxy/config/io.test.rs b/crates/token_proxy_core/src/proxy/config/io.test.rs new file mode 100644 index 0000000..39e8955 --- /dev/null +++ b/crates/token_proxy_core/src/proxy/config/io.test.rs @@ -0,0 +1,28 @@ +use super::*; +use std::path::Path; + +#[test] +fn parse_config_file_migrates_legacy_upstream_strategy_string() { + let parsed = parse_config_file( + r#" + { + "host": "127.0.0.1", + "port": 9208, + "upstream_strategy": "priority_fill_first", + "upstreams": [] + } + "#, + Path::new("/tmp/config.jsonc"), + ) + .expect("legacy config should migrate"); + + assert!(parsed.migrated); + assert_eq!( + parsed.config.upstream_strategy.order, + crate::proxy::config::UpstreamOrderStrategy::FillFirst + ); + assert_eq!( + parsed.config.upstream_strategy.dispatch, + crate::proxy::config::UpstreamDispatchStrategy::Serial + ); +} diff --git a/crates/token_proxy_core/src/proxy/config/migrate.rs b/crates/token_proxy_core/src/proxy/config/migrate.rs index ab15241..aec4ced 100644 --- a/crates/token_proxy_core/src/proxy/config/migrate.rs +++ b/crates/token_proxy_core/src/proxy/config/migrate.rs @@ -32,7 +32,16 @@ pub(super) fn migrate_config_json(root: &mut Value) -> bool { .is_some_and(|obj| obj.contains_key("api_key")) }) }); - let is_legacy_config = had_legacy_enable || had_legacy_provider || had_legacy_api_key; + let had_legacy_upstream_strategy = root_obj + .get("upstream_strategy") + .and_then(Value::as_str) + .is_some_and(|value| { + matches!(value.trim(), "priority_fill_first" | "priority_round_robin") + }); + let is_legacy_config = had_legacy_enable + || had_legacy_provider + || had_legacy_api_key + || had_legacy_upstream_strategy; // 仅当检测到旧字段时才进行迁移;否则避免对新配置做“默认填充”,改变用户语义。 if !is_legacy_config { @@ -45,6 +54,7 @@ pub(super) fn migrate_config_json(root: &mut Value) -> bool { let mut changed = false; changed |= had_legacy_enable; + changed |= migrate_legacy_upstream_strategy(root_obj); let Some(upstreams_value) = root_obj.get_mut("upstreams") else { return changed; @@ -60,6 +70,32 @@ pub(super) fn migrate_config_json(root: &mut Value) -> bool { changed } +fn migrate_legacy_upstream_strategy(root_obj: &mut Map) -> bool { + let Some(value) = root_obj.get("upstream_strategy").and_then(Value::as_str) else { + return false; + }; + let order = match value.trim() { + "priority_fill_first" => "fill_first", + "priority_round_robin" => "round_robin", + _ => return false, + }; + + root_obj.insert( + "upstream_strategy".to_string(), + Value::Object(Map::from_iter([ + ("order".to_string(), Value::String(order.to_string())), + ( + "dispatch".to_string(), + Value::Object(Map::from_iter([( + "type".to_string(), + Value::String("serial".to_string()), + )])), + ), + ])), + ); + true +} + fn migrate_single_upstream(upstream: &mut Value, legacy_enable_conversion: bool) -> bool { let Some(obj) = upstream.as_object_mut() else { return false; diff --git a/crates/token_proxy_core/src/proxy/config/migrate.test.rs b/crates/token_proxy_core/src/proxy/config/migrate.test.rs index abc166c..97fc580 100644 --- a/crates/token_proxy_core/src/proxy/config/migrate.test.rs +++ b/crates/token_proxy_core/src/proxy/config/migrate.test.rs @@ -141,3 +141,52 @@ fn migrate_api_key_to_api_keys() { assert_eq!(keys.len(), 1); assert_eq!(keys[0].as_str(), Some("key-1")); } +#[test] +fn migrate_legacy_upstream_strategy_string_to_structured_fill_first_serial() { + let mut value = parse_json( + r#" + { + "host": "127.0.0.1", + "port": 9208, + "upstream_strategy": "priority_fill_first", + "upstreams": [] + } + "#, + ); + + let changed = migrate_config_json(&mut value); + assert!(changed); + + assert_eq!( + value["upstream_strategy"], + serde_json::json!({ + "order": "fill_first", + "dispatch": { "type": "serial" } + }) + ); +} + +#[test] +fn migrate_legacy_upstream_strategy_string_to_structured_round_robin_serial() { + let mut value = parse_json( + r#" + { + "host": "127.0.0.1", + "port": 9208, + "upstream_strategy": "priority_round_robin", + "upstreams": [] + } + "#, + ); + + let changed = migrate_config_json(&mut value); + assert!(changed); + + assert_eq!( + value["upstream_strategy"], + serde_json::json!({ + "order": "round_robin", + "dispatch": { "type": "serial" } + }) + ); +} diff --git a/crates/token_proxy_core/src/proxy/config/mod.rs b/crates/token_proxy_core/src/proxy/config/mod.rs index cc28538..d0f549c 100644 --- a/crates/token_proxy_core/src/proxy/config/mod.rs +++ b/crates/token_proxy_core/src/proxy/config/mod.rs @@ -13,7 +13,8 @@ const MIN_UPSTREAM_NO_DATA_TIMEOUT_SECS: u64 = 3; pub use types::{ ConfigResponse, HeaderOverride, InboundApiFormat, KiroPreferredEndpoint, ProviderUpstreams, ProxyConfig, ProxyConfigFile, TrayTokenRateConfig, TrayTokenRateFormat, UpstreamConfig, - UpstreamGroup, UpstreamOverrides, UpstreamRuntime, UpstreamStrategy, + UpstreamDispatchRuntime, UpstreamDispatchStrategy, UpstreamGroup, UpstreamOrderStrategy, + UpstreamOverrides, UpstreamRuntime, UpstreamStrategy, UpstreamStrategyRuntime, }; pub async fn read_config(paths: &TokenProxyPaths) -> Result { @@ -75,7 +76,7 @@ fn build_runtime_config(config: ProxyConfigFile) -> Result upstream_no_data_timeout: resolve_upstream_no_data_timeout( config.upstream_no_data_timeout_secs, )?, - upstream_strategy: config.upstream_strategy, + upstream_strategy: resolve_upstream_strategy(config.upstream_strategy)?, upstreams, kiro_preferred_endpoint: config.kiro_preferred_endpoint, antigravity_user_agent: config.antigravity_user_agent, @@ -103,6 +104,46 @@ fn resolve_upstream_no_data_timeout(value: u64) -> Result { Ok(duration) } +fn resolve_upstream_strategy(value: UpstreamStrategy) -> Result { + let dispatch = match value.dispatch { + UpstreamDispatchStrategy::Serial => UpstreamDispatchRuntime::Serial, + UpstreamDispatchStrategy::Hedged { + delay_ms, + max_parallel, + } => UpstreamDispatchRuntime::Hedged { + delay: resolve_hedged_delay(delay_ms)?, + max_parallel: resolve_parallel_attempts("hedged", max_parallel)?, + }, + UpstreamDispatchStrategy::Race { max_parallel } => UpstreamDispatchRuntime::Race { + max_parallel: resolve_parallel_attempts("race", max_parallel)?, + }, + }; + Ok(UpstreamStrategyRuntime { + order: value.order, + dispatch, + }) +} + +fn resolve_hedged_delay(value: u64) -> Result { + if value == 0 { + return Err("upstream_strategy.dispatch.delay_ms must be at least 1.".to_string()); + } + let duration = Duration::from_millis(value); + if Instant::now().checked_add(duration).is_none() { + return Err("upstream_strategy.dispatch.delay_ms is too large.".to_string()); + } + Ok(duration) +} + +fn resolve_parallel_attempts(dispatch: &str, value: u64) -> Result { + if value < 2 { + return Err(format!( + "upstream_strategy.dispatch.max_parallel must be at least 2 for {dispatch}." + )); + } + usize::try_from(value) + .map_err(|_| "upstream_strategy.dispatch.max_parallel is too large.".to_string()) +} fn resolve_max_request_body_bytes(value: Option) -> usize { let value = value.unwrap_or(DEFAULT_MAX_REQUEST_BODY_BYTES); let value = if value == 0 { diff --git a/crates/token_proxy_core/src/proxy/config/mod.test.rs b/crates/token_proxy_core/src/proxy/config/mod.test.rs index 3c7dbd4..b004f1e 100644 --- a/crates/token_proxy_core/src/proxy/config/mod.test.rs +++ b/crates/token_proxy_core/src/proxy/config/mod.test.rs @@ -101,6 +101,97 @@ fn build_runtime_config_maps_upstream_no_data_timeout_secs() { assert_eq!(runtime.upstream_no_data_timeout, Duration::from_secs(3)); } +#[test] +fn build_runtime_config_maps_hedged_strategy() { + let mut config = ProxyConfigFile::default(); + config.upstream_strategy = UpstreamStrategy { + order: UpstreamOrderStrategy::RoundRobin, + dispatch: UpstreamDispatchStrategy::Hedged { + delay_ms: 250, + max_parallel: 3, + }, + }; + + let runtime = build_runtime_config(config).expect("runtime config"); + + assert_eq!( + runtime.upstream_strategy.order, + UpstreamOrderStrategy::RoundRobin + ); + assert_eq!( + runtime.upstream_strategy.dispatch, + UpstreamDispatchRuntime::Hedged { + delay: Duration::from_millis(250), + max_parallel: 3, + } + ); +} + +#[test] +fn build_runtime_config_maps_race_strategy() { + let mut config = ProxyConfigFile::default(); + config.upstream_strategy = UpstreamStrategy { + order: UpstreamOrderStrategy::RoundRobin, + dispatch: UpstreamDispatchStrategy::Race { max_parallel: 4 }, + }; + + let runtime = build_runtime_config(config).expect("runtime config"); + + assert_eq!( + runtime.upstream_strategy.order, + UpstreamOrderStrategy::RoundRobin + ); + assert_eq!( + runtime.upstream_strategy.dispatch, + UpstreamDispatchRuntime::Race { max_parallel: 4 } + ); +} + +#[test] +fn build_runtime_config_rejects_hedged_strategy_with_zero_delay() { + let mut config = ProxyConfigFile::default(); + config.upstream_strategy = UpstreamStrategy { + order: UpstreamOrderStrategy::FillFirst, + dispatch: UpstreamDispatchStrategy::Hedged { + delay_ms: 0, + max_parallel: 2, + }, + }; + + let result = build_runtime_config(config); + + assert!(result.is_err()); +} + +#[test] +fn build_runtime_config_rejects_hedged_strategy_with_max_parallel_below_two() { + let mut config = ProxyConfigFile::default(); + config.upstream_strategy = UpstreamStrategy { + order: UpstreamOrderStrategy::FillFirst, + dispatch: UpstreamDispatchStrategy::Hedged { + delay_ms: 250, + max_parallel: 1, + }, + }; + + let result = build_runtime_config(config); + + assert!(result.is_err()); +} + +#[test] +fn build_runtime_config_rejects_race_strategy_with_max_parallel_below_two() { + let mut config = ProxyConfigFile::default(); + config.upstream_strategy = UpstreamStrategy { + order: UpstreamOrderStrategy::FillFirst, + dispatch: UpstreamDispatchStrategy::Race { max_parallel: 1 }, + }; + + let result = build_runtime_config(config); + + assert!(result.is_err()); +} + #[test] fn build_runtime_config_rejects_upstream_no_data_timeout_below_minimum() { let mut config = ProxyConfigFile::default(); diff --git a/crates/token_proxy_core/src/proxy/config/types.rs b/crates/token_proxy_core/src/proxy/config/types.rs index 71484b7..d3def88 100644 --- a/crates/token_proxy_core/src/proxy/config/types.rs +++ b/crates/token_proxy_core/src/proxy/config/types.rs @@ -85,19 +85,41 @@ impl InboundApiFormat { } } -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] -pub enum UpstreamStrategy { - PriorityRoundRobin, - PriorityFillFirst, +pub enum UpstreamOrderStrategy { + FillFirst, + RoundRobin, } -impl Default for UpstreamStrategy { +impl Default for UpstreamOrderStrategy { fn default() -> Self { - Self::PriorityFillFirst + Self::FillFirst } } +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum UpstreamDispatchStrategy { + #[default] + Serial, + Hedged { + delay_ms: u64, + max_parallel: u64, + }, + Race { + max_parallel: u64, + }, +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Default)] +pub struct UpstreamStrategy { + #[serde(default)] + pub order: UpstreamOrderStrategy, + #[serde(default)] + pub dispatch: UpstreamDispatchStrategy, +} + #[derive(Clone, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum TrayTokenRateFormat { @@ -240,12 +262,40 @@ impl Default for ProxyConfigFile { retryable_failure_cooldown_secs: default_retryable_failure_cooldown_secs(), upstream_no_data_timeout_secs: default_upstream_no_data_timeout_secs(), tray_token_rate: TrayTokenRateConfig::default(), - upstream_strategy: UpstreamStrategy::PriorityFillFirst, + upstream_strategy: UpstreamStrategy::default(), upstreams: Vec::new(), } } } +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct UpstreamStrategyRuntime { + pub order: UpstreamOrderStrategy, + pub dispatch: UpstreamDispatchRuntime, +} + +impl Default for UpstreamStrategyRuntime { + fn default() -> Self { + Self { + order: UpstreamOrderStrategy::default(), + dispatch: UpstreamDispatchRuntime::default(), + } + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Default)] +pub enum UpstreamDispatchRuntime { + #[default] + Serial, + Hedged { + delay: std::time::Duration, + max_parallel: usize, + }, + Race { + max_parallel: usize, + }, +} + #[derive(Clone)] pub struct ProxyConfig { pub host: String, @@ -255,7 +305,7 @@ pub struct ProxyConfig { pub max_request_body_bytes: usize, pub retryable_failure_cooldown: std::time::Duration, pub upstream_no_data_timeout: std::time::Duration, - pub upstream_strategy: UpstreamStrategy, + pub upstream_strategy: UpstreamStrategyRuntime, pub upstreams: HashMap, pub kiro_preferred_endpoint: Option, pub antigravity_user_agent: Option, diff --git a/crates/token_proxy_core/src/proxy/config/types.test.rs b/crates/token_proxy_core/src/proxy/config/types.test.rs index 3df0f12..a798c16 100644 --- a/crates/token_proxy_core/src/proxy/config/types.test.rs +++ b/crates/token_proxy_core/src/proxy/config/types.test.rs @@ -196,3 +196,17 @@ fn proxy_config_file_defaults_retryable_failure_cooldown_to_15_seconds() { assert_eq!(config.retryable_failure_cooldown_secs, 15); } + +#[test] +fn proxy_config_file_defaults_upstream_strategy_to_fill_first_serial() { + let config = ProxyConfigFile::default(); + + assert_eq!( + config.upstream_strategy.order, + UpstreamOrderStrategy::FillFirst + ); + assert_eq!( + config.upstream_strategy.dispatch, + UpstreamDispatchStrategy::Serial + ); +} diff --git a/crates/token_proxy_core/src/proxy/http.test.rs b/crates/token_proxy_core/src/proxy/http.test.rs index b033e36..8a4fd90 100644 --- a/crates/token_proxy_core/src/proxy/http.test.rs +++ b/crates/token_proxy_core/src/proxy/http.test.rs @@ -12,7 +12,7 @@ fn config_with_local(key: &str) -> ProxyConfig { max_request_body_bytes: 1024, retryable_failure_cooldown: std::time::Duration::from_secs(15), upstream_no_data_timeout: std::time::Duration::from_secs(120), - upstream_strategy: crate::proxy::config::UpstreamStrategy::PriorityFillFirst, + upstream_strategy: crate::proxy::config::UpstreamStrategyRuntime::default(), upstreams: HashMap::new(), kiro_preferred_endpoint: None, antigravity_user_agent: None, @@ -28,7 +28,7 @@ fn config_without_local() -> ProxyConfig { max_request_body_bytes: 1024, retryable_failure_cooldown: std::time::Duration::from_secs(15), upstream_no_data_timeout: std::time::Duration::from_secs(120), - upstream_strategy: crate::proxy::config::UpstreamStrategy::PriorityFillFirst, + upstream_strategy: crate::proxy::config::UpstreamStrategyRuntime::default(), upstreams: HashMap::new(), kiro_preferred_endpoint: None, antigravity_user_agent: None, diff --git a/crates/token_proxy_core/src/proxy/openai_compat.rs b/crates/token_proxy_core/src/proxy/openai_compat.rs index 20ccf07..87f6214 100644 --- a/crates/token_proxy_core/src/proxy/openai_compat.rs +++ b/crates/token_proxy_core/src/proxy/openai_compat.rs @@ -32,6 +32,7 @@ pub(crate) enum FormatTransform { ResponsesToChat, ResponsesToAnthropic, AnthropicToResponses, + AnthropicToCodex, ChatToAnthropic, AnthropicToChat, GeminiToAnthropic, @@ -45,6 +46,7 @@ pub(crate) enum FormatTransform { ResponsesToCodex, CodexToChat, CodexToResponses, + CodexToAnthropic, } pub(crate) fn inbound_format(path: &str) -> Option { @@ -71,6 +73,11 @@ pub(crate) async fn transform_request_body( FormatTransform::AnthropicToResponses => { anthropic_compat::anthropic_request_to_responses(body, http_clients).await } + FormatTransform::AnthropicToCodex => { + let intermediate = + anthropic_compat::anthropic_request_to_responses(body, http_clients).await?; + codex_compat::responses_request_to_codex(&intermediate, model_hint) + } FormatTransform::ChatToAnthropic => { let intermediate = chat_request_to_responses(body)?; anthropic_compat::responses_request_to_anthropic(&intermediate, http_clients).await @@ -93,7 +100,9 @@ pub(crate) async fn transform_request_body( FormatTransform::ResponsesToCodex => { codex_compat::responses_request_to_codex(body, model_hint) } - FormatTransform::CodexToChat | FormatTransform::CodexToResponses => Ok(body.clone()), + FormatTransform::CodexToChat + | FormatTransform::CodexToResponses + | FormatTransform::CodexToAnthropic => Ok(body.clone()), } } @@ -112,6 +121,9 @@ pub(crate) fn transform_response_body( FormatTransform::AnthropicToResponses => { anthropic_compat::anthropic_response_to_responses(bytes) } + FormatTransform::AnthropicToCodex => { + Err("Codex response conversion is handled upstream.".to_string()) + } FormatTransform::ChatToAnthropic => { let intermediate = chat_response_to_responses(bytes)?; anthropic_compat::responses_response_to_anthropic(&intermediate, model_hint) @@ -129,6 +141,10 @@ pub(crate) fn transform_response_body( FormatTransform::KiroToAnthropic => { Err("Kiro response conversion is handled upstream.".to_string()) } + FormatTransform::CodexToAnthropic => { + let intermediate = codex_compat::codex_response_to_responses(bytes, None)?; + anthropic_compat::responses_response_to_anthropic(&intermediate, model_hint) + } FormatTransform::CodexToChat | FormatTransform::CodexToResponses => { Err("Codex response conversion is handled upstream.".to_string()) } diff --git a/crates/token_proxy_core/src/proxy/response/dispatch/buffered.rs b/crates/token_proxy_core/src/proxy/response/dispatch/buffered.rs index 98f3b81..e0c91f4 100644 --- a/crates/token_proxy_core/src/proxy/response/dispatch/buffered.rs +++ b/crates/token_proxy_core/src/proxy/response/dispatch/buffered.rs @@ -139,6 +139,9 @@ fn convert_success_body( FormatTransform::CodexToResponses => { convert_codex_to_responses_body(bytes, context, usage, log, request_body) } + FormatTransform::CodexToAnthropic => { + convert_codex_to_anthropic_body(bytes, context, usage, log, request_body) + } _ if transform != FormatTransform::None => { convert_generic_body(transform, bytes, context, usage, log) } @@ -227,6 +230,35 @@ fn convert_codex_to_responses_body( }) } +fn convert_codex_to_anthropic_body( + bytes: &Bytes, + context: &mut LogContext, + usage: UsageSnapshot, + log: Arc, + request_body: Option<&str>, +) -> Result { + let responses = match codex_compat::codex_response_to_responses(bytes, request_body) { + Ok(converted) => converted, + Err(message) => { + return Err(respond_transform_error(context, usage, log, message)); + } + }; + let anthropic = match transform_response_body( + FormatTransform::ResponsesToAnthropic, + &responses, + context.model.as_deref(), + ) { + Ok(converted) => converted, + Err(message) => { + return Err(respond_transform_error(context, usage, log, message)); + } + }; + Ok(ConvertedBody { + output: anthropic, + usage, + }) +} + fn convert_generic_body( transform: FormatTransform, bytes: &Bytes, @@ -355,6 +387,7 @@ fn provider_for_tokens(transform: FormatTransform, provider: &str) -> &str { FormatTransform::KiroToAnthropic => "anthropic", FormatTransform::CodexToChat => "openai", FormatTransform::CodexToResponses => "openai-response", + FormatTransform::CodexToAnthropic => "anthropic", _ if provider == PROVIDER_ANTIGRAVITY => PROVIDER_GEMINI, _ => provider, } diff --git a/crates/token_proxy_core/src/proxy/response/dispatch/stream.rs b/crates/token_proxy_core/src/proxy/response/dispatch/stream.rs index 366528d..36e04c0 100644 --- a/crates/token_proxy_core/src/proxy/response/dispatch/stream.rs +++ b/crates/token_proxy_core/src/proxy/response/dispatch/stream.rs @@ -271,6 +271,9 @@ fn stream_for_composed_transform( FormatTransform::GeminiToResponses => { stream_gemini_to_responses(upstream, context, log, request_tracker) } + FormatTransform::CodexToAnthropic => { + stream_codex_to_anthropic(upstream, context, log, request_tracker) + } _ => streaming::stream_with_logging(upstream, context, log, request_tracker).boxed(), } } @@ -318,6 +321,30 @@ fn stream_anthropic_to_chat( .boxed() } +fn stream_codex_to_anthropic( + upstream: UpstreamBytesStream, + context: LogContext, + log: Arc, + request_tracker: RequestTokenTracker, +) -> ResponseStream { + let intermediate_log = Arc::new(LogWriter::new(None)); + let intermediate_tracker = RequestTokenTracker::disabled(); + let responses_stream = codex_compat::stream_codex_to_responses( + upstream, + context.clone(), + intermediate_log, + intermediate_tracker, + ) + .boxed(); + responses_to_anthropic::stream_responses_to_anthropic( + responses_stream, + context, + log, + request_tracker, + ) + .boxed() +} + fn stream_gemini_to_anthropic( upstream: UpstreamBytesStream, context: LogContext, diff --git a/crates/token_proxy_core/src/proxy/server.rs b/crates/token_proxy_core/src/proxy/server.rs index 3c76dd2..5c8de8a 100644 --- a/crates/token_proxy_core/src/proxy/server.rs +++ b/crates/token_proxy_core/src/proxy/server.rs @@ -236,6 +236,7 @@ fn resolve_anthropic_plan( inbound_format, &[ PROVIDER_RESPONSES, + PROVIDER_CODEX, PROVIDER_CHAT, PROVIDER_GEMINI, PROVIDER_ANTIGRAVITY, @@ -251,6 +252,12 @@ fn resolve_anthropic_plan( request_transform: FormatTransform::AnthropicToResponses, response_transform: FormatTransform::ResponsesToAnthropic, }, + PROVIDER_CODEX => DispatchPlan { + provider: PROVIDER_CODEX, + outbound_path: Some(CODEX_RESPONSES_PATH), + request_transform: FormatTransform::AnthropicToCodex, + response_transform: FormatTransform::CodexToAnthropic, + }, PROVIDER_CHAT => DispatchPlan { provider: PROVIDER_CHAT, outbound_path: Some(CHAT_PATH), @@ -545,6 +552,18 @@ fn build_retry_fallback_plan(path: &str, provider: &'static str) -> Option DispatchPlan { + provider: PROVIDER_RESPONSES, + outbound_path: Some(RESPONSES_PATH), + request_transform: FormatTransform::AnthropicToResponses, + response_transform: FormatTransform::ResponsesToAnthropic, + }, + PROVIDER_CODEX => DispatchPlan { + provider: PROVIDER_CODEX, + outbound_path: Some(CODEX_RESPONSES_PATH), + request_transform: FormatTransform::AnthropicToCodex, + response_transform: FormatTransform::CodexToAnthropic, + }, _ => return None, }); } @@ -587,6 +606,8 @@ fn resolve_retry_fallback_provider( let fallback = match primary_provider { PROVIDER_ANTHROPIC => PROVIDER_KIRO, PROVIDER_KIRO => PROVIDER_ANTHROPIC, + PROVIDER_RESPONSES => PROVIDER_CODEX, + PROVIDER_CODEX => PROVIDER_RESPONSES, _ => return None, }; return Some((fallback, Some(InboundApiFormat::AnthropicMessages))); diff --git a/crates/token_proxy_core/src/proxy/server.test.rs b/crates/token_proxy_core/src/proxy/server.test.rs index 912c58c..99c0707 100644 --- a/crates/token_proxy_core/src/proxy/server.test.rs +++ b/crates/token_proxy_core/src/proxy/server.test.rs @@ -22,8 +22,8 @@ use tokio::{runtime::Runtime, sync::RwLock, task::JoinHandle}; use crate::logging::LogLevel; use crate::paths::TokenProxyPaths; use crate::proxy::config::{ - InboundApiFormat, ProviderUpstreams, ProxyConfig, UpstreamGroup, UpstreamRuntime, - UpstreamStrategy, + InboundApiFormat, ProviderUpstreams, ProxyConfig, UpstreamDispatchRuntime, UpstreamGroup, + UpstreamOrderStrategy, UpstreamRuntime, UpstreamStrategyRuntime, }; const FORMATS_ALL: &[InboundApiFormat] = &[ @@ -127,7 +127,10 @@ fn config_with_runtime_upstreams( max_request_body_bytes: 20 * 1024 * 1024, retryable_failure_cooldown: std::time::Duration::from_secs(15), upstream_no_data_timeout: std::time::Duration::from_secs(120), - upstream_strategy: UpstreamStrategy::PriorityRoundRobin, + upstream_strategy: UpstreamStrategyRuntime { + order: UpstreamOrderStrategy::RoundRobin, + dispatch: UpstreamDispatchRuntime::Serial, + }, upstreams: provider_map, kiro_preferred_endpoint: None, antigravity_user_agent: None, @@ -144,6 +147,7 @@ struct RecordedRequest { struct MockUpstreamState { status: StatusCode, body: Value, + delay_ms: u64, requests: Arc>>, } @@ -178,6 +182,9 @@ async fn mock_upstream_handler( path: uri.path().to_string(), body: json_body, }); + if state.delay_ms > 0 { + tokio::time::sleep(std::time::Duration::from_millis(state.delay_ms)).await; + } ( state.status, [(axum::http::header::CONTENT_TYPE, "application/json")], @@ -187,10 +194,19 @@ async fn mock_upstream_handler( } async fn spawn_mock_upstream(status: StatusCode, body: Value) -> MockUpstream { + spawn_mock_upstream_with_delay(status, body, 0).await +} + +async fn spawn_mock_upstream_with_delay( + status: StatusCode, + body: Value, + delay_ms: u64, +) -> MockUpstream { let requests = Arc::new(Mutex::new(Vec::new())); let state = Arc::new(MockUpstreamState { status, body, + delay_ms, requests: requests.clone(), }); let app = Router::new() @@ -212,6 +228,195 @@ async fn spawn_mock_upstream(status: StatusCode, body: Value) -> MockUpstream { } } +#[test] +fn responses_request_hedged_delay_prefers_faster_same_priority_upstream() { + run_async(async { + let slow_primary = spawn_mock_upstream_with_delay( + StatusCode::OK, + json!({ + "id": "resp_from_slow_primary", + "object": "response", + "created_at": 123, + "model": "gpt-5", + "status": "completed", + "output": [ + { + "type": "message", + "id": "msg_1", + "status": "completed", + "role": "assistant", + "content": [ + { "type": "output_text", "text": "from slow primary" } + ] + } + ], + "usage": { "input_tokens": 1, "output_tokens": 2, "total_tokens": 3 } + }), + 300, + ) + .await; + let fast_secondary = spawn_mock_upstream( + StatusCode::OK, + json!({ + "id": "resp_from_fast_secondary", + "object": "response", + "created_at": 123, + "model": "gpt-5", + "status": "completed", + "output": [ + { + "type": "message", + "id": "msg_1", + "status": "completed", + "role": "assistant", + "content": [ + { "type": "output_text", "text": "from fast secondary" } + ] + } + ], + "usage": { "input_tokens": 1, "output_tokens": 2, "total_tokens": 3 } + }), + ) + .await; + + let mut config = config_with_runtime_upstreams(&[ + ( + PROVIDER_RESPONSES, + 10, + "responses-primary", + slow_primary.base_url.as_str(), + FORMATS_RESPONSES, + ), + ( + PROVIDER_RESPONSES, + 10, + "responses-secondary", + fast_secondary.base_url.as_str(), + FORMATS_RESPONSES, + ), + ]); + config.upstream_strategy = UpstreamStrategyRuntime { + order: UpstreamOrderStrategy::FillFirst, + dispatch: UpstreamDispatchRuntime::Hedged { + delay: std::time::Duration::from_millis(50), + max_parallel: 2, + }, + }; + + let data_dir = next_test_data_dir("responses_hedged_request"); + let state = build_test_state_handle(config, data_dir.clone()).await; + + let (status, json) = send_responses_request(state).await; + let primary_requests = slow_primary.requests(); + let secondary_requests = fast_secondary.requests(); + + slow_primary.abort(); + fast_secondary.abort(); + let _ = std::fs::remove_dir_all(&data_dir); + + assert_eq!(status, StatusCode::OK); + assert_eq!( + json["output"][0]["content"][0]["text"].as_str(), + Some("from fast secondary") + ); + assert_eq!(primary_requests.len(), 1); + assert_eq!(secondary_requests.len(), 1); + }); +} + +#[test] +fn responses_request_race_prefers_faster_same_priority_upstream() { + run_async(async { + let slow_primary = spawn_mock_upstream_with_delay( + StatusCode::OK, + json!({ + "id": "resp_from_slow_primary", + "object": "response", + "created_at": 123, + "model": "gpt-5", + "status": "completed", + "output": [ + { + "type": "message", + "id": "msg_1", + "status": "completed", + "role": "assistant", + "content": [ + { "type": "output_text", "text": "from slow primary" } + ] + } + ], + "usage": { "input_tokens": 1, "output_tokens": 2, "total_tokens": 3 } + }), + 300, + ) + .await; + let fast_secondary = spawn_mock_upstream( + StatusCode::OK, + json!({ + "id": "resp_from_fast_secondary", + "object": "response", + "created_at": 123, + "model": "gpt-5", + "status": "completed", + "output": [ + { + "type": "message", + "id": "msg_1", + "status": "completed", + "role": "assistant", + "content": [ + { "type": "output_text", "text": "from fast secondary" } + ] + } + ], + "usage": { "input_tokens": 1, "output_tokens": 2, "total_tokens": 3 } + }), + ) + .await; + + let mut config = config_with_runtime_upstreams(&[ + ( + PROVIDER_RESPONSES, + 10, + "responses-primary", + slow_primary.base_url.as_str(), + FORMATS_RESPONSES, + ), + ( + PROVIDER_RESPONSES, + 10, + "responses-secondary", + fast_secondary.base_url.as_str(), + FORMATS_RESPONSES, + ), + ]); + config.upstream_strategy = UpstreamStrategyRuntime { + order: UpstreamOrderStrategy::RoundRobin, + dispatch: UpstreamDispatchRuntime::Race { max_parallel: 2 }, + }; + + let data_dir = next_test_data_dir("responses_race_request"); + let state = build_test_state_handle(config, data_dir.clone()).await; + + let (status, json) = send_responses_request(state).await; + let primary_requests = slow_primary.requests(); + let secondary_requests = fast_secondary.requests(); + + slow_primary.abort(); + fast_secondary.abort(); + let _ = std::fs::remove_dir_all(&data_dir); + + assert_eq!(status, StatusCode::OK); + assert_eq!( + json["output"][0]["content"][0]["text"].as_str(), + Some("from fast secondary") + ); + assert_eq!(primary_requests.len(), 1); + assert_eq!(secondary_requests.len(), 1); + }); +} + fn next_test_data_dir(label: &str) -> PathBuf { let stamp = SystemTime::now() .duration_since(UNIX_EPOCH) @@ -410,6 +615,42 @@ async fn send_responses_request(state: ProxyStateHandle) -> (StatusCode, Value) (status, json) } +async fn send_anthropic_messages_request( + state: ProxyStateHandle, + stream: bool, +) -> (StatusCode, Value) { + let response = proxy_request( + State(state), + Method::POST, + Uri::from_static("/v1/messages"), + axum::http::HeaderMap::new(), + Body::from( + json!({ + "model": "claude-sonnet-4-5", + "max_tokens": 64, + "stream": stream, + "messages": [ + { + "role": "user", + "content": [ + { "type": "text", "text": "hi from claude" } + ] + } + ] + }) + .to_string(), + ), + ) + .await; + + let status = response.status(); + let body = to_bytes(response.into_body(), usize::MAX) + .await + .expect("proxy response bytes"); + let json = serde_json::from_slice(&body).expect("proxy response json"); + (status, json) +} + #[test] fn responses_request_uses_chat_compat_for_coding_plan_runtime_upstream() { run_async(async { @@ -551,6 +792,141 @@ fn responses_request_falls_back_from_524_to_codex() { )); } +#[test] +fn anthropic_messages_request_routes_to_codex() { + run_async(async { + let codex = spawn_mock_upstream( + StatusCode::OK, + json!({ + "id": "resp_from_codex", + "object": "response", + "created_at": 123, + "model": "gpt-5-codex", + "status": "completed", + "output": [ + { + "type": "message", + "id": "msg_1", + "status": "completed", + "role": "assistant", + "content": [ + { "type": "output_text", "text": "from codex for claude" } + ] + } + ], + "usage": { "input_tokens": 1, "output_tokens": 2, "total_tokens": 3 } + }), + ) + .await; + + let config = config_with_runtime_upstreams(&[( + PROVIDER_CODEX, + 10, + "codex-primary", + codex.base_url.as_str(), + FORMATS_ALL, + )]); + let data_dir = next_test_data_dir("anthropic_messages_codex_direct"); + let state = build_test_state_handle(config, data_dir.clone()).await; + + let (status, json) = send_anthropic_messages_request(state, false).await; + let requests = codex.requests(); + + codex.abort(); + let _ = std::fs::remove_dir_all(&data_dir); + + assert_eq!(status, StatusCode::OK); + assert_eq!(json["type"], json!("message")); + assert_eq!(json["role"], json!("assistant")); + assert_eq!(json["content"][0]["type"], json!("text")); + assert_eq!(json["content"][0]["text"], json!("from codex for claude")); + assert_eq!(requests.len(), 1); + assert_eq!(requests[0].path, CODEX_RESPONSES_PATH); + assert_eq!(requests[0].body["input"][0]["role"].as_str(), Some("user")); + assert_eq!( + requests[0].body["input"][0]["content"][0]["type"].as_str(), + Some("input_text") + ); + assert_eq!( + requests[0].body["input"][0]["content"][0]["text"].as_str(), + Some("hi from claude") + ); + }); +} + +#[test] +fn anthropic_messages_request_falls_back_from_responses_to_codex() { + run_async(async { + let responses = spawn_mock_upstream( + StatusCode::BAD_REQUEST, + json!({ + "error": { "message": "responses upstream rejected request" } + }), + ) + .await; + let codex = spawn_mock_upstream( + StatusCode::OK, + json!({ + "id": "resp_from_codex", + "object": "response", + "created_at": 123, + "model": "gpt-5-codex", + "status": "completed", + "output": [ + { + "type": "message", + "id": "msg_1", + "status": "completed", + "role": "assistant", + "content": [ + { "type": "output_text", "text": "fallback from codex for claude" } + ] + } + ], + "usage": { "input_tokens": 1, "output_tokens": 2, "total_tokens": 3 } + }), + ) + .await; + + let config = config_with_runtime_upstreams(&[ + ( + PROVIDER_RESPONSES, + 10, + "responses-primary", + responses.base_url.as_str(), + FORMATS_ALL, + ), + ( + PROVIDER_CODEX, + 5, + "codex-fallback", + codex.base_url.as_str(), + FORMATS_ALL, + ), + ]); + let data_dir = next_test_data_dir("anthropic_messages_responses_to_codex_fallback"); + let state = build_test_state_handle(config, data_dir.clone()).await; + + let (status, json) = send_anthropic_messages_request(state, false).await; + let responses_requests = responses.requests(); + let codex_requests = codex.requests(); + + responses.abort(); + codex.abort(); + let _ = std::fs::remove_dir_all(&data_dir); + + assert_eq!(status, StatusCode::OK); + assert_eq!( + json["content"][0]["text"].as_str(), + Some("fallback from codex for claude") + ); + assert_eq!(responses_requests.len(), 1); + assert_eq!(responses_requests[0].path, RESPONSES_PATH); + assert_eq!(codex_requests.len(), 1); + assert_eq!(codex_requests[0].path, CODEX_RESPONSES_PATH); + }); +} + #[test] fn responses_request_skips_recently_failed_same_provider_upstream() { run_async(async { @@ -601,7 +977,10 @@ fn responses_request_skips_recently_failed_same_provider_upstream() { FORMATS_RESPONSES, ), ]); - config.upstream_strategy = UpstreamStrategy::PriorityFillFirst; + config.upstream_strategy = UpstreamStrategyRuntime { + order: UpstreamOrderStrategy::FillFirst, + dispatch: UpstreamDispatchRuntime::Serial, + }; let data_dir = next_test_data_dir("responses_same_provider_cooldown"); let state = build_test_state_handle(config, data_dir.clone()).await; @@ -685,7 +1064,10 @@ fn responses_request_cooldowns_same_provider_upstream_after_401() { FORMATS_RESPONSES, ), ]); - config.upstream_strategy = UpstreamStrategy::PriorityFillFirst; + config.upstream_strategy = UpstreamStrategyRuntime { + order: UpstreamOrderStrategy::FillFirst, + dispatch: UpstreamDispatchRuntime::Serial, + }; let data_dir = next_test_data_dir("responses_same_provider_cooldown_401"); let state = build_test_state_handle(config, data_dir.clone()).await; @@ -769,7 +1151,10 @@ fn responses_request_does_not_cooldown_same_provider_upstream_after_400() { FORMATS_RESPONSES, ), ]); - config.upstream_strategy = UpstreamStrategy::PriorityFillFirst; + config.upstream_strategy = UpstreamStrategyRuntime { + order: UpstreamOrderStrategy::FillFirst, + dispatch: UpstreamDispatchRuntime::Serial, + }; let data_dir = next_test_data_dir("responses_same_provider_no_cooldown_400"); let state = build_test_state_handle(config, data_dir.clone()).await; @@ -853,7 +1238,10 @@ fn responses_request_reload_resets_existing_cooldown_and_applies_new_duration() FORMATS_RESPONSES, ), ]); - config.upstream_strategy = UpstreamStrategy::PriorityFillFirst; + config.upstream_strategy = UpstreamStrategyRuntime { + order: UpstreamOrderStrategy::FillFirst, + dispatch: UpstreamDispatchRuntime::Serial, + }; config.retryable_failure_cooldown = std::time::Duration::from_secs(15); let data_dir = next_test_data_dir("responses_same_provider_reload_resets_cooldown"); diff --git a/crates/token_proxy_core/src/proxy/service.test.rs b/crates/token_proxy_core/src/proxy/service.test.rs index d5821f5..5a7f9bb 100644 --- a/crates/token_proxy_core/src/proxy/service.test.rs +++ b/crates/token_proxy_core/src/proxy/service.test.rs @@ -20,7 +20,7 @@ fn config_with_addr_and_body_limit( max_request_body_bytes, retryable_failure_cooldown: Duration::from_secs(15), upstream_no_data_timeout: Duration::from_secs(120), - upstream_strategy: crate::proxy::config::UpstreamStrategy::PriorityFillFirst, + upstream_strategy: crate::proxy::config::UpstreamStrategyRuntime::default(), upstreams: HashMap::new(), kiro_preferred_endpoint: None, antigravity_user_agent: None, diff --git a/crates/token_proxy_core/src/proxy/upstream.rs b/crates/token_proxy_core/src/proxy/upstream.rs index b680aaa..ff9e611 100644 --- a/crates/token_proxy_core/src/proxy/upstream.rs +++ b/crates/token_proxy_core/src/proxy/upstream.rs @@ -7,7 +7,13 @@ use axum::{ }, response::Response, }; -use std::{sync::Arc, time::Instant}; +use futures_util::stream::{FuturesUnordered, StreamExt}; +use std::{ + future::Future, + pin::Pin, + sync::Arc, + time::{Duration, Instant}, +}; const GEMINI_API_KEY_QUERY: &str = "key"; const LOCAL_UPSTREAM_ID: &str = "local"; @@ -28,7 +34,7 @@ use utils::resolve_group_start; use crate::proxy::redact::redact_query_param_value; use super::{ - config::{InboundApiFormat, ProviderUpstreams, UpstreamRuntime}, + config::{InboundApiFormat, ProviderUpstreams, UpstreamDispatchRuntime, UpstreamRuntime}, gemini, http, http::RequestAuth, inbound::detect_inbound_api_format, @@ -219,6 +225,57 @@ fn merge_group_result(state: &mut ForwardAttemptState, result: GroupAttemptResul false } +type GroupAttemptFuture<'a> = Pin + Send + 'a>>; + +#[derive(Clone, Copy)] +enum CompletionLaunchMode { + FillToCapacity, + SingleSlot, +} + +#[derive(Clone, Copy)] +struct GroupDispatchPlan { + initial_parallel: usize, + max_parallel: usize, + hedge_delay: Option, + completion_launch_mode: CompletionLaunchMode, +} + +impl GroupDispatchPlan { + fn from_dispatch(dispatch: &UpstreamDispatchRuntime) -> Self { + match dispatch { + UpstreamDispatchRuntime::Serial => Self { + initial_parallel: 1, + max_parallel: 1, + hedge_delay: None, + completion_launch_mode: CompletionLaunchMode::SingleSlot, + }, + UpstreamDispatchRuntime::Hedged { + delay, + max_parallel, + } => Self { + initial_parallel: 1, + max_parallel: *max_parallel, + hedge_delay: Some(*delay), + completion_launch_mode: CompletionLaunchMode::SingleSlot, + }, + UpstreamDispatchRuntime::Race { max_parallel } => Self { + initial_parallel: *max_parallel, + max_parallel: *max_parallel, + hedge_delay: None, + completion_launch_mode: CompletionLaunchMode::FillToCapacity, + }, + } + } + + fn completion_launch_slots(self, in_flight_len: usize) -> usize { + match self.completion_launch_mode { + CompletionLaunchMode::FillToCapacity => self.max_parallel.saturating_sub(in_flight_len), + CompletionLaunchMode::SingleSlot => usize::from(in_flight_len < self.max_parallel), + } + } +} + pub(super) struct PreparedUpstreamRequest { upstream_path_with_query: String, upstream_url: String, @@ -376,24 +433,288 @@ async fn try_group_upstreams( response_transform: FormatTransform, request_detail: Option, ) -> GroupAttemptResult { - let mut result = GroupAttemptResult::new(); let start = resolve_group_start(state, provider, group_index, items.len()); let order = state.upstream_selector.order_group( - state.config.upstream_strategy.clone(), + state.config.upstream_strategy.order, provider, items, start, ); - for item_index in order { - let upstream = &items[item_index]; - if let Some(inbound_format) = inbound_format { - if !upstream.supports_inbound(inbound_format) { - continue; + let eligible_order = filter_eligible_upstreams(order, items, inbound_format); + if eligible_order.is_empty() { + return GroupAttemptResult::new(); + } + dispatch_group_upstreams( + state, + method, + provider, + items, + &eligible_order, + inbound_path, + upstream_path_with_query, + headers, + body, + meta, + request_auth, + response_transform, + request_detail, + GroupDispatchPlan::from_dispatch(&state.config.upstream_strategy.dispatch), + ) + .await +} + +fn filter_eligible_upstreams( + order: Vec, + items: &[UpstreamRuntime], + inbound_format: Option, +) -> Vec { + order + .into_iter() + .filter(|item_index| { + inbound_format.is_none_or(|format| items[*item_index].supports_inbound(format)) + }) + .collect() +} + +fn apply_group_attempt_outcome( + state: &ProxyState, + provider: &str, + upstream: &UpstreamRuntime, + result: &mut GroupAttemptResult, + outcome: AttemptOutcome, +) -> bool { + match &outcome { + AttemptOutcome::Success(_) => { + state + .upstream_selector + .clear_cooldown(provider, upstream.selector_key.as_str()); + } + AttemptOutcome::Retryable { + should_cooldown: true, + .. + } => { + state + .upstream_selector + .mark_retryable_failure(provider, upstream.selector_key.as_str()); + } + _ => {} + } + if !matches!(outcome, AttemptOutcome::SkippedAuth) { + result.attempted += 1; + } + apply_attempt_outcome(result, outcome) +} + +async fn dispatch_group_upstreams( + state: &ProxyState, + method: Method, + provider: &str, + items: &[UpstreamRuntime], + order: &[usize], + inbound_path: &str, + upstream_path_with_query: &str, + headers: &HeaderMap, + body: &ReplayableBody, + meta: &RequestMeta, + request_auth: &RequestAuth, + response_transform: FormatTransform, + request_detail: Option, + dispatch_plan: GroupDispatchPlan, +) -> GroupAttemptResult { + let mut result = GroupAttemptResult::new(); + let mut in_flight: FuturesUnordered> = FuturesUnordered::new(); + let mut next_to_launch = 0usize; + + launch_group_attempts( + &mut in_flight, + &mut next_to_launch, + dispatch_plan.initial_parallel.min(order.len()), + state, + &method, + provider, + items, + order, + inbound_path, + upstream_path_with_query, + headers, + body, + meta, + request_auth, + response_transform, + &request_detail, + ); + + let mut hedge_timer = next_hedge_timer( + dispatch_plan.hedge_delay, + next_to_launch < order.len(), + in_flight.len(), + dispatch_plan.max_parallel, + ); + while next_to_launch < order.len() || !in_flight.is_empty() { + if in_flight.is_empty() { + let remaining = order.len() - next_to_launch; + launch_group_attempts( + &mut in_flight, + &mut next_to_launch, + dispatch_plan.initial_parallel.min(remaining), + state, + &method, + provider, + items, + order, + inbound_path, + upstream_path_with_query, + headers, + body, + meta, + request_auth, + response_transform, + &request_detail, + ); + hedge_timer = next_hedge_timer( + dispatch_plan.hedge_delay, + next_to_launch < order.len(), + in_flight.len(), + dispatch_plan.max_parallel, + ); + continue; + } + + let completed = if let Some(timer) = hedge_timer.as_mut() { + tokio::select! { + maybe = in_flight.next(), if !in_flight.is_empty() => maybe, + _ = timer.as_mut(), if next_to_launch < order.len() => { + launch_group_attempts( + &mut in_flight, + &mut next_to_launch, + 1, + state, + &method, + provider, + items, + order, + inbound_path, + upstream_path_with_query, + headers, + body, + meta, + request_auth, + response_transform, + &request_detail, + ); + None + } + } + } else { + in_flight.next().await + }; + + if let Some((item_index, outcome)) = completed { + let upstream = &items[item_index]; + if apply_group_attempt_outcome(state, provider, upstream, &mut result, outcome) { + return result; + } + let immediate_slots = dispatch_plan + .completion_launch_slots(in_flight.len()) + .min(order.len().saturating_sub(next_to_launch)); + if immediate_slots > 0 { + launch_group_attempts( + &mut in_flight, + &mut next_to_launch, + immediate_slots, + state, + &method, + provider, + items, + order, + inbound_path, + upstream_path_with_query, + headers, + body, + meta, + request_auth, + response_transform, + &request_detail, + ); } } + + hedge_timer = next_hedge_timer( + dispatch_plan.hedge_delay, + next_to_launch < order.len(), + in_flight.len(), + dispatch_plan.max_parallel, + ); + } + + result +} + +fn launch_group_attempts<'a>( + in_flight: &mut FuturesUnordered>, + next_to_launch: &mut usize, + slots: usize, + state: &'a ProxyState, + method: &Method, + provider: &'a str, + items: &'a [UpstreamRuntime], + order: &'a [usize], + inbound_path: &'a str, + upstream_path_with_query: &'a str, + headers: &'a HeaderMap, + body: &'a ReplayableBody, + meta: &'a RequestMeta, + request_auth: &'a RequestAuth, + response_transform: FormatTransform, + request_detail: &Option, +) { + for _ in 0..slots { + let Some(item_index) = order.get(*next_to_launch).copied() else { + break; + }; + *next_to_launch += 1; + enqueue_group_attempt( + in_flight, + state, + method, + provider, + items, + item_index, + inbound_path, + upstream_path_with_query, + headers, + body, + meta, + request_auth, + response_transform, + request_detail, + ); + } +} + +fn enqueue_group_attempt<'a>( + in_flight: &mut FuturesUnordered>, + state: &'a ProxyState, + method: &Method, + provider: &'a str, + items: &'a [UpstreamRuntime], + item_index: usize, + inbound_path: &'a str, + upstream_path_with_query: &'a str, + headers: &'a HeaderMap, + body: &'a ReplayableBody, + meta: &'a RequestMeta, + request_auth: &'a RequestAuth, + response_transform: FormatTransform, + request_detail: &Option, +) { + let upstream = &items[item_index]; + let method = method.clone(); + let request_detail = request_detail.clone(); + in_flight.push(Box::pin(async move { let outcome = attempt::attempt_upstream( state, - method.clone(), + method, provider, upstream, inbound_path, @@ -403,33 +724,26 @@ async fn try_group_upstreams( meta, request_auth, response_transform, - request_detail.clone(), + request_detail, ) .await; - match &outcome { - AttemptOutcome::Success(_) => { - state - .upstream_selector - .clear_cooldown(provider, upstream.selector_key.as_str()); - } - AttemptOutcome::Retryable { - should_cooldown: true, - .. - } => { - state - .upstream_selector - .mark_retryable_failure(provider, upstream.selector_key.as_str()); - } - _ => {} - } - if !matches!(outcome, AttemptOutcome::SkippedAuth) { - result.attempted += 1; - } - if apply_attempt_outcome(&mut result, outcome) { - return result; - } + (item_index, outcome) + })); +} + +fn next_hedge_timer( + hedged_request_delay: Option, + has_pending_attempts: bool, + in_flight_len: usize, + max_parallel: usize, +) -> Option>> { + let Some(hedged_request_delay) = hedged_request_delay else { + return None; + }; + if !has_pending_attempts || in_flight_len == 0 || in_flight_len >= max_parallel { + return None; } - result + Some(Box::pin(tokio::time::sleep(hedged_request_delay))) } async fn prepare_upstream_request( diff --git a/crates/token_proxy_core/src/proxy/upstream/utils.rs b/crates/token_proxy_core/src/proxy/upstream/utils.rs index 7369b46..1086e34 100644 --- a/crates/token_proxy_core/src/proxy/upstream/utils.rs +++ b/crates/token_proxy_core/src/proxy/upstream/utils.rs @@ -1,7 +1,7 @@ use axum::http::StatusCode; use std::sync::atomic::Ordering; -use super::super::{config::UpstreamStrategy, ProxyState}; +use super::super::{config::UpstreamOrderStrategy, ProxyState}; use crate::proxy::redact::redact_query_param_value; pub(super) fn extract_query_param(path_with_query: &str, name: &str) -> Option { @@ -47,9 +47,9 @@ pub(super) fn resolve_group_start( group_index: usize, group_len: usize, ) -> usize { - match state.config.upstream_strategy { - UpstreamStrategy::PriorityFillFirst => 0, - UpstreamStrategy::PriorityRoundRobin => state + match state.config.upstream_strategy.order { + UpstreamOrderStrategy::FillFirst => 0, + UpstreamOrderStrategy::RoundRobin => state .cursors .get(provider) .and_then(|cursors| cursors.get(group_index)) diff --git a/crates/token_proxy_core/src/proxy/upstream_selector.rs b/crates/token_proxy_core/src/proxy/upstream_selector.rs index f3f4fce..42ecc6a 100644 --- a/crates/token_proxy_core/src/proxy/upstream_selector.rs +++ b/crates/token_proxy_core/src/proxy/upstream_selector.rs @@ -4,7 +4,7 @@ use std::{ time::{Duration, Instant}, }; -use super::{config::UpstreamRuntime, config::UpstreamStrategy}; +use super::{config::UpstreamOrderStrategy, config::UpstreamRuntime}; #[derive(Hash, PartialEq, Eq)] struct CooldownKey { @@ -36,14 +36,14 @@ impl UpstreamSelectorRuntime { pub(crate) fn order_group( &self, - strategy: UpstreamStrategy, + order: UpstreamOrderStrategy, provider: &str, items: &[UpstreamRuntime], cursor_start: usize, ) -> Vec { - let base_order = match strategy { - UpstreamStrategy::PriorityFillFirst => (0..items.len()).collect(), - UpstreamStrategy::PriorityRoundRobin => (0..items.len()) + let base_order = match order { + UpstreamOrderStrategy::FillFirst => (0..items.len()).collect(), + UpstreamOrderStrategy::RoundRobin => (0..items.len()) .map(|offset| (cursor_start + offset) % items.len()) .collect(), }; diff --git a/crates/token_proxy_core/src/proxy/upstream_selector.test.rs b/crates/token_proxy_core/src/proxy/upstream_selector.test.rs index 1666a51..c7ea016 100644 --- a/crates/token_proxy_core/src/proxy/upstream_selector.test.rs +++ b/crates/token_proxy_core/src/proxy/upstream_selector.test.rs @@ -29,7 +29,7 @@ fn cooled_upstream_moves_behind_ready_candidates() { selector.mark_cooldown_until("responses", "a", Instant::now() + Duration::from_secs(10)); - let order = selector.order_group(UpstreamStrategy::PriorityFillFirst, "responses", &items, 0); + let order = selector.order_group(UpstreamOrderStrategy::FillFirst, "responses", &items, 0); assert_eq!(order, vec![1, 2, 0]); } @@ -43,7 +43,7 @@ fn all_cooled_upstreams_probe_earliest_expiry_first() { selector.mark_cooldown_until("responses", "b", Instant::now() + Duration::from_secs(5)); selector.mark_cooldown_until("responses", "c", Instant::now() + Duration::from_secs(10)); - let order = selector.order_group(UpstreamStrategy::PriorityFillFirst, "responses", &items, 0); + let order = selector.order_group(UpstreamOrderStrategy::FillFirst, "responses", &items, 0); assert_eq!(order, vec![1, 2, 0]); } @@ -56,7 +56,7 @@ fn clear_cooldown_restores_base_order() { selector.mark_cooldown_until("responses", "a", Instant::now() + Duration::from_secs(10)); selector.clear_cooldown("responses", "a"); - let order = selector.order_group(UpstreamStrategy::PriorityFillFirst, "responses", &items, 0); + let order = selector.order_group(UpstreamOrderStrategy::FillFirst, "responses", &items, 0); assert_eq!(order, vec![0, 1]); } @@ -68,7 +68,7 @@ fn zero_retryable_failure_cooldown_disables_cross_request_cooling() { selector.mark_retryable_failure("responses", "a"); - let order = selector.order_group(UpstreamStrategy::PriorityFillFirst, "responses", &items, 0); + let order = selector.order_group(UpstreamOrderStrategy::FillFirst, "responses", &items, 0); assert_eq!(order, vec![0, 1]); } @@ -80,7 +80,7 @@ fn extreme_retryable_failure_cooldown_does_not_panic() { let result = std::panic::catch_unwind(|| { selector.mark_retryable_failure("responses", "a"); - selector.order_group(UpstreamStrategy::PriorityFillFirst, "responses", &items, 0) + selector.order_group(UpstreamOrderStrategy::FillFirst, "responses", &items, 0) }); assert!(result.is_ok()); @@ -93,7 +93,7 @@ fn cooldown_distinguishes_runtime_items_with_same_logical_upstream_id() { selector.mark_retryable_failure("responses", "shared#1"); - let order = selector.order_group(UpstreamStrategy::PriorityFillFirst, "responses", &items, 0); + let order = selector.order_group(UpstreamOrderStrategy::FillFirst, "responses", &items, 0); assert_eq!(order, vec![1, 0]); } diff --git a/messages/en.json b/messages/en.json index 92f7f91..2791151 100644 --- a/messages/en.json +++ b/messages/en.json @@ -88,18 +88,23 @@ "auto_start_aria": "Enable launch at login", "auto_start_status_loading": "Loading autostart status...", "auto_start_status_error": "Failed to read autostart status: {message}", - "strategy_title": "Upstream Strategy", - "strategy_desc": "Choose how upstreams are selected globally.", - "strategy_label": "Strategy", - "strategy_placeholder": "Select strategy", - "strategy_help": "Priority round robin rotates within the highest priority group. Priority fill first uses the top entry until it fails.", - "upstream_strategy_priority_round_robin": "Priority Round Robin", - "upstream_strategy_priority_fill_first": "Priority Fill First", + "upstream_strategy_help": "Order controls candidate selection within the same priority group. Dispatch controls whether requests run serially, hedged after a delay, or race in parallel.", + "upstream_strategy_order_label": "Order", + "upstream_strategy_order_placeholder": "Select order", + "upstream_strategy_order_fill_first": "Fill First", + "upstream_strategy_order_round_robin": "Round Robin", + "upstream_strategy_dispatch_label": "Dispatch", + "upstream_strategy_dispatch_placeholder": "Select dispatch mode", + "upstream_strategy_dispatch_serial": "Serial", + "upstream_strategy_dispatch_hedged": "Hedged", + "upstream_strategy_dispatch_race": "Race", + "upstream_strategy_delay_ms_label": "Hedge Delay (ms)", + "upstream_strategy_max_parallel_label": "Max Parallel", "upstreams_title": "Upstreams", "upstreams_desc": "Define provider pools and credentials. Use provider names openai, openai-response, anthropic, gemini, kiro, codex, and antigravity.", "upstreams_add": "Add Upstream", "upstreams_empty": "No upstreams defined yet.", - "upstreams_tip": "Tip: Priority sorts upstreams in descending order. If priorities tie, the list order is used.", + "upstreams_tip": "Tip: Priority sorts upstreams in descending order. Order applies within the same priority group.", "upstreams_show_api_keys": "Show API Keys", "upstreams_hide_api_keys": "Hide API Keys", "upstreams_column_id": "Id", @@ -442,6 +447,8 @@ "error_upstream_kiro_account_required": "Upstream {id} requires a Kiro account.", "error_upstream_base_url_required": "Upstream {id} base URL is required.", "error_upstream_priority_integer": "Upstream {id} priority must be an integer.", + "error_upstream_strategy_delay_ms_positive_integer": "Hedge delay must be a positive integer.", + "error_upstream_strategy_max_parallel_min": "Max parallel must be an integer greater than or equal to {min}.", "error_upstream_multiple_api_keys_unsupported": "Upstream {id} does not support multiple API keys for account-based providers.", "error_model_mapping_pattern_required": "Upstream {id} mapping row {row} pattern is required.", "error_model_mapping_target_required": "Upstream {id} mapping row {row} target is required.", diff --git a/messages/zh.json b/messages/zh.json index a989555..f3d10df 100644 --- a/messages/zh.json +++ b/messages/zh.json @@ -88,18 +88,23 @@ "auto_start_aria": "启用开机启动", "auto_start_status_loading": "正在读取开机启动状态...", "auto_start_status_error": "读取开机启动状态失败:{message}", - "strategy_title": "上游策略", - "strategy_desc": "选择全局上游的选择方式。", - "strategy_label": "策略", - "strategy_placeholder": "选择策略", - "strategy_help": "Priority round robin 会在最高优先级组内轮询。Priority fill first 会优先使用最高项,直到失败为止。", - "upstream_strategy_priority_round_robin": "优先级轮询", - "upstream_strategy_priority_fill_first": "优先级填充优先", + "upstream_strategy_help": "Order 控制同一优先级组内的候选顺序;Dispatch 控制请求是串行、延迟 hedge,还是并行 race。", + "upstream_strategy_order_label": "顺序", + "upstream_strategy_order_placeholder": "选择顺序", + "upstream_strategy_order_fill_first": "优先填满", + "upstream_strategy_order_round_robin": "轮询", + "upstream_strategy_dispatch_label": "派发", + "upstream_strategy_dispatch_placeholder": "选择派发模式", + "upstream_strategy_dispatch_serial": "串行", + "upstream_strategy_dispatch_hedged": "Hedged", + "upstream_strategy_dispatch_race": "Race", + "upstream_strategy_delay_ms_label": "Hedge 延迟 (ms)", + "upstream_strategy_max_parallel_label": "最大并发数", "upstreams_title": "上游", "upstreams_desc": "定义提供商池与凭据。provider 名称使用 openai、openai-response、anthropic、gemini、kiro、codex、antigravity。", "upstreams_add": "添加上游", "upstreams_empty": "暂无上游。", - "upstreams_tip": "提示:Priority 按降序排序上游;同 Priority 时按列表顺序。", + "upstreams_tip": "提示:Priority 按降序排序上游;同一 Priority 组内再应用 Order。", "upstreams_show_api_keys": "显示 API Key", "upstreams_hide_api_keys": "隐藏 API Key", "upstreams_column_id": "ID", @@ -442,6 +447,8 @@ "error_upstream_provider_required": "上游 {id} 的 provider 为必填。", "error_upstream_base_url_required": "上游 {id} 的 base URL 为必填。", "error_upstream_priority_integer": "上游 {id} 的 priority 必须是整数。", + "error_upstream_strategy_delay_ms_positive_integer": "Hedge 延迟必须是正整数。", + "error_upstream_strategy_max_parallel_min": "最大并发数必须是不小于 {min} 的整数。", "error_upstream_multiple_api_keys_unsupported": "上游 {id} 在账号型 provider 上不支持配置多个 API key。", "error_model_mapping_pattern_required": "上游 {id} 第 {row} 行映射的匹配模式不能为空。", "error_model_mapping_target_required": "上游 {id} 第 {row} 行映射的目标模型不能为空。", diff --git a/src/features/config/cards/upstreams-card.tsx b/src/features/config/cards/upstreams-card.tsx index f0d8897..ff2db44 100644 --- a/src/features/config/cards/upstreams-card.tsx +++ b/src/features/config/cards/upstreams-card.tsx @@ -37,17 +37,17 @@ import { createEmptyUpstream } from "@/features/config/form"; import { useCodexAccounts } from "@/features/codex/use-codex-accounts"; import { useKiroAccounts } from "@/features/kiro/use-kiro-accounts"; import { useAntigravityAccounts } from "@/features/antigravity/use-antigravity-accounts"; -import type { UpstreamForm, UpstreamStrategy } from "@/features/config/types"; +import type { ConfigForm, UpstreamForm } from "@/features/config/types"; import { m } from "@/paraglide/messages.js"; type UpstreamsCardProps = { upstreams: UpstreamForm[]; appProxyUrl: string; - strategy: UpstreamStrategy; + strategy: ConfigForm["upstreamStrategy"]; showApiKeys: boolean; providerOptions: string[]; onToggleApiKeys: () => void; - onStrategyChange: (value: UpstreamStrategy) => void; + onStrategyChange: (value: ConfigForm["upstreamStrategy"]) => void; onAdd: (upstream: UpstreamForm) => void; onRemove: (index: number) => void; onChange: (index: number, patch: Partial) => void; diff --git a/src/features/config/cards/upstreams/table.tsx b/src/features/config/cards/upstreams/table.tsx index 7c24a96..86981cc 100644 --- a/src/features/config/cards/upstreams/table.tsx +++ b/src/features/config/cards/upstreams/table.tsx @@ -5,6 +5,7 @@ import { Ban, Check, Columns3, Copy, Eye, EyeOff, Pencil, Trash2 } from "lucide- import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; import { Select, @@ -24,27 +25,41 @@ import type { UpstreamColumnDefinition, UpstreamColumnId } from "@/features/conf import type { CodexAccountSummary } from "@/features/codex/types"; import type { KiroAccountSummary } from "@/features/kiro/types"; import type { AntigravityAccountSummary } from "@/features/antigravity/types"; -import { UPSTREAM_STRATEGIES, type UpstreamForm, type UpstreamStrategy } from "@/features/config/types"; +import { + UPSTREAM_DISPATCH_STRATEGIES, + UPSTREAM_ORDER_STRATEGIES, + type ConfigForm, + type UpstreamDispatchType, + type UpstreamForm, + type UpstreamOrderStrategy, +} from "@/features/config/types"; import { m } from "@/paraglide/messages.js"; type UpstreamsToolbarProps = { apiKeyVisible: boolean; showApiKeys: boolean; - strategy: UpstreamStrategy; + strategy: ConfigForm["upstreamStrategy"]; onToggleApiKeys: () => void; - onStrategyChange: (value: UpstreamStrategy) => void; + onStrategyChange: (value: ConfigForm["upstreamStrategy"]) => void; onAddClick: () => void; onColumnsClick: () => void; }; -const UPSTREAM_STRATEGY_VALUES: ReadonlySet = new Set( - UPSTREAM_STRATEGIES.map((strategy) => strategy.value) +const UPSTREAM_ORDER_VALUES: ReadonlySet = new Set( + UPSTREAM_ORDER_STRATEGIES.map((strategy) => strategy.value) +); +const UPSTREAM_DISPATCH_VALUES: ReadonlySet = new Set( + UPSTREAM_DISPATCH_STRATEGIES.map((strategy) => strategy.value) ); const CELL_PLACEHOLDER = "—"; const TOOLTIP_CONTENT_CLASS = "max-w-[560px] whitespace-pre-wrap break-words"; -function toUpstreamStrategy(value: string): UpstreamStrategy | null { - return UPSTREAM_STRATEGY_VALUES.has(value) ? (value as UpstreamStrategy) : null; +function toUpstreamOrderStrategy(value: string): UpstreamOrderStrategy | null { + return UPSTREAM_ORDER_VALUES.has(value) ? (value as UpstreamOrderStrategy) : null; +} + +function toUpstreamDispatchType(value: string): UpstreamDispatchType | null { + return UPSTREAM_DISPATCH_VALUES.has(value) ? (value as UpstreamDispatchType) : null; } type CellTooltipProps = { @@ -81,40 +96,62 @@ export function UpstreamsToolbar({ onAddClick, onColumnsClick, }: UpstreamsToolbarProps) { + const updateStrategy = (patch: Partial) => { + onStrategyChange({ + ...strategy, + ...patch, + }); + }; + const showsHedgeDelay = strategy.dispatchType === "hedged"; + const showsMaxParallel = strategy.dispatchType !== "serial"; + return ( -
-
- - -
-
-
-
+

{m.upstream_strategy_help()}

); } diff --git a/src/features/config/form.test.ts b/src/features/config/form.test.ts index fc13e31..478fb14 100644 --- a/src/features/config/form.test.ts +++ b/src/features/config/form.test.ts @@ -155,11 +155,22 @@ describe("config/form", () => { enabled: true, format: "split", }, - upstream_strategy: "priority_fill_first", + upstream_strategy: { + order: "fill_first", + dispatch: { + type: "serial", + }, + }, }); expect(form.upstreamNoDataTimeoutSecs).toBe("120"); expect(form.upstreams[0]?.apiKeys).toBe("key-a, key-b"); + expect(form.upstreamStrategy).toEqual({ + order: "fill_first", + dispatchType: "serial", + hedgeDelayMs: "2000", + maxParallel: "2", + }); }); it("serializes upstream no data timeout seconds", () => { @@ -171,6 +182,85 @@ describe("config/form", () => { expect(payload.upstream_no_data_timeout_secs).toBe(45); }); + it("serializes structured upstream strategy", () => { + const payload = toPayload({ + ...EMPTY_FORM, + upstreamStrategy: { + order: "round_robin", + dispatchType: "hedged", + hedgeDelayMs: "1500", + maxParallel: "3", + }, + }); + + expect(payload.upstream_strategy).toEqual({ + order: "round_robin", + dispatch: { + type: "hedged", + delay_ms: 1500, + max_parallel: 3, + }, + }); + }); + + it("validates hedged delay as positive integer", () => { + expect( + validate({ + ...EMPTY_FORM, + upstreamStrategy: { + ...EMPTY_FORM.upstreamStrategy, + dispatchType: "hedged", + hedgeDelayMs: "0", + }, + }).valid + ).toBe(false); + + expect( + validate({ + ...EMPTY_FORM, + upstreamStrategy: { + ...EMPTY_FORM.upstreamStrategy, + dispatchType: "hedged", + hedgeDelayMs: "1", + }, + }).valid + ).toBe(true); + }); + + it("validates race and hedged max parallel as integer >= 2", () => { + expect( + validate({ + ...EMPTY_FORM, + upstreamStrategy: { + ...EMPTY_FORM.upstreamStrategy, + dispatchType: "hedged", + maxParallel: "1", + }, + }).valid + ).toBe(false); + + expect( + validate({ + ...EMPTY_FORM, + upstreamStrategy: { + ...EMPTY_FORM.upstreamStrategy, + dispatchType: "race", + maxParallel: "1", + }, + }).valid + ).toBe(false); + + expect( + validate({ + ...EMPTY_FORM, + upstreamStrategy: { + ...EMPTY_FORM.upstreamStrategy, + dispatchType: "race", + maxParallel: "2", + }, + }).valid + ).toBe(true); + }); it("serializes openai compatibility upstream flags", () => { const upstream = createEmptyUpstream(); upstream.id = "glm-coding-plan"; diff --git a/src/features/config/form.ts b/src/features/config/form.ts index a918325..7b051d7 100644 --- a/src/features/config/form.ts +++ b/src/features/config/form.ts @@ -6,7 +6,9 @@ import { type ProxyConfigFile, type ProxyConfigFileBase, type TrayTokenRateConfig, + type UpstreamDispatchStrategy, type UpstreamForm, + type UpstreamStrategy, TRAY_TOKEN_RATE_FORMATS, } from "@/features/config/types"; import { createNativeInboundFormatSet, removeInboundFormatsInSet } from "@/features/config/inbound-formats"; @@ -19,8 +21,12 @@ const DEFAULT_TRAY_TOKEN_RATE: TrayTokenRateConfig = { const MIN_UPSTREAM_NO_DATA_TIMEOUT_SECS = 3; const DEFAULT_UPSTREAM_NO_DATA_TIMEOUT_SECS = 120; +const DEFAULT_HEDGE_DELAY_MS = 2000; +const DEFAULT_MAX_PARALLEL = 2; +const MIN_PARALLEL_ATTEMPTS = 2; const INTEGER_PATTERN = /^-?\d+$/; const NON_NEGATIVE_INTEGER_PATTERN = /^\d+$/; +const POSITIVE_INTEGER_PATTERN = /^[1-9]\d*$/; let modelMappingCounter = 0; const TRAY_TOKEN_RATE_FORMAT_VALUES: ReadonlySet = new Set( @@ -96,7 +102,12 @@ export const EMPTY_FORM: ConfigForm = { retryableFailureCooldownSecs: "15", upstreamNoDataTimeoutSecs: String(DEFAULT_UPSTREAM_NO_DATA_TIMEOUT_SECS), trayTokenRate: { ...DEFAULT_TRAY_TOKEN_RATE }, - upstreamStrategy: "priority_fill_first", + upstreamStrategy: { + order: "fill_first", + dispatchType: "serial", + hedgeDelayMs: String(DEFAULT_HEDGE_DELAY_MS), + maxParallel: String(DEFAULT_MAX_PARALLEL), + }, upstreams: [], }; @@ -171,7 +182,7 @@ export function toForm(config: ProxyConfigFile): ConfigForm { config.upstream_no_data_timeout_secs ?? DEFAULT_UPSTREAM_NO_DATA_TIMEOUT_SECS, ), trayTokenRate: normalizeTrayTokenRate(config.tray_token_rate), - upstreamStrategy: config.upstream_strategy, + upstreamStrategy: toUpstreamStrategyForm(config.upstream_strategy), upstreams: config.upstreams.map((upstream) => ({ id: upstream.id, providers: upstream.providers ?? [], @@ -219,7 +230,7 @@ export function toPayload(form: ConfigForm): ProxyConfigFile { form.upstreamNoDataTimeoutSecs, ), tray_token_rate: form.trayTokenRate, - upstream_strategy: form.upstreamStrategy, + upstream_strategy: toUpstreamStrategyPayload(form.upstreamStrategy), upstreams: form.upstreams.map((upstream) => { const providers = normalizeProviders(upstream.providers); const apiKeys = parseApiKeysInput(upstream.apiKeys); @@ -276,6 +287,10 @@ export function validate(form: ConfigForm) { message: m.error_upstream_no_data_timeout_secs_integer(), }; } + const upstreamStrategyError = validateUpstreamStrategy(form.upstreamStrategy); + if (upstreamStrategyError) { + return { valid: false, message: upstreamStrategyError }; + } const ids = new Set(); for (const upstream of form.upstreams) { @@ -430,6 +445,76 @@ function normalizeTrayTokenRate(value: TrayTokenRateConfig) { return value; } +function toUpstreamStrategyForm(strategy: UpstreamStrategy): ConfigForm["upstreamStrategy"] { + switch (strategy.dispatch.type) { + case "serial": + return { + order: strategy.order, + dispatchType: "serial", + hedgeDelayMs: String(DEFAULT_HEDGE_DELAY_MS), + maxParallel: String(DEFAULT_MAX_PARALLEL), + }; + case "hedged": + return { + order: strategy.order, + dispatchType: "hedged", + hedgeDelayMs: String(strategy.dispatch.delay_ms), + maxParallel: String(strategy.dispatch.max_parallel), + }; + case "race": + return { + order: strategy.order, + dispatchType: "race", + hedgeDelayMs: String(DEFAULT_HEDGE_DELAY_MS), + maxParallel: String(strategy.dispatch.max_parallel), + }; + } +} + +function toUpstreamStrategyPayload( + strategy: ConfigForm["upstreamStrategy"], +): UpstreamStrategy { + return { + order: strategy.order, + dispatch: toUpstreamDispatchPayload(strategy), + }; +} + +function toUpstreamDispatchPayload( + strategy: ConfigForm["upstreamStrategy"], +): UpstreamDispatchStrategy { + switch (strategy.dispatchType) { + case "serial": + return { type: "serial" }; + case "hedged": + return { + type: "hedged", + delay_ms: parsePositiveInteger(strategy.hedgeDelayMs, DEFAULT_HEDGE_DELAY_MS), + max_parallel: parseMinParallel(strategy.maxParallel, DEFAULT_MAX_PARALLEL), + }; + case "race": + return { + type: "race", + max_parallel: parseMinParallel(strategy.maxParallel, DEFAULT_MAX_PARALLEL), + }; + } +} + +function validateUpstreamStrategy(strategy: ConfigForm["upstreamStrategy"]) { + if (strategy.dispatchType === "serial") { + return ""; + } + if (strategy.dispatchType === "hedged" && !isPositiveInteger(strategy.hedgeDelayMs)) { + return m.error_upstream_strategy_delay_ms_positive_integer(); + } + if (!isValidMinParallel(strategy.maxParallel)) { + return m.error_upstream_strategy_max_parallel_min({ + min: String(MIN_PARALLEL_ATTEMPTS), + }); + } + return ""; +} + function normalizeProviders(values: readonly string[]) { const seen = new Set(); const output: string[] = []; @@ -603,3 +688,26 @@ function parseUpstreamNoDataTimeoutSecs(value: string) { const number = Number.parseInt(trimmed, 10); return Number.isFinite(number) ? number : DEFAULT_UPSTREAM_NO_DATA_TIMEOUT_SECS; } + +function isPositiveInteger(value: string) { + return POSITIVE_INTEGER_PATTERN.test(value.trim()); +} + +function parsePositiveInteger(value: string, fallback: number) { + const trimmed = value.trim(); + if (!POSITIVE_INTEGER_PATTERN.test(trimmed)) { + return fallback; + } + const number = Number.parseInt(trimmed, 10); + return Number.isFinite(number) ? number : fallback; +} + +function isValidMinParallel(value: string) { + const parsed = parsePositiveInteger(value, 0); + return parsed >= MIN_PARALLEL_ATTEMPTS; +} + +function parseMinParallel(value: string, fallback: number) { + const parsed = parsePositiveInteger(value, fallback); + return parsed >= MIN_PARALLEL_ATTEMPTS ? parsed : fallback; +} diff --git a/src/features/config/types.ts b/src/features/config/types.ts index cba2402..6274b9f 100644 --- a/src/features/config/types.ts +++ b/src/features/config/types.ts @@ -1,11 +1,29 @@ import { m } from "@/paraglide/messages.js"; -export const UPSTREAM_STRATEGIES = [ - { value: "priority_fill_first", label: () => m.upstream_strategy_priority_fill_first() }, - { value: "priority_round_robin", label: () => m.upstream_strategy_priority_round_robin() }, +export const UPSTREAM_ORDER_STRATEGIES = [ + { value: "fill_first", label: () => m.upstream_strategy_order_fill_first() }, + { value: "round_robin", label: () => m.upstream_strategy_order_round_robin() }, ] as const; -export type UpstreamStrategy = (typeof UPSTREAM_STRATEGIES)[number]["value"]; +export type UpstreamOrderStrategy = (typeof UPSTREAM_ORDER_STRATEGIES)[number]["value"]; + +export const UPSTREAM_DISPATCH_STRATEGIES = [ + { value: "serial", label: () => m.upstream_strategy_dispatch_serial() }, + { value: "hedged", label: () => m.upstream_strategy_dispatch_hedged() }, + { value: "race", label: () => m.upstream_strategy_dispatch_race() }, +] as const; + +export type UpstreamDispatchType = (typeof UPSTREAM_DISPATCH_STRATEGIES)[number]["value"]; + +export type UpstreamDispatchStrategy = + | { type: "serial" } + | { type: "hedged"; delay_ms: number; max_parallel: number } + | { type: "race"; max_parallel: number }; + +export type UpstreamStrategy = { + order: UpstreamOrderStrategy; + dispatch: UpstreamDispatchStrategy; +}; export const TRAY_TOKEN_RATE_FORMATS = [ { value: "combined", label: () => m.proxy_core_tray_token_rate_format_combined() }, @@ -174,6 +192,11 @@ export type ConfigForm = { retryableFailureCooldownSecs: string; upstreamNoDataTimeoutSecs: string; trayTokenRate: TrayTokenRateConfig; - upstreamStrategy: UpstreamStrategy; + upstreamStrategy: { + order: UpstreamOrderStrategy; + dispatchType: UpstreamDispatchType; + hedgeDelayMs: string; + maxParallel: string; + }; upstreams: UpstreamForm[]; };