diff --git a/crates/gateway/src/services.rs b/crates/gateway/src/services.rs index 6d982b6a5..1d0321140 100644 --- a/crates/gateway/src/services.rs +++ b/crates/gateway/src/services.rs @@ -1480,20 +1480,24 @@ impl GatewayServices { self } - /// Create a [`Services`] bundle for sharing with the GraphQL schema. + /// Create a [`Services`] bundle with an injected `chat` and `system_info`. /// - /// Clones all service `Arc`s (cheap pointer bumps) into the shared bundle. - /// The `system_info` service is provided separately because it needs the - /// fully-constructed `GatewayState` which isn't available during + /// Clones all other service `Arc`s (cheap pointer bumps) into the shared + /// bundle. The `system_info` service is provided separately because it + /// needs the fully-constructed `GatewayState` which isn't available during /// `GatewayServices` construction. - pub fn to_services(&self, system_info: Arc) -> Arc { + pub fn to_services_with_chat( + &self, + system_info: Arc, + chat: Arc, + ) -> Arc { Arc::new(Services { agent: self.agent.clone(), session: self.session.clone(), channel: self.channel.clone(), config: self.config.clone(), cron: self.cron.clone(), - chat: self.chat.clone(), + chat, tts: self.tts.clone(), stt: self.stt.clone(), skills: self.skills.clone(), @@ -1513,6 +1517,10 @@ impl GatewayServices { system_info, }) } + + pub fn to_services(&self, system_info: Arc) -> Arc { + self.to_services_with_chat(system_info, self.chat.clone()) + } } #[cfg(test)] diff --git a/crates/httpd/src/graphql_routes.rs b/crates/httpd/src/graphql_routes.rs index 3b0667b35..11c2c8c58 100644 --- a/crates/httpd/src/graphql_routes.rs +++ b/crates/httpd/src/graphql_routes.rs @@ -8,6 +8,7 @@ use std::sync::Arc; use { async_graphql::http::{GraphiQLPlugin, GraphiQLSource}, + async_trait::async_trait, axum::{ Json, extract::{FromRequestParts, Request, State, WebSocketUpgrade}, @@ -17,7 +18,10 @@ use { serde_json::Value, }; -use moltis_gateway::{services::ServiceResult, state::GatewayState}; +use moltis_gateway::{ + services::{ChatService, ServiceResult}, + state::GatewayState, +}; use crate::server::AppState; @@ -29,6 +33,101 @@ pub struct GatewaySystemInfoService { pub state: Arc, } +/// GraphQL chat shim that resolves the live chat service at call time. +/// +/// GraphQL schema construction happens before the late-bound chat service is +/// attached. Resolving through `GatewayState::chat()` keeps GraphQL aligned +/// with RPC/WebSocket behavior after the override is installed. +pub struct GraphqlChatServiceProxy { + pub state: Arc, +} + +#[async_trait] +impl ChatService for GraphqlChatServiceProxy { + async fn send(&self, params: Value) -> ServiceResult { + self.state.chat().await.send(params).await + } + + async fn send_sync(&self, params: Value) -> ServiceResult { + self.state.chat().await.send_sync(params).await + } + + async fn abort(&self, params: Value) -> ServiceResult { + self.state.chat().await.abort(params).await + } + + async fn cancel_queued(&self, params: Value) -> ServiceResult { + self.state.chat().await.cancel_queued(params).await + } + + async fn history(&self, params: Value) -> ServiceResult { + self.state.chat().await.history(params).await + } + + async fn inject(&self, params: Value) -> ServiceResult { + self.state.chat().await.inject(params).await + } + + async fn clear(&self, params: Value) -> ServiceResult { + self.state.chat().await.clear(params).await + } + + async fn compact(&self, params: Value) -> ServiceResult { + self.state.chat().await.compact(params).await + } + + async fn context(&self, params: Value) -> ServiceResult { + self.state.chat().await.context(params).await + } + + async fn raw_prompt(&self, params: Value) -> ServiceResult { + self.state.chat().await.raw_prompt(params).await + } + + async fn full_context(&self, params: Value) -> ServiceResult { + self.state.chat().await.full_context(params).await + } + + async fn active(&self, params: Value) -> ServiceResult { + self.state.chat().await.active(params).await + } + + async fn active_session_keys(&self) -> Vec { + self.state.chat().await.active_session_keys().await + } + + async fn active_thinking_text(&self, session_key: &str) -> Option { + self.state + .chat() + .await + .active_thinking_text(session_key) + .await + } + + async fn active_voice_pending(&self, session_key: &str) -> bool { + self.state + .chat() + .await + .active_voice_pending(session_key) + .await + } + + async fn peek(&self, params: Value) -> ServiceResult { + self.state.chat().await.peek(params).await + } +} + +pub fn build_graphql_schema(state: Arc) -> moltis_graphql::MoltisSchema { + let system_info = Arc::new(GatewaySystemInfoService { + state: Arc::clone(&state), + }); + let chat = Arc::new(GraphqlChatServiceProxy { + state: Arc::clone(&state), + }); + let services = state.services.to_services_with_chat(system_info, chat); + moltis_graphql::build_schema(services, state.graphql_broadcast.clone()) +} + #[async_trait::async_trait] impl moltis_service_traits::SystemInfoService for GatewaySystemInfoService { async fn health(&self) -> ServiceResult { diff --git a/crates/httpd/src/server.rs b/crates/httpd/src/server.rs index 7ea70d935..aa06f04d1 100644 --- a/crates/httpd/src/server.rs +++ b/crates/httpd/src/server.rs @@ -283,13 +283,7 @@ pub fn build_gateway_base( } #[cfg(feature = "graphql")] - let graphql_schema = { - let system_info = Arc::new(crate::graphql_routes::GatewaySystemInfoService { - state: Arc::clone(&state), - }); - let services = state.services.to_services(system_info); - moltis_graphql::build_schema(services, state.graphql_broadcast.clone()) - }; + let graphql_schema = crate::graphql_routes::build_graphql_schema(Arc::clone(&state)); let app_state = AppState { gateway: state, @@ -349,13 +343,7 @@ pub fn build_gateway_base( } #[cfg(feature = "graphql")] - let graphql_schema = { - let system_info = Arc::new(crate::graphql_routes::GatewaySystemInfoService { - state: Arc::clone(&state), - }); - let services = state.services.to_services(system_info); - moltis_graphql::build_schema(services, state.graphql_broadcast.clone()) - }; + let graphql_schema = crate::graphql_routes::build_graphql_schema(Arc::clone(&state)); let app_state = AppState { gateway: state, diff --git a/crates/httpd/tests/graphql_chat_binding.rs b/crates/httpd/tests/graphql_chat_binding.rs new file mode 100644 index 000000000..4b26597ae --- /dev/null +++ b/crates/httpd/tests/graphql_chat_binding.rs @@ -0,0 +1,162 @@ +#![allow(clippy::unwrap_used, clippy::expect_used)] +#![cfg(feature = "graphql")] + +use std::{ + net::SocketAddr, + sync::{Arc, Mutex}, +}; + +use {async_trait::async_trait, tokio::net::TcpListener}; + +use { + moltis_gateway::{ + auth, + methods::MethodRegistry, + services::{ChatService, GatewayServices, ServiceResult}, + state::GatewayState, + }, + moltis_httpd::server::{build_gateway_base, finalize_gateway_app}, + serde_json::{Value, json}, +}; + +#[derive(Default)] +struct RecordingChatService { + calls: Mutex>, +} + +impl RecordingChatService { + fn record(&self, method: &str) { + self.calls + .lock() + .unwrap_or_else(|e| e.into_inner()) + .push(method.to_string()); + } + + fn calls(&self) -> Vec { + self.calls.lock().unwrap_or_else(|e| e.into_inner()).clone() + } +} + +#[async_trait] +impl ChatService for RecordingChatService { + async fn send(&self, params: Value) -> ServiceResult { + self.record("send"); + assert_eq!(params["message"], "Hello"); + Ok(json!({ "ok": true })) + } + + async fn abort(&self, _params: Value) -> ServiceResult { + Ok(json!({ "ok": true })) + } + + async fn cancel_queued(&self, _params: Value) -> ServiceResult { + Ok(json!({ "cleared": 0 })) + } + + async fn history(&self, _params: Value) -> ServiceResult { + Ok(json!([])) + } + + async fn inject(&self, _params: Value) -> ServiceResult { + Ok(json!({ "ok": true })) + } + + async fn clear(&self, _params: Value) -> ServiceResult { + Ok(json!({ "ok": true })) + } + + async fn compact(&self, _params: Value) -> ServiceResult { + Ok(json!({ "ok": true })) + } + + async fn context(&self, _params: Value) -> ServiceResult { + Ok(json!({})) + } + + async fn raw_prompt(&self, _params: Value) -> ServiceResult { + Ok(json!({ "text": "prompt" })) + } + + async fn full_context(&self, _params: Value) -> ServiceResult { + Ok(json!([])) + } + + async fn active(&self, params: Value) -> ServiceResult { + self.record("active"); + assert_eq!(params["sessionKey"], "sess1"); + Ok(json!({ "active": true })) + } +} + +async fn start_graphql_server() -> (SocketAddr, Arc, tempfile::TempDir) { + let tmp = tempfile::tempdir().unwrap(); + moltis_config::set_config_dir(tmp.path().to_path_buf()); + moltis_config::set_data_dir(tmp.path().to_path_buf()); + + let state = GatewayState::new(auth::resolve_auth(None, None), GatewayServices::noop()); + let state_clone = Arc::clone(&state); + let methods = Arc::new(MethodRegistry::new()); + + #[cfg(feature = "push-notifications")] + let (router, app_state) = build_gateway_base(state, methods, None, None); + #[cfg(not(feature = "push-notifications"))] + let (router, app_state) = build_gateway_base(state, methods, None); + + let app = finalize_gateway_app(router, app_state, false); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + axum::serve( + listener, + app.into_make_service_with_connect_info::(), + ) + .await + .unwrap(); + }); + + (addr, state_clone, tmp) +} + +#[tokio::test] +async fn graphql_chat_uses_late_bound_override_after_schema_build() { + let (addr, state, _tmp) = start_graphql_server().await; + + let chat = Arc::new(RecordingChatService::default()); + state + .set_chat(Arc::clone(&chat) as Arc) + .await; + + let client = reqwest::Client::new(); + + let send_response: Value = client + .post(format!("http://{addr}/graphql")) + .json(&json!({ + "query": r#"mutation { chat { send(message: "Hello") { ok } } }"#, + })) + .send() + .await + .unwrap() + .json() + .await + .unwrap(); + + assert_eq!(send_response["data"]["chat"]["send"]["ok"], true); + + let active_response: Value = client + .post(format!("http://{addr}/graphql")) + .json(&json!({ + "query": r#"query { sessions { active(sessionKey: "sess1") { active } } }"#, + })) + .send() + .await + .unwrap() + .json() + .await + .unwrap(); + + assert_eq!( + active_response["data"]["sessions"]["active"]["active"], + true + ); + assert_eq!(chat.calls(), vec!["send", "active"]); +}