diff --git a/engine/packages/pegboard-gateway/src/lib.rs b/engine/packages/pegboard-gateway/src/lib.rs index a340811ca2..6e0c073d44 100644 --- a/engine/packages/pegboard-gateway/src/lib.rs +++ b/engine/packages/pegboard-gateway/src/lib.rs @@ -27,7 +27,7 @@ use tokio_tungstenite::tungstenite::{ }; use universaldb::utils::IsolationLevel::*; -use crate::shared_state::{InFlightRequestHandle, SharedState}; +use crate::shared_state::{InFlightRequestHandle, InFlightRequestState, SharedState}; mod keepalive_task; mod metrics; @@ -178,7 +178,12 @@ impl PegboardGateway { .. } = self .shared_state - .start_in_flight_request(tunnel_subject, runner_protocol_version, request_id) + .start_in_flight_request( + tunnel_subject, + runner_protocol_version, + request_id, + InFlightRequestState::AwaitingHttpResponseStart, + ) .await; // Start request @@ -304,7 +309,16 @@ impl PegboardGateway { new, } = self .shared_state - .start_in_flight_request(tunnel_subject.clone(), runner_protocol_version, request_id) + .start_in_flight_request( + tunnel_subject.clone(), + runner_protocol_version, + request_id, + if after_hibernation { + InFlightRequestState::ActiveWebSocket + } else { + InFlightRequestState::AwaitingWebSocketOpen + }, + ) .await; ensure!( diff --git a/engine/packages/pegboard-gateway/src/shared_state.rs b/engine/packages/pegboard-gateway/src/shared_state.rs index f2ca560d93..d504f74961 100644 --- a/engine/packages/pegboard-gateway/src/shared_state.rs +++ b/engine/packages/pegboard-gateway/src/shared_state.rs @@ -26,10 +26,60 @@ pub struct InFlightRequestHandle { pub new: bool, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum InFlightRequestState { + AwaitingHttpResponseStart, + AwaitingWebSocketOpen, + ActiveWebSocket, + Closed, +} + +impl InFlightRequestState { + fn accept_message(&mut self, message_kind: &protocol::mk2::ToServerTunnelMessageKind) -> bool { + use protocol::mk2::ToServerTunnelMessageKind; + + match (self, message_kind) { + ( + state @ InFlightRequestState::AwaitingHttpResponseStart, + ToServerTunnelMessageKind::ToServerResponseStart(_) + | ToServerTunnelMessageKind::ToServerResponseAbort, + ) => { + *state = InFlightRequestState::Closed; + true + } + ( + state @ InFlightRequestState::AwaitingWebSocketOpen, + ToServerTunnelMessageKind::ToServerWebSocketOpen(_), + ) => { + *state = InFlightRequestState::ActiveWebSocket; + true + } + ( + state @ InFlightRequestState::AwaitingWebSocketOpen, + ToServerTunnelMessageKind::ToServerWebSocketClose(_), + ) + | ( + state @ InFlightRequestState::ActiveWebSocket, + ToServerTunnelMessageKind::ToServerWebSocketClose(_), + ) => { + *state = InFlightRequestState::Closed; + true + } + ( + InFlightRequestState::ActiveWebSocket, + ToServerTunnelMessageKind::ToServerWebSocketMessage(_) + | ToServerTunnelMessageKind::ToServerWebSocketMessageAck(_), + ) => true, + _ => false, + } + } +} + struct InFlightRequest { /// UPS subject to send messages to for this request. receiver_subject: String, protocol_version: u16, + state: InFlightRequestState, /// Sender for incoming messages to this request. msg_tx: mpsc::Sender, /// Used to check if the request handler has been dropped. @@ -134,6 +184,7 @@ impl SharedState { receiver_subject: String, protocol_version: u16, request_id: protocol::mk2::RequestId, + state: InFlightRequestState, ) -> InFlightRequestHandle { let (msg_tx, msg_rx) = mpsc::channel(128); let (drop_tx, drop_rx) = watch::channel(None); @@ -143,6 +194,7 @@ impl SharedState { entry.insert_entry(InFlightRequest { receiver_subject, protocol_version, + state, msg_tx, drop_tx, opened: false, @@ -159,6 +211,7 @@ impl SharedState { entry.receiver_subject = receiver_subject; entry.msg_tx = msg_tx; entry.drop_tx = drop_tx; + entry.state = state; entry.opened = false; entry.last_pong = util::timestamp::now(); @@ -355,7 +408,7 @@ impl SharedState { Ok(protocol::mk2::ToGateway::ToServerTunnelMessage(msg)) => { let message_id = msg.message_id; - let Some(in_flight) = self + let Some(mut in_flight) = self .in_flight_requests .get_async(&message_id.request_id) .await @@ -369,6 +422,18 @@ impl SharedState { continue; }; + if !in_flight.state.accept_message(&msg.message_kind) { + tracing::warn!( + gateway_id=%protocol::util::id_to_string(&message_id.gateway_id), + request_id=%protocol::util::id_to_string(&message_id.request_id), + message_index=message_id.message_index, + state=?in_flight.state, + message_kind=?msg.message_kind, + "dropping invalid tunnel message for request state" + ); + continue; + } + // Send message to the request handler to emulate the real network action let inner_size = match &msg.message_kind { protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketMessage( @@ -619,6 +684,65 @@ fn wrapping_gt(a: u16, b: u16) -> bool { a != b && a.wrapping_sub(b) < u16::MAX / 2 } +#[cfg(test)] +mod tests { + use super::InFlightRequestState; + use rivet_runner_protocol as protocol; + + #[test] + fn http_requests_only_accept_http_terminal_messages() { + let mut state = InFlightRequestState::AwaitingHttpResponseStart; + assert!( + state.accept_message(&protocol::mk2::ToServerTunnelMessageKind::ToServerResponseAbort,) + ); + assert_eq!(state, InFlightRequestState::Closed); + + let mut state = InFlightRequestState::AwaitingHttpResponseStart; + assert!(!state.accept_message( + &protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketMessage( + protocol::mk2::ToServerWebSocketMessage { + data: Vec::new(), + binary: false, + }, + ), + )); + assert_eq!(state, InFlightRequestState::AwaitingHttpResponseStart); + } + + #[test] + fn websockets_must_open_before_streaming() { + let mut state = InFlightRequestState::AwaitingWebSocketOpen; + assert!(!state.accept_message( + &protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketMessage( + protocol::mk2::ToServerWebSocketMessage { + data: Vec::new(), + binary: false, + }, + ), + )); + assert_eq!(state, InFlightRequestState::AwaitingWebSocketOpen); + + assert!(state.accept_message( + &protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketOpen( + protocol::mk2::ToServerWebSocketOpen { + can_hibernate: false, + }, + ), + )); + assert_eq!(state, InFlightRequestState::ActiveWebSocket); + } + + #[test] + fn active_websockets_reject_http_messages() { + let mut state = InFlightRequestState::ActiveWebSocket; + assert!( + !state + .accept_message(&protocol::mk2::ToServerTunnelMessageKind::ToServerResponseAbort,) + ); + assert_eq!(state, InFlightRequestState::ActiveWebSocket); + } +} + // fn wrapping_lt(a: u16, b: u16) -> bool { // b.wrapping_sub(a) < u16::MAX / 2 // }