diff --git a/Cargo.lock b/Cargo.lock index 308dd56..457f866 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -73,12 +73,41 @@ version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "atomic-waker" version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "axum-core" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "sync_wrapper", + "tower-layer", + "tower-service", +] + [[package]] name = "base64" version = "0.22.1" @@ -293,6 +322,24 @@ dependencies = [ "litrs", ] +[[package]] +name = "edgee-ai-gateway-core" +version = "0.1.0" +dependencies = [ + "async-trait", + "axum-core", + "bytes", + "futures", + "http", + "reqwest 0.13.2", + "serde", + "serde_json", + "thiserror 2.0.18", + "tokio", + "tower", + "tracing", +] + [[package]] name = "edgee-cli" version = "0.2.1" @@ -303,7 +350,7 @@ dependencies = [ "console 0.15.11", "dialoguer", "open", - "reqwest", + "reqwest 0.13.2", "self_update", "serde", "serde_json", @@ -313,6 +360,15 @@ dependencies = [ "uuid", ] +[[package]] +name = "edgee-compression-layer" +version = "0.1.0" +dependencies = [ + "edgee-ai-gateway-core", + "edgee-compressor", + "tower", +] + [[package]] name = "edgee-compressor" version = "0.1.0" @@ -398,6 +454,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.32" @@ -414,12 +485,34 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" +[[package]] +name = "futures-executor" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + [[package]] name = "futures-io" version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" +[[package]] +name = "futures-macro" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.32" @@ -438,8 +531,10 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" dependencies = [ + "futures-channel", "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -854,6 +949,15 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "11d3d7f243d5c5a8b9bb5d6dd2b1602c0cb0b9db1621bafc7ed66e35ff9fe092" +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + [[package]] name = "log" version = "0.4.29" @@ -872,6 +976,12 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + [[package]] name = "miniz_oxide" version = "0.8.9" @@ -922,6 +1032,29 @@ dependencies = [ "pathdiff", ] +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + [[package]] name = "pathdiff" version = "0.2.3" @@ -1056,7 +1189,7 @@ dependencies = [ "once_cell", "socket2", "tracing", - "windows-sys 0.52.0", + "windows-sys 0.60.2", ] [[package]] @@ -1109,6 +1242,15 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + [[package]] name = "regex" version = "1.12.3" @@ -1179,6 +1321,40 @@ dependencies = [ "webpki-roots", ] +[[package]] +name = "reqwest" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab3f43e3283ab1488b624b44b0e988d0acea0b3214e694730a055cb6b2efa801" +dependencies = [ + "base64", + "bytes", + "futures-core", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "js-sys", + "log", + "percent-encoding", + "pin-project-lite", + "serde", + "serde_json", + "sync_wrapper", + "tokio", + "tokio-util", + "tower", + "tower-http", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "wasm-streams", + "web-sys", +] + [[package]] name = "ring" version = "0.17.14" @@ -1260,6 +1436,12 @@ version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + [[package]] name = "self-replace" version = "1.5.0" @@ -1282,7 +1464,7 @@ dependencies = [ "log", "quick-xml", "regex", - "reqwest", + "reqwest 0.12.28", "self-replace", "semver", "serde", @@ -1590,6 +1772,7 @@ dependencies = [ "bytes", "libc", "mio", + "parking_lot", "pin-project-lite", "signal-hook-registry", "socket2", @@ -1985,6 +2168,19 @@ dependencies = [ "wasmparser", ] +[[package]] +name = "wasm-streams" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d1ec4f6517c9e11ae630e200b2b65d193279042e28edd4a2cda233e46670bbb" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "wasmparser" version = "0.244.0" diff --git a/Cargo.toml b/Cargo.toml index 0b29cea..53e2f50 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,19 +1,28 @@ [workspace] -members = ["crates/cli", "crates/compressor"] -resolver = "2" +members = ["crates/*"] +resolver = "3" [workspace.dependencies] anyhow = "1" +async-trait = "0.1" +axum-core = "0.5" +bytes = "1" clap = "4" colored = "2" console = "0.15" dialoguer = "0.11" +futures = "0.3" +http = "1" +http-body-util = "0.1" open = "5" +reqwest = { version = "0.13", default-features = false } +self_update = { version = "0.43.1", default-features = false } serde = "1" -toml = "0.8" -uuid = "1" -tokio = "1" -reqwest = { version = "0.12", default-features = false } serde_json = "1" -self_update = { version = "0.43.1", default-features = false } +thiserror = "2" time = { version = "0.3" } +tokio = "1" +toml = "0.8" +tower = { version = "0.5", features = ["util"] } +tracing = "0.1" +uuid = "1" diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index 2af9d8d..ede79d2 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -18,7 +18,7 @@ serde = { workspace = true, features = ["derive"] } toml.workspace = true uuid = { workspace = true, features = ["v4"] } tokio = { workspace = true, features = ["rt-multi-thread", "macros", "net", "signal", "io-util"] } -reqwest = { workspace = true, features = ["json", "rustls-tls"] } +reqwest = { workspace = true, features = ["json"] } serde_json.workspace = true self_update = { workspace = true, optional = true, features = ["reqwest", "rustls"] } time = { workspace = true, features = ["formatting", "parsing", "serde"] } diff --git a/crates/compression-layer/Cargo.toml b/crates/compression-layer/Cargo.toml new file mode 100644 index 0000000..f94f5d9 --- /dev/null +++ b/crates/compression-layer/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "edgee-compression-layer" +version = "0.1.0" +edition = "2024" +description = "Tower Layer that compresses LLM tool-result content before provider dispatch" + +[dependencies] +edgee-ai-gateway-core = { path = "../gateway-core" } +edgee-compressor = { path = "../compressor" } +tower = { workspace = true } diff --git a/crates/compression-layer/src/compress.rs b/crates/compression-layer/src/compress.rs new file mode 100644 index 0000000..22773dc --- /dev/null +++ b/crates/compression-layer/src/compress.rs @@ -0,0 +1,167 @@ +use std::collections::HashMap; + +use edgee_ai_gateway_core::{ + CompletionRequest, + types::{Message, MessageContent}, +}; + +use crate::config::{AgentType, CompressionConfig}; + +/// Walk `req.messages`, compressing tool-result content in-place. +/// +/// Two sweeps: +/// 1. Build `tool_call_id → (name, arguments)` from every AssistantMessage. +/// 2. For each ToolMessage, look up the tool name + arguments, compress the +/// content, and replace it if the compressor produced a shorter result. +pub fn compress_request( + config: &CompressionConfig, + mut req: CompletionRequest, +) -> CompletionRequest { + // Sweep 1 — index tool calls by id + let mut call_index: HashMap = HashMap::new(); + for msg in &req.messages { + if let Message::Assistant(a) = msg + && let Some(calls) = &a.tool_calls + { + for call in calls { + call_index.insert( + call.id.clone(), + (call.function.name.clone(), call.function.arguments.clone()), + ); + } + } + } + + // Sweep 2 — compress ToolMessage content + for msg in &mut req.messages { + if let Message::Tool(tool_msg) = msg { + let Some((name, arguments)) = call_index.get(&tool_msg.tool_call_id) else { + continue; + }; + + let compressor = match config.agent { + AgentType::Claude => edgee_compressor::claude_compressor_for(name), + AgentType::Codex => edgee_compressor::codex_compressor_for(name), + AgentType::OpenCode => edgee_compressor::opencode_compressor_for(name), + }; + + let Some(compressor) = compressor else { + continue; + }; + + let text = tool_msg.content.as_text(); + if let Some(compressed) = edgee_compressor::compress_claude_tool_with_segment_protection( + compressor, arguments, &text, + ) { + tool_msg.content = MessageContent::Text(compressed); + } + } + } + + req +} + +#[cfg(test)] +mod tests { + use edgee_ai_gateway_core::{ + CompletionRequest, + types::{ + AssistantMessage, FunctionCall, Message, MessageContent, ToolCall, ToolMessage, + UserMessage, + }, + }; + + use crate::config::{AgentType, CompressionConfig}; + + use super::compress_request; + + fn glob_output(n: usize) -> String { + // Produce `n` fake file paths spread across a few directories so the + // Glob compressor can actually group them (threshold: >30 paths). + let dirs = ["src/alpha", "src/beta", "src/gamma", "src/delta"]; + (0..n) + .map(|i| format!("{}/file_{i}.rs", dirs[i % dirs.len()])) + .collect::>() + .join("\n") + } + + #[test] + fn compresses_glob_tool_result() { + let output = glob_output(50); + let original_len = output.len(); + + let req = CompletionRequest::new( + "claude-3-5-sonnet".to_string(), + vec![ + Message::User(UserMessage { + name: None, + content: MessageContent::Text("list files".into()), + cache_control: None, + }), + Message::Assistant(AssistantMessage { + name: None, + content: None, + refusal: None, + cache_control: None, + tool_calls: Some(vec![ToolCall { + id: "call_1".into(), + tool_type: "function".into(), + function: FunctionCall { + name: "Glob".into(), + arguments: r#"{"pattern":"**/*.rs"}"#.into(), + }, + }]), + }), + Message::Tool(ToolMessage { + tool_call_id: "call_1".into(), + content: MessageContent::Text(output), + }), + ], + ); + + let config = CompressionConfig { + agent: AgentType::Claude, + }; + let compressed = compress_request(&config, req); + + let tool_msg = compressed.messages.iter().find_map(|m| { + if let Message::Tool(t) = m { + Some(t) + } else { + None + } + }); + + let compressed_len = tool_msg.unwrap().content.as_text().len(); + assert!( + compressed_len < original_len, + "expected compression: {compressed_len} < {original_len}" + ); + } + + #[test] + fn skips_unknown_tool_call_id() { + let req = CompletionRequest::new( + "claude-3-5-sonnet".to_string(), + vec![Message::Tool(ToolMessage { + tool_call_id: "orphan".into(), + content: MessageContent::Text("some output".into()), + })], + ); + + let config = CompressionConfig { + agent: AgentType::Claude, + }; + let result = compress_request(&config, req); + + // Content should be unchanged + let tool_msg = result.messages.iter().find_map(|m| { + if let Message::Tool(t) = m { + Some(t) + } else { + None + } + }); + assert_eq!(tool_msg.unwrap().content.as_text(), "some output"); + } +} diff --git a/crates/compression-layer/src/config.rs b/crates/compression-layer/src/config.rs new file mode 100644 index 0000000..d78588b --- /dev/null +++ b/crates/compression-layer/src/config.rs @@ -0,0 +1,24 @@ +use std::sync::Arc; + +/// Which agent's tool-name conventions to use when dispatching compressors. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AgentType { + /// Claude Code — tool names: `Bash`, `Read`, `Grep`, `Glob` + Claude, + /// Codex CLI — tool names: `shell_command`, `read_file`, `grep`, `list_directory` + Codex, + /// OpenCode — tool names: `bash`, `read`, `grep`, `glob` + OpenCode, +} + +/// Configuration for the compression layer. +#[derive(Debug, Clone)] +pub struct CompressionConfig { + pub agent: AgentType, +} + +impl CompressionConfig { + pub fn new(agent: AgentType) -> Arc { + Arc::new(Self { agent }) + } +} diff --git a/crates/compression-layer/src/layer.rs b/crates/compression-layer/src/layer.rs new file mode 100644 index 0000000..870b6fb --- /dev/null +++ b/crates/compression-layer/src/layer.rs @@ -0,0 +1,34 @@ +use std::sync::Arc; + +use crate::{config::CompressionConfig, service::CompressionService}; + +/// Tower [`Layer`] that wraps a downstream service with tool-result compression. +/// +/// Construct via [`CompressionLayer::new`], then compose with Tower's +/// [`ServiceBuilder`](tower::ServiceBuilder): +/// +/// ```rust,ignore +/// let svc = ServiceBuilder::new() +/// .layer(CompressionLayer::new(CompressionConfig::new(AgentType::Claude))) +/// .service(dispatch_service); +/// ``` +#[derive(Clone)] +pub struct CompressionLayer { + config: Arc, +} + +impl CompressionLayer { + pub fn new(config: impl Into>) -> Self { + Self { + config: config.into(), + } + } +} + +impl tower::Layer for CompressionLayer { + type Service = CompressionService; + + fn layer(&self, inner: S) -> Self::Service { + CompressionService::new(inner, Arc::clone(&self.config)) + } +} diff --git a/crates/compression-layer/src/lib.rs b/crates/compression-layer/src/lib.rs new file mode 100644 index 0000000..7757524 --- /dev/null +++ b/crates/compression-layer/src/lib.rs @@ -0,0 +1,8 @@ +pub mod compress; +pub mod config; +pub mod layer; +pub mod service; + +pub use config::{AgentType, CompressionConfig}; +pub use layer::CompressionLayer; +pub use service::CompressionService; diff --git a/crates/compression-layer/src/service.rs b/crates/compression-layer/src/service.rs new file mode 100644 index 0000000..4420570 --- /dev/null +++ b/crates/compression-layer/src/service.rs @@ -0,0 +1,43 @@ +use std::{ + sync::Arc, + task::{Context, Poll}, +}; + +use edgee_ai_gateway_core::CompletionRequest; +use tower::Service; + +use crate::{compress::compress_request, config::CompressionConfig}; + +/// Tower [`Service`] produced by [`CompressionLayer`](crate::CompressionLayer). +/// +/// Intercepts each [`CompletionRequest`], compresses tool-result content +/// in-place, then delegates to the wrapped inner service. +pub struct CompressionService { + inner: S, + config: Arc, +} + +impl CompressionService { + pub(crate) fn new(inner: S, config: Arc) -> Self { + Self { inner, config } + } +} + +impl Service for CompressionService +where + S: Service, +{ + type Response = S::Response; + type Error = S::Error; + // Compression is synchronous — no extra future wrapping needed. + type Future = S::Future; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: CompletionRequest) -> Self::Future { + let compressed = compress_request(&self.config, req); + self.inner.call(compressed) + } +} diff --git a/crates/gateway-core/Cargo.toml b/crates/gateway-core/Cargo.toml new file mode 100644 index 0000000..d47bcb5 --- /dev/null +++ b/crates/gateway-core/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "edgee-ai-gateway-core" +version = "0.1.0" +edition = "2024" +description = "Core LLM request→response pipeline for the Edgee AI Gateway" + +[features] +default = [] + +tokio = ["dep:reqwest", "dep:tokio"] + +[dependencies] +tower.workspace = true +futures.workspace = true +async-trait.workspace = true +thiserror.workspace = true +tracing.workspace = true +http.workspace = true +axum-core.workspace = true +bytes.workspace = true +serde = { workspace = true, features = ["derive"] } +serde_json.workspace = true + +# Runtime-specific — only compiled with the `tokio` feature +reqwest = { workspace = true, default-features = false, features = ["stream"], optional = true } +tokio = { workspace = true, features = ["rt"], optional = true } + +[dev-dependencies] +tokio = { workspace = true, features = ["full"] } diff --git a/crates/gateway-core/src/backend/http.rs b/crates/gateway-core/src/backend/http.rs new file mode 100644 index 0000000..492dadd --- /dev/null +++ b/crates/gateway-core/src/backend/http.rs @@ -0,0 +1,55 @@ +use axum_core::body::Body; +use http::{Request, Response}; + +use crate::error::{Error, Result}; + +/// Abstract HTTP transport. +/// +/// Implementations exist for: +/// - [`ReqwestHttpClient`] (tokio feature, local/AWS backends) +/// - Platform-specific clients (e.g. Fastly backend — no tokio/reqwest required) +/// +/// Callers inject a concrete implementation at construction time via +/// [`crate::service::ProviderDispatchService::new`] or the passthrough services. +/// The core crate itself never depends on a specific runtime. +#[async_trait::async_trait] +pub trait HttpClient: Send + Sync { + async fn send(&self, req: Request) -> Result>; +} + +/// A [`HttpClient`] backed by [`reqwest`]. +/// +/// Only available when the `tokio` feature is enabled. Use this in local +/// development and on platforms that support the tokio async runtime (e.g. AWS). +/// +/// For Fastly Compute@Edge (`wasm32-wasip1`), provide your own [`HttpClient`] +/// implementation using the Fastly SDK instead. +#[cfg(feature = "tokio")] +pub struct ReqwestHttpClient(reqwest::Client); + +#[cfg(feature = "tokio")] +impl ReqwestHttpClient { + pub fn new(client: reqwest::Client) -> Self { + Self(client) + } +} + +#[cfg(feature = "tokio")] +#[async_trait::async_trait] +impl HttpClient for ReqwestHttpClient { + async fn send(&self, req: Request) -> Result> { + let req: reqwest::Request = req + .map(|body| reqwest::Body::wrap_stream(body.into_data_stream())) + .try_into() + .map_err(|e| Error::HttpClient(format!("Failed to convert request: {e}")))?; + + let resp = self + .0 + .execute(req) + .await + .map_err(|e| Error::HttpClient(format!("HTTP request failed: {e}")))?; + let resp = Response::from(resp); + + Ok(resp.map(Body::new)) + } +} diff --git a/crates/gateway-core/src/backend/mod.rs b/crates/gateway-core/src/backend/mod.rs new file mode 100644 index 0000000..3883215 --- /dev/null +++ b/crates/gateway-core/src/backend/mod.rs @@ -0,0 +1 @@ +pub mod http; diff --git a/crates/gateway-core/src/config.rs b/crates/gateway-core/src/config.rs new file mode 100644 index 0000000..51af301 --- /dev/null +++ b/crates/gateway-core/src/config.rs @@ -0,0 +1,26 @@ +/// Runtime configuration for a single LLM provider. +/// +/// API keys are provided by the caller at construction time; the core crate +/// never resolves, rotates, or validates credentials. +#[derive(Debug, Clone)] +pub struct ProviderConfig { + /// Provider API key (e.g. Anthropic `x-api-key` or OpenAI `Bearer` token). + pub api_key: String, + /// Override the provider's default base URL (e.g. for proxies or local stubs). + /// `None` means use the provider's production endpoint. + pub base_url: Option, +} + +impl ProviderConfig { + pub fn new(api_key: impl Into) -> Self { + Self { + api_key: api_key.into(), + base_url: None, + } + } + + pub fn with_base_url(mut self, base_url: impl Into) -> Self { + self.base_url = Some(base_url.into()); + self + } +} diff --git a/crates/gateway-core/src/error.rs b/crates/gateway-core/src/error.rs new file mode 100644 index 0000000..4692f55 --- /dev/null +++ b/crates/gateway-core/src/error.rs @@ -0,0 +1,29 @@ +/// Crate-level error type. +/// +/// Each variant carries enough semantic information to determine its HTTP status +/// mapping, observability category, and whether it is retryable — without +/// inspecting the message string. +#[derive(Debug, thiserror::Error)] +pub enum Error { + /// The underlying HTTP client failed to send the request or receive the response. + #[error("http client error: {0}")] + HttpClient(String), + + /// The provider returned a non-2xx status code. + #[error("provider error: status={status}, body={body}")] + ProviderError { status: u16, body: String }, + + /// An error occurred while reading or parsing a streaming response. + #[error("stream error: {0}")] + Stream(String), + + /// JSON serialization or deserialization failed. + #[error("serialization error: {0}")] + Json(#[from] serde_json::Error), + + /// Building the outbound HTTP request failed (e.g. invalid URI or header value). + #[error("request build error: {0}")] + RequestBuild(String), +} + +pub type Result = std::result::Result; diff --git a/crates/gateway-core/src/lib.rs b/crates/gateway-core/src/lib.rs new file mode 100644 index 0000000..70118e8 --- /dev/null +++ b/crates/gateway-core/src/lib.rs @@ -0,0 +1,68 @@ +//! Core LLM request→response pipeline for the Edgee AI Gateway. +//! +//! # Architecture +//! +//! The pipeline is modelled as a Tower [`Service`] chain. This crate defines the +//! innermost service ([`service::ProviderDispatchService`]) and the foundational +//! types/traits that all other gateway crates depend on. +//! +//! ```text +//! CompletionRequest +//! │ +//! v +//! ┌──────────────────────┐ +//! │ [User layers] │ ← Any tower::Layer (compression, logging, …) +//! └──────┬───────────────┘ +//! │ +//! v +//! ┌──────────────────────┐ +//! │ ProviderDispatch │ ← Service +//! │ Service │ +//! └──────────────────────┘ +//! │ +//! v +//! GatewayResponse +//! ``` +//! +//! # Passthrough +//! +//! Two additional Tower services handle the passthrough path, where requests +//! arrive in provider-native format and are forwarded without translation: +//! +//! - [`passthrough::anthropic::AnthropicPassthroughService`] — `POST /v1/messages` +//! - [`passthrough::openai::OpenAIPassthroughService`] — `POST /v1/responses` +//! +//! # Platform compatibility +//! +//! This crate has **no hard dependency on tokio or reqwest**. Enable the `tokio` +//! feature to get a concrete [`backend::http::ReqwestHttpClient`] backed by reqwest. +//! On other platforms (e.g. Fastly `wasm32-wasip1`), provide your own +//! [`backend::http::HttpClient`] implementation. +//! +//! [`Service`]: tower::Service + +pub mod backend; +pub mod config; +pub mod error; +pub mod passthrough; +pub mod provider; +pub mod service; +pub mod types; + +// Flat re-exports for convenience +pub use backend::http::HttpClient; +#[cfg(feature = "tokio")] +pub use backend::http::ReqwestHttpClient; +pub use config::ProviderConfig; +pub use error::{Error, Result}; +pub use provider::Provider; +pub use service::ProviderDispatchService; +pub use types::{ + CompletionChunk, CompletionRequest, CompletionResponse, GatewayResponse, Message, + PassthroughRequest, Usage, +}; + +// ── Test utilities (compiled only for tests) ───────────────────────────── + +#[cfg(test)] +pub(crate) mod testing; diff --git a/crates/gateway-core/src/passthrough/anthropic.rs b/crates/gateway-core/src/passthrough/anthropic.rs new file mode 100644 index 0000000..c311c14 --- /dev/null +++ b/crates/gateway-core/src/passthrough/anthropic.rs @@ -0,0 +1,114 @@ +use std::{ + sync::Arc, + task::{Context, Poll}, +}; + +use axum_core::body::Body; +use futures::future::BoxFuture; +use http::{Request, Response}; +use tower::Service; + +use crate::{ + PassthroughRequest, + backend::http::HttpClient, + config::ProviderConfig, + error::{Error, Result}, +}; + +/// Default Anthropic Messages API endpoint. +const DEFAULT_BASE_URL: &str = "https://api.anthropic.com"; + +/// Passthrough Tower service for the Anthropic Messages API. +/// +/// Forwards `POST /v1/messages` requests to Anthropic in their **native format** +/// without any translation. Headers supplied in the [`PassthroughRequest`] are +/// forwarded as-is (gateway-internal headers must already be stripped by the +/// caller — see [`crate::passthrough::SKIP_HEADERS`]). +/// +/// This is one of the two "distinct Tower `Service` implementations" for +/// passthrough described in the spec (§6 Milestone 1). +pub struct AnthropicPassthroughService { + client: Arc, + config: ProviderConfig, +} + +impl AnthropicPassthroughService { + pub fn new(client: Arc, config: ProviderConfig) -> Self { + Self { client, config } + } + + fn target_uri(&self) -> String { + let base = self.config.base_url.as_deref().unwrap_or(DEFAULT_BASE_URL); + format!("{base}/v1/messages") + } +} + +impl Service for AnthropicPassthroughService { + type Response = Response; + type Error = Error; + type Future = BoxFuture<'static, Result>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: PassthroughRequest) -> Self::Future { + let client = self.client.clone(); + let uri = self.target_uri(); + + Box::pin(async move { + let mut builder = Request::builder().method(http::Method::POST).uri(&uri); + + for (key, value) in &req.headers { + builder = builder.header(key.as_str(), value.as_str()); + } + + let forwarded = builder + .body(Body::from(req.body)) + .map_err(|e| Error::RequestBuild(e.to_string()))?; + + client.send(forwarded).await + }) + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use super::*; + use crate::config::ProviderConfig; + + #[test] + fn target_uri_default() { + let svc = AnthropicPassthroughService::new( + Arc::new(crate::testing::StubClient), + ProviderConfig::new("key"), + ); + assert_eq!(svc.target_uri(), "https://api.anthropic.com/v1/messages"); + } + + #[test] + fn target_uri_custom_base_url() { + let svc = AnthropicPassthroughService::new( + Arc::new(crate::testing::StubClient), + ProviderConfig::new("key").with_base_url("http://localhost:8080"), + ); + assert_eq!(svc.target_uri(), "http://localhost:8080/v1/messages"); + } + + #[test] + fn strips_skipped_headers() { + let req = PassthroughRequest::new( + Bytes::from("{}"), + vec![ + ("content-type".into(), "application/json".into()), + // x-edgee-api-key should have been stripped by the caller; + // here we verify the service forwards what it receives as-is. + ("x-api-key".into(), "sk-ant-test".into()), + ], + ); + // The service itself does not filter — it trusts the caller. + assert_eq!(req.headers.len(), 2); + } +} diff --git a/crates/gateway-core/src/passthrough/mod.rs b/crates/gateway-core/src/passthrough/mod.rs new file mode 100644 index 0000000..9d54a91 --- /dev/null +++ b/crates/gateway-core/src/passthrough/mod.rs @@ -0,0 +1,19 @@ +pub mod anthropic; +pub mod openai; + +/// HTTP headers stripped from all outbound passthrough requests. +/// +/// These are either hop-by-hop headers that must not be forwarded, or +/// gateway-internal headers that must not leak to providers. +/// +/// The HTTP boundary layer above `gateway-core` should apply this list when +/// constructing a [`crate::PassthroughRequest`] from an incoming HTTP request. +pub const SKIP_HEADERS: &[&str] = &[ + "host", + "content-length", + "transfer-encoding", + "accept-encoding", + "connection", + // Gateway-internal auth / control headers + "x-edgee-api-key", +]; diff --git a/crates/gateway-core/src/passthrough/openai.rs b/crates/gateway-core/src/passthrough/openai.rs new file mode 100644 index 0000000..2d2db60 --- /dev/null +++ b/crates/gateway-core/src/passthrough/openai.rs @@ -0,0 +1,131 @@ +use std::{ + sync::Arc, + task::{Context, Poll}, +}; + +use axum_core::body::Body; +use futures::future::BoxFuture; +use http::{Request, Response}; +use tower::Service; + +use crate::{ + PassthroughRequest, + backend::http::HttpClient, + config::ProviderConfig, + error::{Error, Result}, +}; + +/// OpenAI Responses API endpoint for requests authenticated with a project key +/// (`sk-proj-…`). These keys belong to the OpenAI Platform API. +const OPENAI_API_RESPONSES_URL: &str = "https://api.openai.com/v1/responses"; + +/// Default Responses API endpoint (ChatGPT backend, used by Codex CLI without +/// a project key). +const OPENAI_CHATGPT_RESPONSES_URL: &str = "https://chatgpt.com/backend-api/codex/responses"; + +/// Passthrough Tower service for the OpenAI Responses API. +/// +/// Forwards `POST /v1/responses` requests to OpenAI in their **native format** +/// without any translation. Headers supplied in the [`PassthroughRequest`] are +/// forwarded as-is (gateway-internal headers must already be stripped by the +/// caller — see [`crate::passthrough::SKIP_HEADERS`]). +/// +/// Endpoint selection (when `ProviderConfig::base_url` is `None`): +/// - `authorization: Bearer sk-proj-…` → `api.openai.com` (Platform API key) +/// - anything else → `chatgpt.com` backend (Codex CLI default) +/// +/// This is one of the two "distinct Tower `Service` implementations" for +/// passthrough described in the spec (§6 Milestone 1). +pub struct OpenAIPassthroughService { + client: Arc, + config: ProviderConfig, +} + +impl OpenAIPassthroughService { + pub fn new(client: Arc, config: ProviderConfig) -> Self { + Self { client, config } + } + + fn target_uri(&self, headers: &[(String, String)]) -> String { + // If an explicit override is set, use it directly. + if let Some(base) = &self.config.base_url { + return format!("{base}/v1/responses"); + } + + // Otherwise select by key prefix (matching reference gateway behaviour). + let is_proj_key = headers + .iter() + .find(|(k, _)| k.eq_ignore_ascii_case("authorization")) + .map(|(_, v)| v.starts_with("sk-proj-") || v.starts_with("Bearer sk-proj-")) + .unwrap_or(false); + + if is_proj_key { + OPENAI_API_RESPONSES_URL.to_owned() + } else { + OPENAI_CHATGPT_RESPONSES_URL.to_owned() + } + } +} + +impl Service for OpenAIPassthroughService { + type Response = Response; + type Error = Error; + type Future = BoxFuture<'static, Result>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: PassthroughRequest) -> Self::Future { + let client = self.client.clone(); + let uri = self.target_uri(&req.headers); + + Box::pin(async move { + let mut builder = Request::builder().method(http::Method::POST).uri(&uri); + + for (key, value) in &req.headers { + builder = builder.header(key.as_str(), value.as_str()); + } + + let forwarded = builder + .body(Body::from(req.body)) + .map_err(|e| Error::RequestBuild(e.to_string()))?; + + client.send(forwarded).await + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::ProviderConfig; + + fn make_svc(base_url: Option<&str>) -> OpenAIPassthroughService { + let mut config = ProviderConfig::new("key"); + if let Some(u) = base_url { + config = config.with_base_url(u); + } + OpenAIPassthroughService::new(Arc::new(crate::testing::StubClient), config) + } + + #[test] + fn routes_proj_key_to_api_openai() { + let svc = make_svc(None); + let headers = vec![("authorization".into(), "Bearer sk-proj-abc123".into())]; + assert_eq!(svc.target_uri(&headers), OPENAI_API_RESPONSES_URL); + } + + #[test] + fn routes_non_proj_key_to_chatgpt() { + let svc = make_svc(None); + let headers = vec![("authorization".into(), "Bearer sk-abc123".into())]; + assert_eq!(svc.target_uri(&headers), OPENAI_CHATGPT_RESPONSES_URL); + } + + #[test] + fn custom_base_url_overrides_selection() { + let svc = make_svc(Some("http://localhost:4000")); + assert_eq!(svc.target_uri(&[]), "http://localhost:4000/v1/responses"); + } +} diff --git a/crates/gateway-core/src/provider.rs b/crates/gateway-core/src/provider.rs new file mode 100644 index 0000000..4a6ca3a --- /dev/null +++ b/crates/gateway-core/src/provider.rs @@ -0,0 +1,37 @@ +use futures::stream::BoxStream; + +use crate::{ + config::ProviderConfig, + error::{Error, Result}, + types::{CompletionChunk, CompletionRequest, CompletionResponse}, +}; + +/// Core abstraction for an LLM provider. +/// +/// Implementations translate from the canonical [`CompletionRequest`] (OpenAI +/// Chat Completions format) into the provider's native API format, make the +/// HTTP call via the injected [`crate::http_client::HttpClient`], and parse +/// the response back into the canonical types. +/// +/// # Dyn compatibility +/// +/// The `complete_stream` method returns a [`BoxStream`] (not `impl Stream`) so +/// that `dyn Provider` is object-safe and can be stored in a `Vec` or `Arc`. +/// `async_trait` boxes the future returned by `complete` for the same reason. +#[async_trait::async_trait] +pub trait Provider: Send + Sync { + /// Perform a non-streaming completion. Waits for the full response. + async fn complete( + &self, + request: &CompletionRequest, + config: &ProviderConfig, + ) -> Result; + + /// Begin a streaming completion. Returns a lazy stream that starts the + /// HTTP request when first polled. + fn complete_stream( + &self, + request: CompletionRequest, + config: ProviderConfig, + ) -> BoxStream<'static, Result>; +} diff --git a/crates/gateway-core/src/service.rs b/crates/gateway-core/src/service.rs new file mode 100644 index 0000000..7746b91 --- /dev/null +++ b/crates/gateway-core/src/service.rs @@ -0,0 +1,130 @@ +use std::{ + sync::Arc, + task::{Context, Poll}, +}; + +use futures::future::BoxFuture; +use tower::Service; + +use crate::{ + backend::http::HttpClient, + config::ProviderConfig, + error::{Error, Result}, + types::{CompletionRequest, GatewayResponse}, +}; + +/// The innermost Tower service in the core LLM pipeline. +/// +/// Routes a [`CompletionRequest`] (OpenAI-compatible canonical format) to the +/// appropriate provider implementation, which translates the request to the +/// provider's native format and calls the provider API. +/// +/// This is the innermost service; all middleware layers +/// (`tools-compression`, user-defined layers, etc.) wrap it: +/// +/// ```text +/// CompletionRequest +/// │ +/// v +/// ┌──────────────────┐ +/// │ [User layers] │ ← Any tower::Layer +/// └──────┬───────────┘ +/// │ +/// v +/// ┌──────────────────┐ +/// │ Provider │ ← This service +/// │ dispatch │ +/// └──────────────────┘ +/// │ +/// v +/// GatewayResponse +/// ``` +/// +/// # Construction +/// +/// ```rust,ignore +/// let service = ProviderDispatchService::new( +/// Arc::new(ReqwestHttpClient::new(client)), +/// anthropic_config, +/// openai_config, +/// ); +/// ``` +pub struct ProviderDispatchService { + _client: Arc, + _anthropic_config: ProviderConfig, + _openai_config: ProviderConfig, +} + +impl ProviderDispatchService { + pub fn new( + client: Arc, + anthropic_config: ProviderConfig, + openai_config: ProviderConfig, + ) -> Self { + Self { + _client: client, + _anthropic_config: anthropic_config, + _openai_config: openai_config, + } + } +} + +impl Service for ProviderDispatchService { + type Response = GatewayResponse; + type Error = Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _req: CompletionRequest) -> Self::Future { + Box::pin(async { + Err(Error::HttpClient( + "ProviderDispatchService: not yet implemented".into(), + )) + }) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + use tower::ServiceExt as _; + + use super::*; + use crate::{ + testing::StubClient, + types::message::{Message, MessageContent, UserMessage}, + }; + + fn make_service() -> ProviderDispatchService { + ProviderDispatchService::new( + Arc::new(StubClient), + ProviderConfig::new("test-anthropic-key"), + ProviderConfig::new("test-openai-key"), + ) + } + + #[tokio::test] + async fn poll_ready_is_always_ready() { + let mut svc = make_service(); + let ready = std::future::poll_fn(|cx| svc.poll_ready(cx)).await; + assert!(ready.is_ok()); + } + + #[tokio::test] + async fn call_returns_error_for_unimplemented() { + let svc = make_service(); + let req = CompletionRequest::new( + "gpt-4o", + vec![Message::User(UserMessage { + name: None, + content: MessageContent::Text("test".into()), + cache_control: None, + })], + ); + let result = svc.oneshot(req).await; + assert!(result.is_err()); + } +} diff --git a/crates/gateway-core/src/testing.rs b/crates/gateway-core/src/testing.rs new file mode 100644 index 0000000..5ccda42 --- /dev/null +++ b/crates/gateway-core/src/testing.rs @@ -0,0 +1,21 @@ +//! Shared test utilities for `gateway-core` unit tests. +//! +//! This module is compiled only in test builds (`#[cfg(test)]`). + +use axum_core::body::Body; +use http::{Request, Response}; + +use crate::{Error, backend::http::HttpClient, error::Result}; + +/// A no-op [`HttpClient`] that always returns an error. +/// +/// Useful for tests that need to construct services without exercising the +/// HTTP transport layer. +pub struct StubClient; + +#[async_trait::async_trait] +impl HttpClient for StubClient { + async fn send(&self, _req: Request) -> Result> { + Err(Error::HttpClient("StubClient always fails".into())) + } +} diff --git a/crates/gateway-core/src/types/message.rs b/crates/gateway-core/src/types/message.rs new file mode 100644 index 0000000..2fa4eea --- /dev/null +++ b/crates/gateway-core/src/types/message.rs @@ -0,0 +1,271 @@ +use serde::{Deserialize, Serialize}; + +// ── Content ─────────────────────────────────────────────────────────────── + +/// A content part within a multi-part message. +/// +/// The `#[serde(other)]` catch-all preserves forward compatibility with +/// provider-specific content block types (e.g. Anthropic image blocks). +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ContentPart { + /// Standard text part. Also accepted as `input_text` (Responses API) or + /// `output_text` (Anthropic streaming). + #[serde(alias = "input_text", alias = "output_text")] + Text { text: String }, + #[serde(other)] + Unknown, +} + +/// The content of a message: either a plain string or an array of content parts. +/// +/// Providers accept both forms; `#[serde(untagged)]` handles the ambiguity. +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum MessageContent { + Text(String), + Parts(Vec), +} + +impl MessageContent { + /// Extract all text, joining multi-part content with double newlines. + pub fn as_text(&self) -> String { + match self { + MessageContent::Text(s) => s.clone(), + MessageContent::Parts(parts) => parts + .iter() + .filter_map(|p| match p { + ContentPart::Text { text } => Some(text.as_str()), + ContentPart::Unknown => None, + }) + .collect::>() + .join("\n\n"), + } + } + + pub fn is_empty(&self) -> bool { + match self { + MessageContent::Text(s) => s.is_empty(), + MessageContent::Parts(parts) => parts.is_empty(), + } + } +} + +impl Default for MessageContent { + fn default() -> Self { + MessageContent::Text(String::new()) + } +} + +impl From for MessageContent { + fn from(s: String) -> Self { + MessageContent::Text(s) + } +} + +impl From<&str> for MessageContent { + fn from(s: &str) -> Self { + MessageContent::Text(s.to_owned()) + } +} + +// ── Messages ────────────────────────────────────────────────────────────── + +/// A conversation message. The `role` field is used as the serde tag so +/// serialized JSON matches the OpenAI Chat Completions wire format exactly. +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "role", rename_all = "snake_case")] +pub enum Message { + /// OpenAI "developer" system prompt (treated as system by most providers). + Developer(DeveloperMessage), + System(SystemMessage), + User(UserMessage), + Assistant(AssistantMessage), + Tool(ToolMessage), +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct DeveloperMessage { + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + pub content: MessageContent, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct SystemMessage { + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + pub content: MessageContent, + /// Preserved for passthrough; Anthropic uses this for prompt caching. + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_control: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct UserMessage { + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + pub content: MessageContent, + /// Preserved for passthrough; Anthropic uses this for prompt caching. + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_control: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct AssistantMessage { + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub refusal: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + /// Preserved for passthrough; Anthropic uses this for prompt caching. + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_control: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ToolMessage { + pub content: MessageContent, + pub tool_call_id: String, +} + +// ── Tools ───────────────────────────────────────────────────────────────── + +/// A tool (function) the model may call. +/// +/// Custom `Deserialize` handles both the OpenAI nested format +/// (`{"type":"function","function":{...}}`) and the Anthropic flat format +/// (`{"name":"...","description":"...","input_schema":{...}}`). +#[derive(Debug, Clone)] +pub enum Tool { + Function { + function: FunctionDefinition, + }, + /// Unknown tool type — preserved opaquely for passthrough. + Unknown(serde_json::Value), +} + +impl<'de> Deserialize<'de> for Tool { + fn deserialize>(d: D) -> Result { + let v = serde_json::Value::deserialize(d)?; + if v.get("type").and_then(|t| t.as_str()) == Some("function") { + #[derive(Deserialize)] + struct FunctionTool { + function: FunctionDefinition, + } + serde_json::from_value::(v) + .map(|t| Tool::Function { + function: t.function, + }) + .map_err(serde::de::Error::custom) + } else { + Ok(Tool::Unknown(v)) + } + } +} + +impl Serialize for Tool { + fn serialize(&self, s: S) -> Result { + use serde::ser::SerializeMap as _; + match self { + Tool::Function { function } => { + let mut map = s.serialize_map(Some(2))?; + map.serialize_entry("type", "function")?; + map.serialize_entry("function", function)?; + map.end() + } + Tool::Unknown(v) => v.serialize(s), + } + } +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct FunctionDefinition { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// JSON Schema for the function's parameters. + #[serde(skip_serializing_if = "Option::is_none")] + pub parameters: Option, +} + +/// Controls which tool (if any) the model calls. +/// +/// `#[serde(untagged)]` handles both string shortcuts (`"auto"`, `"required"`, +/// `"none"`) and the specific-function object `{"type":"function","function":{...}}`. +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum ToolChoice { + /// Simple mode string: `"auto"`, `"required"`, or `"none"`. + Mode(String), + /// Force a specific function: `{"type":"function","function":{"name":"..."}}` + Specific { + #[serde(rename = "type")] + tool_type: String, + function: ToolChoiceFunction, + }, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ToolChoiceFunction { + pub name: String, +} + +/// A tool call made by the assistant in a response. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ToolCall { + pub id: String, + #[serde(rename = "type", default = "default_tool_type")] + pub tool_type: String, + pub function: FunctionCall, +} + +fn default_tool_type() -> String { + "function".to_string() +} + +/// The function invocation within a tool call. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct FunctionCall { + pub name: String, + /// JSON-encoded arguments string (as returned by the model). + pub arguments: String, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn message_content_text_roundtrip() { + let content = MessageContent::Text("hello".into()); + let json = serde_json::to_string(&content).unwrap(); + assert_eq!(json, r#""hello""#); + let back: MessageContent = serde_json::from_str(&json).unwrap(); + assert_eq!(back.as_text(), "hello"); + } + + #[test] + fn tool_deserializes_function_nested_format() { + let json = r#"{"type":"function","function":{"name":"get_weather","description":"Get weather","parameters":{}}}"#; + let tool: Tool = serde_json::from_str(json).unwrap(); + assert!(matches!(tool, Tool::Function { .. })); + } + + #[test] + fn tool_choice_string_mode() { + let json = r#""auto""#; + let tc: ToolChoice = serde_json::from_str(json).unwrap(); + assert!(matches!(tc, ToolChoice::Mode(s) if s == "auto")); + } + + #[test] + fn message_tagged_by_role() { + let json = r#"{"role":"user","content":"hello"}"#; + let msg: Message = serde_json::from_str(json).unwrap(); + assert!(matches!(msg, Message::User(_))); + } +} diff --git a/crates/gateway-core/src/types/mod.rs b/crates/gateway-core/src/types/mod.rs new file mode 100644 index 0000000..1e807a0 --- /dev/null +++ b/crates/gateway-core/src/types/mod.rs @@ -0,0 +1,16 @@ +pub mod message; +pub mod passthrough; +pub mod request; +pub mod response; + +pub use message::{ + AssistantMessage, ContentPart, DeveloperMessage, FunctionCall, FunctionDefinition, Message, + MessageContent, SystemMessage, Tool, ToolCall, ToolChoice, ToolChoiceFunction, ToolMessage, + UserMessage, +}; +pub use passthrough::PassthroughRequest; +pub use request::CompletionRequest; +pub use response::{ + Choice, ChunkChoice, CompletionChunk, CompletionResponse, CompletionTokensDetails, Delta, + DeltaFunction, DeltaToolCall, FinishReason, GatewayResponse, PromptTokensDetails, Usage, +}; diff --git a/crates/gateway-core/src/types/passthrough.rs b/crates/gateway-core/src/types/passthrough.rs new file mode 100644 index 0000000..513969d --- /dev/null +++ b/crates/gateway-core/src/types/passthrough.rs @@ -0,0 +1,30 @@ +use bytes::Bytes; + +/// A raw LLM request in a provider's native wire format, ready for passthrough. +/// +/// This is the pipeline-level input type for the passthrough Tower services +/// ([`crate::passthrough::anthropic::AnthropicPassthroughService`], +/// [`crate::passthrough::openai::OpenAIPassthroughService`]). +/// +/// The HTTP boundary layer above `gateway-core` is responsible for: +/// - Reading the raw request body into [`Bytes`]. +/// - Stripping gateway-internal headers (see [`crate::passthrough::SKIP_HEADERS`]). +/// - Constructing this type before handing the request to the pipeline. +/// +/// `http` types are used *internally* by the passthrough service implementations +/// only when building the outbound HTTP call to the provider — never in this +/// public interface. +#[derive(Debug, Clone)] +pub struct PassthroughRequest { + /// Raw serialized request body in the provider's native format. + pub body: Bytes, + /// Pre-filtered headers to forward (gateway-internal headers already stripped). + /// Each entry is a `(name, value)` pair as UTF-8 strings. + pub headers: Vec<(String, String)>, +} + +impl PassthroughRequest { + pub fn new(body: Bytes, headers: Vec<(String, String)>) -> Self { + Self { body, headers } + } +} diff --git a/crates/gateway-core/src/types/request.rs b/crates/gateway-core/src/types/request.rs new file mode 100644 index 0000000..44b6d75 --- /dev/null +++ b/crates/gateway-core/src/types/request.rs @@ -0,0 +1,92 @@ +use serde::{Deserialize, Serialize}; + +use super::message::{Message, Tool, ToolChoice}; + +/// A canonical LLM completion request in OpenAI Chat Completions format. +/// +/// This is the provider-agnostic entry point for the [`crate::service::ProviderDispatchService`]. +/// Provider implementations translate from this type to their native API format. +/// +/// The `messages` field also accepts the Responses API `input` alias so that +/// the same type can represent requests from both Chat Completions and Responses API +/// clients before they are normalised. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CompletionRequest { + /// Model identifier (e.g. `"claude-opus-4-5"`, `"gpt-4o"`). + pub model: String, + + /// Conversation history. + /// + /// Accepts `"messages"` (Chat Completions) or `"input"` (Responses API). + #[serde(alias = "input")] + pub messages: Vec, + + /// Maximum tokens to generate in the response. + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + + /// Whether to stream the response as SSE chunks. + #[serde(default)] + pub stream: bool, + + /// Tools (functions) the model may call. + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub tools: Vec, + + /// Controls which tool (if any) the model calls. + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + + /// Sampling temperature (0–2). Higher = more random. + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + + /// Nucleus sampling probability mass. + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, +} + +impl CompletionRequest { + pub fn new(model: impl Into, messages: Vec) -> Self { + Self { + model: model.into(), + messages, + max_tokens: None, + stream: false, + tools: Vec::new(), + tool_choice: None, + temperature: None, + top_p: None, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::message::{MessageContent, UserMessage}; + + #[test] + fn completion_request_minimal_roundtrip() { + let req = CompletionRequest::new( + "gpt-4o", + vec![Message::User(UserMessage { + name: None, + content: MessageContent::Text("Hello".into()), + cache_control: None, + })], + ); + let json = serde_json::to_string(&req).unwrap(); + let back: CompletionRequest = serde_json::from_str(&json).unwrap(); + assert_eq!(back.model, "gpt-4o"); + assert!(!back.stream); + } + + #[test] + fn accepts_input_alias_for_messages() { + let json = r#"{"model":"claude-opus-4-5","input":[{"role":"user","content":"Hi"}]}"#; + let req: CompletionRequest = serde_json::from_str(json).unwrap(); + assert_eq!(req.model, "claude-opus-4-5"); + assert_eq!(req.messages.len(), 1); + } +} diff --git a/crates/gateway-core/src/types/response.rs b/crates/gateway-core/src/types/response.rs new file mode 100644 index 0000000..e796509 --- /dev/null +++ b/crates/gateway-core/src/types/response.rs @@ -0,0 +1,139 @@ +use futures::stream::BoxStream; +use serde::{Deserialize, Serialize}; + +use super::message::Message; +use crate::error::Error; + +// ── Usage ───────────────────────────────────────────────────────────────── + +/// Token-level usage details for a completion. +#[derive(Debug, Clone, Copy, Default, Deserialize, Serialize)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_tokens_details: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub completion_tokens_details: Option, +} + +/// Breakdown of prompt token counts. +#[derive(Debug, Clone, Copy, Default, Deserialize, Serialize)] +pub struct PromptTokensDetails { + #[serde(skip_serializing_if = "Option::is_none")] + pub cached_tokens: Option, + /// For Anthropic: tokens written into the prompt cache. + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_creation_tokens: Option, +} + +/// Breakdown of completion token counts. +#[derive(Debug, Clone, Copy, Default, Deserialize, Serialize)] +pub struct CompletionTokensDetails { + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_tokens: Option, +} + +// ── Finish reason ───────────────────────────────────────────────────────── + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum FinishReason { + Stop, + Length, + ContentFilter, + ToolCalls, +} + +// ── Non-streaming response ──────────────────────────────────────────────── + +/// A complete (non-streaming) LLM response in OpenAI Chat Completions format. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CompletionResponse { + pub id: String, + pub object: String, + pub created: i64, + pub model: String, + pub choices: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, +} + +/// A single choice within a [`CompletionResponse`]. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Choice { + pub index: u32, + pub message: Message, + #[serde(skip_serializing_if = "Option::is_none")] + pub finish_reason: Option, +} + +// ── Streaming response ──────────────────────────────────────────────────── + +/// A single streaming chunk in OpenAI SSE format (`object: "chat.completion.chunk"`). +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CompletionChunk { + pub id: String, + pub object: String, + pub created: i64, + pub model: String, + pub choices: Vec, + /// Only present in the final chunk when `stream_options.include_usage` is set. + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, +} + +/// A single choice within a [`CompletionChunk`]. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChunkChoice { + pub index: u32, + pub delta: Delta, + #[serde(skip_serializing_if = "Option::is_none")] + pub finish_reason: Option, +} + +/// The incremental content delta for a streaming chunk. +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +pub struct Delta { + /// Present only in the first chunk for a given choice. + #[serde(skip_serializing_if = "Option::is_none")] + pub role: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, +} + +/// An incremental tool call in a streaming delta. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct DeltaToolCall { + pub index: u32, + pub id: String, + #[serde(rename = "type")] + pub tool_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub function: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct DeltaFunction { + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option, +} + +// ── Unified service response ────────────────────────────────────────────── + +/// The response type produced by [`crate::service::ProviderDispatchService`]. +/// +/// Callers match on this enum to handle streaming and non-streaming responses +/// uniformly through the same Tower service interface. +pub enum GatewayResponse { + /// A complete, buffered response. + Complete(CompletionResponse), + /// A lazy stream of chunks. The HTTP request to the provider is not made + /// until the stream is first polled. + Stream(BoxStream<'static, Result>), +}