Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions crates/gateway/src/services.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1480,20 +1480,24 @@ impl GatewayServices {
self
}

/// Create a [`Services`] bundle for sharing with the GraphQL schema.
/// Create a [`Services`] bundle with an injected `chat` and `system_info`.
///
/// Clones all service `Arc`s (cheap pointer bumps) into the shared bundle.
/// The `system_info` service is provided separately because it needs the
/// fully-constructed `GatewayState` which isn't available during
/// Clones all other service `Arc`s (cheap pointer bumps) into the shared
/// bundle. The `system_info` service is provided separately because it
/// needs the fully-constructed `GatewayState` which isn't available during
/// `GatewayServices` construction.
pub fn to_services(&self, system_info: Arc<dyn SystemInfoService>) -> Arc<Services> {
pub fn to_services_with_chat(
&self,
system_info: Arc<dyn SystemInfoService>,
chat: Arc<dyn ChatService>,
) -> Arc<Services> {
Arc::new(Services {
agent: self.agent.clone(),
session: self.session.clone(),
channel: self.channel.clone(),
config: self.config.clone(),
cron: self.cron.clone(),
chat: self.chat.clone(),
chat,
tts: self.tts.clone(),
stt: self.stt.clone(),
skills: self.skills.clone(),
Expand All @@ -1513,6 +1517,10 @@ impl GatewayServices {
system_info,
})
}

pub fn to_services(&self, system_info: Arc<dyn SystemInfoService>) -> Arc<Services> {
self.to_services_with_chat(system_info, self.chat.clone())
}
}

#[cfg(test)]
Expand Down
101 changes: 100 additions & 1 deletion crates/httpd/src/graphql_routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::sync::Arc;

use {
async_graphql::http::{GraphiQLPlugin, GraphiQLSource},
async_trait::async_trait,
axum::{
Json,
extract::{FromRequestParts, Request, State, WebSocketUpgrade},
Expand All @@ -17,7 +18,10 @@ use {
serde_json::Value,
};

use moltis_gateway::{services::ServiceResult, state::GatewayState};
use moltis_gateway::{
services::{ChatService, ServiceResult},
state::GatewayState,
};

use crate::server::AppState;

Expand All @@ -29,6 +33,101 @@ pub struct GatewaySystemInfoService {
pub state: Arc<GatewayState>,
}

/// GraphQL chat shim that resolves the live chat service at call time.
///
/// GraphQL schema construction happens before the late-bound chat service is
/// attached. Resolving through `GatewayState::chat()` keeps GraphQL aligned
/// with RPC/WebSocket behavior after the override is installed.
pub struct GraphqlChatServiceProxy {
pub state: Arc<GatewayState>,
}

#[async_trait]
impl ChatService for GraphqlChatServiceProxy {
async fn send(&self, params: Value) -> ServiceResult {
self.state.chat().await.send(params).await
}

async fn send_sync(&self, params: Value) -> ServiceResult {
self.state.chat().await.send_sync(params).await
}

async fn abort(&self, params: Value) -> ServiceResult {
self.state.chat().await.abort(params).await
}

async fn cancel_queued(&self, params: Value) -> ServiceResult {
self.state.chat().await.cancel_queued(params).await
}

async fn history(&self, params: Value) -> ServiceResult {
self.state.chat().await.history(params).await
}

async fn inject(&self, params: Value) -> ServiceResult {
self.state.chat().await.inject(params).await
}

async fn clear(&self, params: Value) -> ServiceResult {
self.state.chat().await.clear(params).await
}

async fn compact(&self, params: Value) -> ServiceResult {
self.state.chat().await.compact(params).await
}

async fn context(&self, params: Value) -> ServiceResult {
self.state.chat().await.context(params).await
}

async fn raw_prompt(&self, params: Value) -> ServiceResult {
self.state.chat().await.raw_prompt(params).await
}

async fn full_context(&self, params: Value) -> ServiceResult {
self.state.chat().await.full_context(params).await
}

async fn active(&self, params: Value) -> ServiceResult {
self.state.chat().await.active(params).await
}

async fn active_session_keys(&self) -> Vec<String> {
self.state.chat().await.active_session_keys().await
}

async fn active_thinking_text(&self, session_key: &str) -> Option<String> {
self.state
.chat()
.await
.active_thinking_text(session_key)
.await
}

async fn active_voice_pending(&self, session_key: &str) -> bool {
self.state
.chat()
.await
.active_voice_pending(session_key)
.await
}

async fn peek(&self, params: Value) -> ServiceResult {
self.state.chat().await.peek(params).await
}
}

pub fn build_graphql_schema(state: Arc<GatewayState>) -> moltis_graphql::MoltisSchema {
let system_info = Arc::new(GatewaySystemInfoService {
state: Arc::clone(&state),
});
let chat = Arc::new(GraphqlChatServiceProxy {
state: Arc::clone(&state),
});
let services = state.services.to_services_with_chat(system_info, chat);
moltis_graphql::build_schema(services, state.graphql_broadcast.clone())
}

#[async_trait::async_trait]
impl moltis_service_traits::SystemInfoService for GatewaySystemInfoService {
async fn health(&self) -> ServiceResult {
Expand Down
16 changes: 2 additions & 14 deletions crates/httpd/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,13 +283,7 @@ pub fn build_gateway_base(
}

#[cfg(feature = "graphql")]
let graphql_schema = {
let system_info = Arc::new(crate::graphql_routes::GatewaySystemInfoService {
state: Arc::clone(&state),
});
let services = state.services.to_services(system_info);
moltis_graphql::build_schema(services, state.graphql_broadcast.clone())
};
let graphql_schema = crate::graphql_routes::build_graphql_schema(Arc::clone(&state));

let app_state = AppState {
gateway: state,
Expand Down Expand Up @@ -349,13 +343,7 @@ pub fn build_gateway_base(
}

#[cfg(feature = "graphql")]
let graphql_schema = {
let system_info = Arc::new(crate::graphql_routes::GatewaySystemInfoService {
state: Arc::clone(&state),
});
let services = state.services.to_services(system_info);
moltis_graphql::build_schema(services, state.graphql_broadcast.clone())
};
let graphql_schema = crate::graphql_routes::build_graphql_schema(Arc::clone(&state));

let app_state = AppState {
gateway: state,
Expand Down
162 changes: 162 additions & 0 deletions crates/httpd/tests/graphql_chat_binding.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#![allow(clippy::unwrap_used, clippy::expect_used)]
#![cfg(feature = "graphql")]

use std::{
net::SocketAddr,
sync::{Arc, Mutex},
};

use {async_trait::async_trait, tokio::net::TcpListener};

use {
moltis_gateway::{
auth,
methods::MethodRegistry,
services::{ChatService, GatewayServices, ServiceResult},
state::GatewayState,
},
moltis_httpd::server::{build_gateway_base, finalize_gateway_app},
serde_json::{Value, json},
};

#[derive(Default)]
struct RecordingChatService {
calls: Mutex<Vec<String>>,
}

impl RecordingChatService {
fn record(&self, method: &str) {
self.calls
.lock()
.unwrap_or_else(|e| e.into_inner())
.push(method.to_string());
}

fn calls(&self) -> Vec<String> {
self.calls.lock().unwrap_or_else(|e| e.into_inner()).clone()
}
}

#[async_trait]
impl ChatService for RecordingChatService {
async fn send(&self, params: Value) -> ServiceResult {
self.record("send");
assert_eq!(params["message"], "Hello");
Ok(json!({ "ok": true }))
}

async fn abort(&self, _params: Value) -> ServiceResult {
Ok(json!({ "ok": true }))
}

async fn cancel_queued(&self, _params: Value) -> ServiceResult {
Ok(json!({ "cleared": 0 }))
}

async fn history(&self, _params: Value) -> ServiceResult {
Ok(json!([]))
}

async fn inject(&self, _params: Value) -> ServiceResult {
Ok(json!({ "ok": true }))
}

async fn clear(&self, _params: Value) -> ServiceResult {
Ok(json!({ "ok": true }))
}

async fn compact(&self, _params: Value) -> ServiceResult {
Ok(json!({ "ok": true }))
}

async fn context(&self, _params: Value) -> ServiceResult {
Ok(json!({}))
}

async fn raw_prompt(&self, _params: Value) -> ServiceResult {
Ok(json!({ "text": "prompt" }))
}

async fn full_context(&self, _params: Value) -> ServiceResult {
Ok(json!([]))
}

async fn active(&self, params: Value) -> ServiceResult {
self.record("active");
assert_eq!(params["sessionKey"], "sess1");
Ok(json!({ "active": true }))
}
}

async fn start_graphql_server() -> (SocketAddr, Arc<GatewayState>, tempfile::TempDir) {
let tmp = tempfile::tempdir().unwrap();
moltis_config::set_config_dir(tmp.path().to_path_buf());
moltis_config::set_data_dir(tmp.path().to_path_buf());

let state = GatewayState::new(auth::resolve_auth(None, None), GatewayServices::noop());
let state_clone = Arc::clone(&state);
let methods = Arc::new(MethodRegistry::new());

#[cfg(feature = "push-notifications")]
let (router, app_state) = build_gateway_base(state, methods, None, None);
#[cfg(not(feature = "push-notifications"))]
let (router, app_state) = build_gateway_base(state, methods, None);

let app = finalize_gateway_app(router, app_state, false);
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.await
.unwrap();
});

(addr, state_clone, tmp)
}

#[tokio::test]
async fn graphql_chat_uses_late_bound_override_after_schema_build() {
let (addr, state, _tmp) = start_graphql_server().await;

let chat = Arc::new(RecordingChatService::default());
state
.set_chat(Arc::clone(&chat) as Arc<dyn ChatService>)
.await;

let client = reqwest::Client::new();

let send_response: Value = client
.post(format!("http://{addr}/graphql"))
.json(&json!({
"query": r#"mutation { chat { send(message: "Hello") { ok } } }"#,
}))
.send()
.await
.unwrap()
.json()
.await
.unwrap();

assert_eq!(send_response["data"]["chat"]["send"]["ok"], true);

let active_response: Value = client
.post(format!("http://{addr}/graphql"))
.json(&json!({
"query": r#"query { sessions { active(sessionKey: "sess1") { active } } }"#,
}))
.send()
.await
.unwrap()
.json()
.await
.unwrap();

assert_eq!(
active_response["data"]["sessions"]["active"]["active"],
true
);
assert_eq!(chat.calls(), vec!["send", "active"]);
}
Loading