diff --git a/src/main.rs b/src/main.rs index 41f4e73e0..88c63c801 100644 --- a/src/main.rs +++ b/src/main.rs @@ -183,6 +183,95 @@ struct ActiveChannel { _outbound_handle: tokio::task::JoinHandle<()>, } +/// Key to uniquely identify a channel by (agent_id, conversation_id) pair. +/// This prevents message leakage between agents when channels share the same +/// conversation_id (e.g., DMs that may have the same ID across different agents). +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct ActiveChannelKey { + agent_id: String, + conversation_id: String, +} + +impl ActiveChannelKey { + fn new(agent_id: impl AsRef, conversation_id: impl Into) -> Self { + Self { + agent_id: agent_id.as_ref().to_string(), + conversation_id: conversation_id.into(), + } + } +} + +/// A deferred message waiting for its target channel to become active. +/// This ensures messages are delivered to their intended target only. +#[derive(Debug, Clone)] +struct DeferredMessage { + /// The exact key of the target channel where this message must be delivered. + target_key: ActiveChannelKey, + /// The message to deliver. + message: spacebot::InboundMessage, + /// When the message was deferred. + deferred_at: chrono::DateTime, +} + +/// Queue for deferred messages that will be delivered when their target +/// channel becomes active. Each message is bound to a specific +/// (agent_id, conversation_id) pair. +#[derive(Debug, Default)] +struct DeferredMessageQueue { + messages: Vec, +} + +impl DeferredMessageQueue { + fn new() -> Self { + Self { + messages: Vec::new(), + } + } + + /// Add a message to the queue bound to a specific target channel. + fn push(&mut self, target_key: ActiveChannelKey, message: spacebot::InboundMessage) { + self.messages.push(DeferredMessage { + target_key, + message, + deferred_at: chrono::Utc::now(), + }); + } + + /// Drain and return all messages intended for the given channel key. + fn drain_for(&mut self, key: &ActiveChannelKey) -> Vec { + let existing = std::mem::take(&mut self.messages); + let mut drained = Vec::new(); + let mut kept = Vec::with_capacity(existing.len()); + + for deferred in existing { + if &deferred.target_key == key { + drained.push(deferred.message); + } else { + kept.push(deferred); + } + } + + self.messages = kept; + drained + } + + /// Check if there are any deferred messages for a specific channel. + fn has_for(&self, key: &ActiveChannelKey) -> bool { + self.messages.iter().any(|m| &m.target_key == key) + } + + /// Get count of deferred messages. + fn len(&self) -> usize { + self.messages.len() + } + + /// Remove messages older than the given duration. + fn remove_expired(&mut self, max_age: chrono::Duration) { + let now = chrono::Utc::now(); + self.messages.retain(|m| now - m.deferred_at < max_age); + } +} + #[derive(Debug, serde::Serialize)] struct BackfillTranscriptEntry { role: String, @@ -1821,8 +1910,16 @@ async fn run( tracing::info!(pid = std::process::id(), "spacebot daemon started"); } - // Active conversation channels: conversation_id -> ActiveChannel - let mut active_channels: HashMap = HashMap::new(); + // Active conversation channels: (agent_id, conversation_id) -> ActiveChannel + // Uses ActiveChannelKey to ensure channels are uniquely identified by the + // exact (agent_id, conversation_id) pair, preventing message leakage between + // agents when channels share conversation_ids (e.g., DMs). + let mut active_channels: HashMap = HashMap::new(); + + // Queue for deferred messages when their target channel is not active. + // Messages are bound to their original target (agent_id, conversation_id) + // and will only be delivered to that exact channel when it becomes active. + let mut deferred_messages = DeferredMessageQueue::new(); // Resume idle interactive workers that survived the restart. // For each idle worker, pre-create the channel if needed and spawn @@ -1874,7 +1971,8 @@ async fn run( for (conversation_id, workers) in by_channel { // Ensure the channel exists. If it's already in active_channels // (unlikely at startup), use its state. Otherwise, pre-create it. - if !active_channels.contains_key(&conversation_id) { + let channel_key = ActiveChannelKey::new(agent_id.clone(), conversation_id.clone()); + if !active_channels.contains_key(&channel_key) { // First pass: retire any workers whose sessions can't be // reconnected. Only create the channel if at least one // worker has a chance of resuming. @@ -2106,14 +2204,40 @@ async fn run( } }); + let channel_key = + ActiveChannelKey::new(agent_id.clone(), conversation_id.clone()); active_channels.insert( - conversation_id.clone(), + channel_key.clone(), ActiveChannel { message_tx: channel_tx, _outbound_handle: outbound_handle, }, ); + // Deliver any deferred messages that were waiting for this channel + let deferred = deferred_messages.drain_for(&channel_key); + let deferred_count = deferred.len(); + if deferred_count > 0 { + if let Some(channel) = active_channels.get(&channel_key) { + for message in deferred { + if let Err(error) = channel.message_tx.send(message).await { + tracing::warn!( + %error, + conversation_id = %conversation_id, + agent_id = %agent_id, + "failed to deliver deferred message" + ); + } + } + tracing::info!( + conversation_id = %conversation_id, + agent_id = %agent_id, + count = deferred_count, + "delivered deferred messages to newly active channel" + ); + } + } + tracing::info!( conversation_id = %conversation_id, agent_id = %agent_id, @@ -2155,9 +2279,10 @@ async fn run( }; let conversation_id = message.conversation_id.clone(); + let channel_key = ActiveChannelKey::new(agent_id.clone(), conversation_id.clone()); // Find or create a channel for this conversation - if !active_channels.contains_key(&conversation_id) { + if !active_channels.contains_key(&channel_key) { let Some(agent) = agents.get(&agent_id) else { tracing::warn!( agent_id = %agent_id, @@ -2349,10 +2474,38 @@ async fn run( ); }); - active_channels.insert(conversation_id.clone(), ActiveChannel { - message_tx: channel_tx, - _outbound_handle: outbound_handle, - }); + let channel_key = ActiveChannelKey::new(agent_id.clone(), conversation_id.clone()); + active_channels.insert( + channel_key.clone(), + ActiveChannel { + message_tx: channel_tx, + _outbound_handle: outbound_handle, + }, + ); + + // Deliver any deferred messages that were waiting for this channel + let deferred = deferred_messages.drain_for(&channel_key); + let deferred_count = deferred.len(); + if deferred_count > 0 { + if let Some(channel) = active_channels.get(&channel_key) { + for message in deferred { + if let Err(error) = channel.message_tx.send(message).await { + tracing::warn!( + %error, + conversation_id = %conversation_id, + agent_id = %agent_id, + "failed to deliver deferred message" + ); + } + } + tracing::info!( + conversation_id = %conversation_id, + agent_id = %agent_id, + count = deferred_count, + "delivered deferred messages to newly active channel" + ); + } + } tracing::info!( conversation_id = %conversation_id, @@ -2362,7 +2515,7 @@ async fn run( } // Forward the message to the channel - if let Some(active) = active_channels.get(&conversation_id) { + if let Some(active) = active_channels.get(&channel_key) { // Emit inbound message to SSE clients let sender_name = message.formatted_author.clone().or_else(|| { message @@ -2385,7 +2538,7 @@ async fn run( %error, "failed to forward message to channel" ); - active_channels.remove(&conversation_id); + active_channels.remove(&channel_key); } } } @@ -2404,8 +2557,15 @@ async fn run( } // Cross-agent message injection (e.g. delegated task completion retrigger). // Forwards the injected message to the target channel if it exists. + // SECURITY FIX: Uses exact (agent_id, conversation_id) key to prevent + // message leakage to unintended channels. Deferred messages are queued + // and only delivered when the exact target channel becomes active. Some(injection) = injection_rx.recv() => { - if let Some(active) = active_channels.get(&injection.conversation_id) { + let target_key = ActiveChannelKey::new( + injection.agent_id.clone(), + injection.conversation_id.clone(), + ); + if let Some(active) = active_channels.get(&target_key) { if let Err(error) = active.message_tx.send(injection.message).await { tracing::warn!( %error, @@ -2421,10 +2581,15 @@ async fn run( ); } } else { + // SECURITY FIX: Queue the message for the exact target channel + // instead of delivering it to any active channel. This prevents + // cron output from leaking to unintended channels. + deferred_messages.push(target_key, injection.message); + deferred_messages.remove_expired(chrono::Duration::hours(24)); tracing::info!( conversation_id = %injection.conversation_id, agent_id = %injection.agent_id, - "injection target channel not active, notification will be delivered on next message" + "injection target channel not active, message queued for exact target" ); } } @@ -3695,7 +3860,7 @@ async fn initialize_agents( #[cfg(test)] mod tests { - use super::wait_for_startup_warmup_tasks; + use super::{ActiveChannelKey, DeferredMessageQueue, wait_for_startup_warmup_tasks}; use std::future::pending; use std::sync::Arc; use std::time::Duration; @@ -3758,4 +3923,105 @@ mod tests { "startup warmup timeout should return without waiting for non-cooperative task" ); } + + // ============================================================================ + // SECURITY REGRESSION TEST FOR ISSUE #498 + // Tests that deferred messages stay bound to their original target channel + // and are NOT delivered to unrelated active channels. + // ============================================================================ + + #[test] + fn active_channel_key_uniquely_identifies_agent_conversation_pairs() { + let key1 = ActiveChannelKey::new("agent1", "conv123"); + let key2 = ActiveChannelKey::new("agent1", "conv123"); + let key3 = ActiveChannelKey::new("agent2", "conv123"); + let key4 = ActiveChannelKey::new("agent1", "conv456"); + + // Same (agent_id, conversation_id) pairs are equal + assert_eq!(key1, key2); + + // Different agent_ids with same conversation_id are NOT equal + // (this is the core security fix - prevents cross-agent leakage) + assert_ne!(key1, key3); + + // Same agent_id with different conversation_ids are NOT equal + assert_ne!(key1, key4); + } + + #[test] + fn deferred_message_queue_binds_messages_to_target_key() { + let mut queue = DeferredMessageQueue::new(); + let target_key = ActiveChannelKey::new("agent1", "dm_channel_123"); + + // Create a test message + let message = spacebot::InboundMessage { + id: "test-msg-1".to_string(), + source: "test".into(), + adapter: None, + conversation_id: "dm_channel_123".to_string(), + sender_id: "system".into(), + agent_id: Some("agent1".into()), + content: spacebot::MessageContent::Text("test message".to_string()), + timestamp: chrono::Utc::now(), + metadata: std::collections::HashMap::new(), + formatted_author: None, + }; + + // Queue a message for agent1/dm_channel_123 + queue.push(target_key.clone(), message.clone()); + + // Verify message exists for the exact target + assert!(queue.has_for(&target_key)); + assert_eq!(queue.len(), 1); + + // Verify message does NOT exist for different agent with same conversation_id + let other_agent_key = ActiveChannelKey::new("agent2", "dm_channel_123"); + assert!(!queue.has_for(&other_agent_key)); + + // Verify message does NOT exist for same agent with different conversation_id + let other_conv_key = ActiveChannelKey::new("agent1", "public_channel_456"); + assert!(!queue.has_for(&other_conv_key)); + + // Drain for the correct key returns the message + let drained = queue.drain_for(&target_key); + assert_eq!(drained.len(), 1); + assert_eq!(drained[0].id, "test-msg-1"); + assert_eq!(queue.len(), 0); + + // Draining for wrong key returns empty + queue.push(target_key.clone(), message.clone()); + let wrong_drain = queue.drain_for(&other_agent_key); + assert!(wrong_drain.is_empty()); + assert_eq!(queue.len(), 1); // Message still in queue + } + + #[test] + fn deferred_message_queue_remove_expired_works() { + let mut queue = DeferredMessageQueue::new(); + let target_key = ActiveChannelKey::new("agent1", "dm_channel_123"); + + let message = spacebot::InboundMessage { + id: "test-msg-1".to_string(), + source: "test".into(), + adapter: None, + conversation_id: "dm_channel_123".to_string(), + sender_id: "system".into(), + agent_id: Some("agent1".into()), + content: spacebot::MessageContent::Text("test message".to_string()), + timestamp: chrono::Utc::now(), + metadata: std::collections::HashMap::new(), + formatted_author: None, + }; + + queue.push(target_key.clone(), message); + assert_eq!(queue.len(), 1); + + // Very short expiration should not remove fresh messages + queue.remove_expired(chrono::Duration::seconds(60)); + assert_eq!(queue.len(), 1); + + // Zero duration should remove all messages (they're at least 0 nanoseconds old) + queue.remove_expired(chrono::Duration::seconds(0)); + assert_eq!(queue.len(), 0); + } }