diff --git a/crates/providers/src/github_copilot.rs b/crates/providers/src/github_copilot.rs index c3327301d..b379968f9 100644 --- a/crates/providers/src/github_copilot.rs +++ b/crates/providers/src/github_copilot.rs @@ -22,7 +22,9 @@ use { SseLineResult, StreamingToolState, finalize_stream, parse_openai_compat_usage_from_payload, parse_tool_calls, process_openai_sse_line, to_openai_tools, }, - moltis_agents::model::{ChatMessage, CompletionResponse, LlmProvider, StreamEvent}, + moltis_agents::model::{ + ChatMessage, CompletionResponse, LlmProvider, StreamEvent, ToolCall, Usage, + }, }; // ── Constants ──────────────────────────────────────────────────────────────── @@ -34,7 +36,6 @@ const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code"; const GITHUB_TOKEN_URL: &str = "https://github.com/login/oauth/access_token"; const COPILOT_TOKEN_URL: &str = "https://api.github.com/copilot_internal/v2/token"; const COPILOT_API_BASE: &str = "https://api.individual.githubcopilot.com"; -const COPILOT_MODELS_ENDPOINT: &str = "https://api.individual.githubcopilot.com/models"; const PROVIDER_NAME: &str = "github-copilot"; @@ -63,6 +64,22 @@ struct GithubTokenResponse { struct CopilotTokenResponse { token: String, expires_at: u64, + /// Enterprise accounts return a proxy endpoint hostname (e.g. + /// `proxy.enterprise.githubcopilot.com`). When present, all API + /// requests must be routed through `https://{proxy_ep}/…` and chat + /// completions must use `stream: true`. + #[serde(rename = "proxy-ep")] + proxy_ep: Option, +} + +/// Resolved authentication: a valid Copilot API token plus the base URL to +/// use for API requests (may differ for enterprise vs individual accounts). +struct CopilotAuth { + token: String, + base_url: String, + /// `true` when the endpoint is an enterprise proxy that only supports + /// streaming chat completions. + is_enterprise: bool, } // ── Provider ───────────────────────────────────────────────────────────────── @@ -139,9 +156,9 @@ impl GitHubCopilotProvider { } } - /// Get a valid Copilot API token, exchanging the GitHub token if needed. - async fn get_valid_copilot_token(&self) -> anyhow::Result { - fetch_valid_copilot_token_with_fallback(self.client, &self.token_store).await + /// Get a valid Copilot API token and resolved base URL. + async fn get_copilot_auth(&self) -> anyhow::Result { + fetch_copilot_auth_with_fallback(self.client, &self.token_store).await } } @@ -193,16 +210,37 @@ pub const COPILOT_MODELS: &[(&str, &str)] = &[ ("gemini-2.0-flash", "Gemini 2.0 Flash (Copilot)"), ]; -async fn fetch_valid_copilot_token( +/// Build a [`CopilotAuth`] from an `account_id` value that may contain a +/// proxy-ep hostname persisted from a previous token exchange. +fn copilot_auth_from_parts(token: String, proxy_ep: Option<&str>) -> CopilotAuth { + match proxy_ep.filter(|s| !s.is_empty()) { + Some(ep) => { + debug!(proxy_ep = %ep, "using enterprise proxy endpoint"); + CopilotAuth { + token, + base_url: format!("https://{ep}"), + is_enterprise: true, + } + }, + None => CopilotAuth { + token, + base_url: COPILOT_API_BASE.to_string(), + is_enterprise: false, + }, + } +} + +async fn fetch_copilot_auth( client: &reqwest::Client, token_store: &TokenStore, -) -> anyhow::Result { +) -> anyhow::Result { let tokens = token_store.load(PROVIDER_NAME).ok_or_else(|| { anyhow::anyhow!("not logged in to github-copilot — run OAuth device flow first") })?; // The `access_token` stored is the GitHub user token. // We exchange it for a short-lived Copilot API token and cache it. + // The proxy-ep (if any) is persisted in the `account_id` field. if let Some(copilot_tokens) = token_store.load("github-copilot-api") && let Some(expires_at) = copilot_tokens.expires_at { @@ -211,7 +249,9 @@ async fn fetch_valid_copilot_token( .unwrap_or_default() .as_secs(); if now + 60 < expires_at { - return Ok(copilot_tokens.access_token.expose_secret().clone()); + let token = copilot_tokens.access_token.expose_secret().clone(); + let proxy_ep = copilot_tokens.account_id.as_deref(); + return Ok(copilot_auth_from_parts(token, proxy_ep)); } } @@ -239,21 +279,26 @@ async fn fetch_valid_copilot_token( access_token: Secret::new(copilot_resp.token.clone()), refresh_token: None, id_token: None, - account_id: None, + // Persist the enterprise proxy-ep hostname (if any) so we can + // reconstruct the correct base URL from the cache. + account_id: copilot_resp.proxy_ep.clone(), expires_at: Some(copilot_resp.expires_at), }); - Ok(copilot_resp.token) + Ok(copilot_auth_from_parts( + copilot_resp.token, + copilot_resp.proxy_ep.as_deref(), + )) } -async fn fetch_valid_copilot_token_with_fallback( +async fn fetch_copilot_auth_with_fallback( client: &reqwest::Client, primary_store: &TokenStore, -) -> anyhow::Result { +) -> anyhow::Result { let Some(token_store) = token_store_with_provider_tokens(primary_store) else { anyhow::bail!("not logged in to github-copilot — run OAuth device flow first"); }; - fetch_valid_copilot_token(client, &token_store).await + fetch_copilot_auth(client, &token_store).await } pub fn default_model_catalog() -> Vec { @@ -359,11 +404,11 @@ fn parse_models_payload(value: &serde_json::Value) -> Vec anyhow::Result> { let response = client - .get(COPILOT_MODELS_ENDPOINT) - .header("Authorization", format!("Bearer {access_token}")) + .get(format!("{}/models", auth.base_url)) + .header("Authorization", format!("Bearer {}", auth.token)) .header("Accept", "application/json") .header("Editor-Version", EDITOR_VERSION) .header("User-Agent", COPILOT_USER_AGENT) @@ -397,9 +442,8 @@ pub fn start_model_discovery() -> mpsc::Receiver Vec { super::merge_discovered_with_fallback_catalog(discovered, fallback) } +// ── Enterprise streaming-to-sync bridge ────────────────────────────────────── + +/// Send a streaming chat completion request and collect the SSE events into a +/// single [`CompletionResponse`]. Used for enterprise proxy endpoints that +/// reject non-streaming requests. +async fn collect_streamed_completion( + client: &reqwest::Client, + auth: &CopilotAuth, + model: &str, + messages: &[ChatMessage], + tools: &[serde_json::Value], +) -> anyhow::Result { + let openai_messages: Vec = + messages.iter().map(ChatMessage::to_openai_value).collect(); + let mut body = serde_json::json!({ + "model": model, + "messages": openai_messages, + "stream": true, + "stream_options": { "include_usage": true }, + }); + + if !tools.is_empty() { + body["tools"] = serde_json::Value::Array(to_openai_tools(tools)); + } + + debug!( + model = %model, + messages_count = messages.len(), + tools_count = tools.len(), + "github-copilot enterprise complete (streaming) request" + ); + trace!(body = %serde_json::to_string(&body).unwrap_or_default(), "github-copilot enterprise request body"); + + let http_resp = client + .post(format!("{}/chat/completions", auth.base_url)) + .header("Authorization", format!("Bearer {}", auth.token)) + .header("content-type", "application/json") + .header("Editor-Version", EDITOR_VERSION) + .header("User-Agent", COPILOT_USER_AGENT) + .json(&body) + .send() + .await?; + + let status = http_resp.status(); + if !status.is_success() { + let retry_after_ms = super::retry_after_ms_from_headers(http_resp.headers()); + let body_text = http_resp.text().await.unwrap_or_default(); + warn!(status = %status, body = %body_text, "github-copilot enterprise API error"); + anyhow::bail!( + "{}", + super::with_retry_after_marker( + format!("GitHub Copilot API error HTTP {status}: {body_text}"), + retry_after_ms, + ) + ); + } + + // Parse the SSE stream into events, then assemble a CompletionResponse. + let mut byte_stream = http_resp.bytes_stream(); + let mut buf = String::new(); + let mut state = StreamingToolState::default(); + let mut events: Vec = Vec::new(); + + while let Some(chunk) = byte_stream.next().await { + let chunk = chunk?; + buf.push_str(&String::from_utf8_lossy(&chunk)); + + while let Some(pos) = buf.find('\n') { + let line = buf[..pos].trim().to_string(); + buf = buf[pos + 1..].to_string(); + + if line.is_empty() { + continue; + } + let Some(data) = line + .strip_prefix("data: ") + .or_else(|| line.strip_prefix("data:")) + else { + continue; + }; + + match process_openai_sse_line(data, &mut state) { + SseLineResult::Done => { + events.extend(finalize_stream(&mut state)); + return Ok(stream_events_to_completion(events)); + }, + SseLineResult::Events(evts) => events.extend(evts), + SseLineResult::Skip => {}, + } + } + } + + // Process any trailing data in the buffer. + let line = buf.trim().to_string(); + if !line.is_empty() { + if let Some(data) = line + .strip_prefix("data: ") + .or_else(|| line.strip_prefix("data:")) + { + match process_openai_sse_line(data, &mut state) { + SseLineResult::Done | SseLineResult::Events(_) | SseLineResult::Skip => {}, + } + } + } + events.extend(finalize_stream(&mut state)); + Ok(stream_events_to_completion(events)) +} + +/// Collapse a collected list of [`StreamEvent`]s into a [`CompletionResponse`]. +fn stream_events_to_completion(events: Vec) -> CompletionResponse { + let mut text_parts: Vec = Vec::new(); + let mut tool_calls: Vec = Vec::new(); + let mut usage = Usage::default(); + + // Track in-progress tool calls by index. + let mut pending_tools: Vec<(String, String, String)> = Vec::new(); // (id, name, args) + + for event in events { + match event { + StreamEvent::Delta(s) => text_parts.push(s), + StreamEvent::ToolCallStart { id, name, index } => { + while pending_tools.len() <= index { + pending_tools.push((String::new(), String::new(), String::new())); + } + pending_tools[index].0 = id; + pending_tools[index].1 = name; + }, + StreamEvent::ToolCallArgumentsDelta { index, delta } => { + if let Some(entry) = pending_tools.get_mut(index) { + entry.2.push_str(&delta); + } + }, + StreamEvent::ToolCallComplete { index } => { + if let Some(entry) = pending_tools.get(index) { + let arguments: serde_json::Value = + serde_json::from_str(&entry.2).unwrap_or_default(); + tool_calls.push(ToolCall { + id: entry.0.clone(), + name: entry.1.clone(), + arguments, + }); + } + }, + StreamEvent::Done(u) => usage = u, + StreamEvent::Error(_) + | StreamEvent::ProviderRaw(_) + | StreamEvent::ReasoningDelta(_) => {}, + } + } + + let text = if text_parts.is_empty() { + None + } else { + Some(text_parts.join("")) + }; + + CompletionResponse { + text, + tool_calls, + usage, + } +} + // ── LlmProvider impl ──────────────────────────────────────────────────────── #[async_trait] @@ -461,7 +668,14 @@ impl LlmProvider for GitHubCopilotProvider { messages: &[ChatMessage], tools: &[serde_json::Value], ) -> anyhow::Result { - let token = self.get_valid_copilot_token().await?; + let auth = self.get_copilot_auth().await?; + + // Enterprise proxy only supports streaming — delegate to the + // streaming path and collect the result. + if auth.is_enterprise { + return collect_streamed_completion(self.client, &auth, &self.model, messages, tools) + .await; + } let openai_messages: Vec = messages.iter().map(ChatMessage::to_openai_value).collect(); @@ -484,8 +698,8 @@ impl LlmProvider for GitHubCopilotProvider { let http_resp = self .client - .post(format!("{COPILOT_API_BASE}/chat/completions")) - .header("Authorization", format!("Bearer {token}")) + .post(format!("{}/chat/completions", auth.base_url)) + .header("Authorization", format!("Bearer {}", auth.token)) .header("content-type", "application/json") .header("Editor-Version", EDITOR_VERSION) .header("User-Agent", COPILOT_USER_AGENT) @@ -539,8 +753,8 @@ impl LlmProvider for GitHubCopilotProvider { tools: Vec, ) -> Pin + Send + '_>> { Box::pin(async_stream::stream! { - let token = match self.get_valid_copilot_token().await { - Ok(t) => t, + let auth = match self.get_copilot_auth().await { + Ok(a) => a, Err(e) => { yield StreamEvent::Error(e.to_string()); return; @@ -570,8 +784,8 @@ impl LlmProvider for GitHubCopilotProvider { let resp = match self .client - .post(format!("{COPILOT_API_BASE}/chat/completions")) - .header("Authorization", format!("Bearer {token}")) + .post(format!("{}/chat/completions", auth.base_url)) + .header("Authorization", format!("Bearer {}", auth.token)) .header("content-type", "application/json") .header("Editor-Version", EDITOR_VERSION) .header("User-Agent", COPILOT_USER_AGENT) @@ -1067,4 +1281,196 @@ mod tests { "copilot-integration-id header should NOT be sent" ); } + + // ── Enterprise proxy tests ────────────────────────────────────────────── + + #[test] + fn copilot_token_response_deserializes_proxy_ep() { + let json = r#"{ + "token": "tok_abc", + "expires_at": 1700000000, + "proxy-ep": "proxy.enterprise.githubcopilot.com" + }"#; + let resp: CopilotTokenResponse = serde_json::from_str(json).unwrap(); + assert_eq!(resp.token, "tok_abc"); + assert_eq!(resp.expires_at, 1700000000); + assert_eq!( + resp.proxy_ep.as_deref(), + Some("proxy.enterprise.githubcopilot.com") + ); + } + + #[test] + fn copilot_token_response_without_proxy_ep() { + let json = r#"{"token": "tok_abc", "expires_at": 1700000000}"#; + let resp: CopilotTokenResponse = serde_json::from_str(json).unwrap(); + assert!(resp.proxy_ep.is_none()); + } + + #[test] + fn copilot_auth_from_parts_individual() { + let auth = copilot_auth_from_parts("tok".into(), None); + assert_eq!(auth.base_url, COPILOT_API_BASE); + assert!(!auth.is_enterprise); + } + + #[test] + fn copilot_auth_from_parts_enterprise() { + let auth = + copilot_auth_from_parts("tok".into(), Some("proxy.enterprise.githubcopilot.com")); + assert_eq!(auth.base_url, "https://proxy.enterprise.githubcopilot.com"); + assert!(auth.is_enterprise); + } + + #[test] + fn copilot_auth_from_parts_empty_proxy_ep() { + let auth = copilot_auth_from_parts("tok".into(), Some("")); + assert_eq!(auth.base_url, COPILOT_API_BASE); + assert!(!auth.is_enterprise); + } + + /// Helper: start a mock server that returns SSE streaming responses. + async fn start_streaming_mock_with_capture( + sse_body: String, + ) -> (String, Arc>>) { + let captured: Arc>> = Arc::new(Mutex::new(Vec::new())); + let captured_clone = captured.clone(); + + let app = Router::new().route( + "/chat/completions", + post(move |req: Request| { + let cap = captured_clone.clone(); + let body_data = sse_body.clone(); + async move { + let headers: Vec<(String, String)> = req + .headers() + .iter() + .map(|(k, v)| { + (k.as_str().to_string(), v.to_str().unwrap_or("").to_string()) + }) + .collect(); + + let body_bytes = axum::body::to_bytes(req.into_body(), 1024 * 1024) + .await + .unwrap_or_default(); + let body: Option = serde_json::from_slice(&body_bytes).ok(); + + cap.lock().unwrap().push(CapturedRequest { headers, body }); + + ( + [( + http::header::CONTENT_TYPE, + "text/event-stream; charset=utf-8", + )], + body_data, + ) + } + }), + ); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + (format!("http://{addr}"), captured) + } + + fn mock_streaming_sse() -> String { + [ + r#"data: {"choices":[{"delta":{"role":"assistant","content":"Hello"}}]}"#, + r#"data: {"choices":[{"delta":{"content":" world"}}]}"#, + r#"data: {"choices":[],"usage":{"prompt_tokens":10,"completion_tokens":5}}"#, + "data: [DONE]", + "", + ] + .join("\n\n") + } + + #[tokio::test] + async fn enterprise_complete_uses_streaming_and_collects() { + let sse = mock_streaming_sse(); + let (base_url, captured) = start_streaming_mock_with_capture(sse).await; + + let auth = CopilotAuth { + token: "ent-token".into(), + base_url, + is_enterprise: true, + }; + + let client = reqwest::Client::new(); + let messages = vec![ChatMessage::user("hi")]; + let result = collect_streamed_completion(&client, &auth, "gpt-4o", &messages, &[]).await; + assert!(result.is_ok(), "expected Ok, got: {result:?}"); + + let resp = result.unwrap(); + assert_eq!(resp.text.as_deref(), Some("Hello world")); + assert!(resp.tool_calls.is_empty()); + assert_eq!(resp.usage.input_tokens, 10); + assert_eq!(resp.usage.output_tokens, 5); + + // Verify request had stream: true + let reqs = captured.lock().unwrap(); + let body = reqs[0].body.as_ref().unwrap(); + assert_eq!(body["stream"], true); + } + + #[tokio::test] + async fn enterprise_complete_collects_tool_calls() { + let sse = [ + r#"data: {"choices":[{"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"call_1","function":{"name":"read_file","arguments":""}}]}}]}"#, + r#"data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"path\":\"/tmp"}}]}}]}"#, + r#"data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"/test.txt\"}"}}]}}]}"#, + r#"data: {"choices":[],"usage":{"prompt_tokens":20,"completion_tokens":10}}"#, + "data: [DONE]", + "", + ] + .join("\n\n"); + + let (base_url, _) = start_streaming_mock_with_capture(sse).await; + + let auth = CopilotAuth { + token: "ent-token".into(), + base_url, + is_enterprise: true, + }; + + let client = reqwest::Client::new(); + let messages = vec![ChatMessage::user("read file")]; + let resp = collect_streamed_completion(&client, &auth, "gpt-4o", &messages, &[]) + .await + .unwrap(); + + assert!(resp.text.is_none() || resp.text.as_deref() == Some("")); + assert_eq!(resp.tool_calls.len(), 1); + assert_eq!(resp.tool_calls[0].id, "call_1"); + assert_eq!(resp.tool_calls[0].name, "read_file"); + assert_eq!(resp.tool_calls[0].arguments["path"], "/tmp/test.txt"); + } + + #[test] + fn stream_events_to_completion_text_only() { + let events = vec![ + StreamEvent::Delta("Hello ".into()), + StreamEvent::Delta("world".into()), + StreamEvent::Done(Usage { + input_tokens: 5, + output_tokens: 2, + ..Default::default() + }), + ]; + let resp = stream_events_to_completion(events); + assert_eq!(resp.text.as_deref(), Some("Hello world")); + assert!(resp.tool_calls.is_empty()); + assert_eq!(resp.usage.input_tokens, 5); + } + + #[test] + fn stream_events_to_completion_empty() { + let events = vec![StreamEvent::Done(Usage::default())]; + let resp = stream_events_to_completion(events); + assert!(resp.text.is_none()); + assert!(resp.tool_calls.is_empty()); + } }