Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions engine/packages/pegboard-gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use tokio_tungstenite::tungstenite::{
};
use universaldb::utils::IsolationLevel::*;

use crate::shared_state::{InFlightRequestHandle, SharedState};
use crate::shared_state::{InFlightRequestHandle, InFlightRequestState, SharedState};

mod keepalive_task;
mod metrics;
Expand Down Expand Up @@ -178,7 +178,12 @@ impl PegboardGateway {
..
} = self
.shared_state
.start_in_flight_request(tunnel_subject, runner_protocol_version, request_id)
.start_in_flight_request(
tunnel_subject,
runner_protocol_version,
request_id,
InFlightRequestState::AwaitingHttpResponseStart,
)
.await;

// Start request
Expand Down Expand Up @@ -304,7 +309,16 @@ impl PegboardGateway {
new,
} = self
.shared_state
.start_in_flight_request(tunnel_subject.clone(), runner_protocol_version, request_id)
.start_in_flight_request(
tunnel_subject.clone(),
runner_protocol_version,
request_id,
if after_hibernation {
InFlightRequestState::ActiveWebSocket
} else {
InFlightRequestState::AwaitingWebSocketOpen
},
)
.await;

ensure!(
Expand Down
126 changes: 125 additions & 1 deletion engine/packages/pegboard-gateway/src/shared_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,60 @@ pub struct InFlightRequestHandle {
pub new: bool,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InFlightRequestState {
AwaitingHttpResponseStart,
AwaitingWebSocketOpen,
ActiveWebSocket,
Closed,
}

impl InFlightRequestState {
fn accept_message(&mut self, message_kind: &protocol::mk2::ToServerTunnelMessageKind) -> bool {
use protocol::mk2::ToServerTunnelMessageKind;

match (self, message_kind) {
(
state @ InFlightRequestState::AwaitingHttpResponseStart,
ToServerTunnelMessageKind::ToServerResponseStart(_)
| ToServerTunnelMessageKind::ToServerResponseAbort,
) => {
*state = InFlightRequestState::Closed;
true
}
(
state @ InFlightRequestState::AwaitingWebSocketOpen,
ToServerTunnelMessageKind::ToServerWebSocketOpen(_),
) => {
*state = InFlightRequestState::ActiveWebSocket;
true
}
(
state @ InFlightRequestState::AwaitingWebSocketOpen,
ToServerTunnelMessageKind::ToServerWebSocketClose(_),
)
| (
state @ InFlightRequestState::ActiveWebSocket,
ToServerTunnelMessageKind::ToServerWebSocketClose(_),
) => {
*state = InFlightRequestState::Closed;
true
}
(
InFlightRequestState::ActiveWebSocket,
ToServerTunnelMessageKind::ToServerWebSocketMessage(_)
| ToServerTunnelMessageKind::ToServerWebSocketMessageAck(_),
) => true,
_ => false,
}
}
}

struct InFlightRequest {
/// UPS subject to send messages to for this request.
receiver_subject: String,
protocol_version: u16,
state: InFlightRequestState,
/// Sender for incoming messages to this request.
msg_tx: mpsc::Sender<protocol::mk2::ToServerTunnelMessageKind>,
/// Used to check if the request handler has been dropped.
Expand Down Expand Up @@ -134,6 +184,7 @@ impl SharedState {
receiver_subject: String,
protocol_version: u16,
request_id: protocol::mk2::RequestId,
state: InFlightRequestState,
) -> InFlightRequestHandle {
let (msg_tx, msg_rx) = mpsc::channel(128);
let (drop_tx, drop_rx) = watch::channel(None);
Expand All @@ -143,6 +194,7 @@ impl SharedState {
entry.insert_entry(InFlightRequest {
receiver_subject,
protocol_version,
state,
msg_tx,
drop_tx,
opened: false,
Expand All @@ -159,6 +211,7 @@ impl SharedState {
entry.receiver_subject = receiver_subject;
entry.msg_tx = msg_tx;
entry.drop_tx = drop_tx;
entry.state = state;
entry.opened = false;
entry.last_pong = util::timestamp::now();

Expand Down Expand Up @@ -355,7 +408,7 @@ impl SharedState {
Ok(protocol::mk2::ToGateway::ToServerTunnelMessage(msg)) => {
let message_id = msg.message_id;

let Some(in_flight) = self
let Some(mut in_flight) = self
.in_flight_requests
.get_async(&message_id.request_id)
.await
Expand All @@ -369,6 +422,18 @@ impl SharedState {
continue;
};

if !in_flight.state.accept_message(&msg.message_kind) {
tracing::warn!(
gateway_id=%protocol::util::id_to_string(&message_id.gateway_id),
request_id=%protocol::util::id_to_string(&message_id.request_id),
message_index=message_id.message_index,
state=?in_flight.state,
message_kind=?msg.message_kind,
"dropping invalid tunnel message for request state"
);
continue;
}

// Send message to the request handler to emulate the real network action
let inner_size = match &msg.message_kind {
protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketMessage(
Expand Down Expand Up @@ -619,6 +684,65 @@ fn wrapping_gt(a: u16, b: u16) -> bool {
a != b && a.wrapping_sub(b) < u16::MAX / 2
}

#[cfg(test)]
mod tests {
use super::InFlightRequestState;
use rivet_runner_protocol as protocol;

#[test]
fn http_requests_only_accept_http_terminal_messages() {
let mut state = InFlightRequestState::AwaitingHttpResponseStart;
assert!(
state.accept_message(&protocol::mk2::ToServerTunnelMessageKind::ToServerResponseAbort,)
);
assert_eq!(state, InFlightRequestState::Closed);

let mut state = InFlightRequestState::AwaitingHttpResponseStart;
assert!(!state.accept_message(
&protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketMessage(
protocol::mk2::ToServerWebSocketMessage {
data: Vec::new(),
binary: false,
},
),
));
assert_eq!(state, InFlightRequestState::AwaitingHttpResponseStart);
}

#[test]
fn websockets_must_open_before_streaming() {
let mut state = InFlightRequestState::AwaitingWebSocketOpen;
assert!(!state.accept_message(
&protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketMessage(
protocol::mk2::ToServerWebSocketMessage {
data: Vec::new(),
binary: false,
},
),
));
assert_eq!(state, InFlightRequestState::AwaitingWebSocketOpen);

assert!(state.accept_message(
&protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketOpen(
protocol::mk2::ToServerWebSocketOpen {
can_hibernate: false,
},
),
));
assert_eq!(state, InFlightRequestState::ActiveWebSocket);
}

#[test]
fn active_websockets_reject_http_messages() {
let mut state = InFlightRequestState::ActiveWebSocket;
assert!(
!state
.accept_message(&protocol::mk2::ToServerTunnelMessageKind::ToServerResponseAbort,)
);
assert_eq!(state, InFlightRequestState::ActiveWebSocket);
}
}

// fn wrapping_lt(a: u16, b: u16) -> bool {
// b.wrapping_sub(a) < u16::MAX / 2
// }
Loading