Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 38 additions & 8 deletions engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,9 @@ async fn handle_tunnel_message_mk2(
authorized_tunnel_routes: &HashMap<(protocol::mk2::GatewayId, protocol::mk2::RequestId), ()>,
msg: protocol::mk2::ToServerTunnelMessage,
) -> Result<()> {
let route = (msg.message_id.gateway_id, msg.message_id.request_id);
let clear_route = should_clear_tunnel_route_mk2(&msg.message_kind);

// Extract inner data length before consuming msg
let inner_data_len = tunnel_message_inner_data_len_mk2(&msg.message_kind);

Expand All @@ -868,10 +871,7 @@ async fn handle_tunnel_message_mk2(
return Err(errors::WsError::InvalidPacket("payload too large".to_string()).build());
}

if !authorized_tunnel_routes
.contains_async(&(msg.message_id.gateway_id, msg.message_id.request_id))
.await
{
if !authorized_tunnel_routes.contains_async(&route).await {
return Err(
errors::WsError::InvalidPacket("unauthorized tunnel message".to_string()).build(),
);
Expand Down Expand Up @@ -899,6 +899,10 @@ async fn handle_tunnel_message_mk2(
)
})?;

if clear_route {
authorized_tunnel_routes.remove_async(&route).await;
}

Ok(())
}

Expand All @@ -909,6 +913,9 @@ async fn handle_tunnel_message_mk1(
authorized_tunnel_routes: &HashMap<(protocol::mk2::GatewayId, protocol::mk2::RequestId), ()>,
msg: protocol::ToServerTunnelMessage,
) -> Result<()> {
let route = (msg.message_id.gateway_id, msg.message_id.request_id);
let clear_route = should_clear_tunnel_route_mk1(&msg.message_kind);

// Ignore DeprecatedTunnelAck messages (used only for backwards compatibility)
if matches!(
msg.message_kind,
Expand All @@ -925,10 +932,7 @@ async fn handle_tunnel_message_mk1(
return Err(errors::WsError::InvalidPacket("payload too large".to_string()).build());
}

if !authorized_tunnel_routes
.contains_async(&(msg.message_id.gateway_id, msg.message_id.request_id))
.await
{
if !authorized_tunnel_routes.contains_async(&route).await {
return Err(
errors::WsError::InvalidPacket("unauthorized tunnel message".to_string()).build(),
);
Expand All @@ -950,9 +954,35 @@ async fn handle_tunnel_message_mk1(
)
})?;

if clear_route {
authorized_tunnel_routes.remove_async(&route).await;
}

Ok(())
}

fn should_clear_tunnel_route_mk2(msg_kind: &protocol::mk2::ToServerTunnelMessageKind) -> bool {
match msg_kind {
protocol::mk2::ToServerTunnelMessageKind::ToServerResponseStart(response) => {
!response.stream
}
protocol::mk2::ToServerTunnelMessageKind::ToServerResponseChunk(chunk) => chunk.finish,
protocol::mk2::ToServerTunnelMessageKind::ToServerResponseAbort
| protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketClose(_) => true,
_ => false,
}
}

fn should_clear_tunnel_route_mk1(msg_kind: &protocol::ToServerTunnelMessageKind) -> bool {
match msg_kind {
protocol::ToServerTunnelMessageKind::ToServerResponseStart(response) => !response.stream,
protocol::ToServerTunnelMessageKind::ToServerResponseChunk(chunk) => chunk.finish,
protocol::ToServerTunnelMessageKind::ToServerResponseAbort
| protocol::ToServerTunnelMessageKind::ToServerWebSocketClose(_) => true,
_ => false,
}
}

/// Returns the length of the inner data payload for a tunnel message kind.
fn tunnel_message_inner_data_len_mk2(kind: &protocol::mk2::ToServerTunnelMessageKind) -> usize {
use protocol::mk2::ToServerTunnelMessageKind;
Expand Down
150 changes: 148 additions & 2 deletions engine/packages/pegboard-runner/tests/support/ws_to_tunnel_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,74 @@ fn response_abort_message_mk2(
}
}

fn response_start_message_mk2(
gateway_id: protocol::mk2::GatewayId,
request_id: protocol::mk2::RequestId,
) -> protocol::mk2::ToServerTunnelMessage {
response_start_message_mk2_with_stream(gateway_id, request_id, false)
}

fn response_start_message_mk2_with_stream(
gateway_id: protocol::mk2::GatewayId,
request_id: protocol::mk2::RequestId,
stream: bool,
) -> protocol::mk2::ToServerTunnelMessage {
protocol::mk2::ToServerTunnelMessage {
message_id: protocol::mk2::MessageId {
gateway_id,
request_id,
message_index: 0,
},
message_kind: protocol::mk2::ToServerTunnelMessageKind::ToServerResponseStart(
protocol::mk2::ToServerResponseStart {
status: 200,
headers: Default::default(),
body: None,
stream,
},
),
}
}

fn response_chunk_message_mk2(
gateway_id: protocol::mk2::GatewayId,
request_id: protocol::mk2::RequestId,
finish: bool,
) -> protocol::mk2::ToServerTunnelMessage {
protocol::mk2::ToServerTunnelMessage {
message_id: protocol::mk2::MessageId {
gateway_id,
request_id,
message_index: 0,
},
message_kind: protocol::mk2::ToServerTunnelMessageKind::ToServerResponseChunk(
protocol::mk2::ToServerResponseChunk {
body: b"chunk".to_vec(),
finish,
},
),
}
}

fn websocket_message_mk2(
gateway_id: protocol::mk2::GatewayId,
request_id: protocol::mk2::RequestId,
) -> protocol::mk2::ToServerTunnelMessage {
protocol::mk2::ToServerTunnelMessage {
message_id: protocol::mk2::MessageId {
gateway_id,
request_id,
message_index: 0,
},
message_kind: protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketMessage(
protocol::mk2::ToServerWebSocketMessage {
data: b"ping".to_vec(),
binary: false,
},
),
}
}

fn response_abort_message_mk1(
gateway_id: protocol::mk2::GatewayId,
request_id: protocol::mk2::RequestId,
Expand All @@ -39,6 +107,74 @@ fn response_abort_message_mk1(
}
}

fn websocket_message_mk1(
gateway_id: protocol::mk2::GatewayId,
request_id: protocol::mk2::RequestId,
) -> protocol::ToServerTunnelMessage {
protocol::ToServerTunnelMessage {
message_id: protocol::MessageId {
gateway_id,
request_id,
message_index: 0,
},
message_kind: protocol::ToServerTunnelMessageKind::ToServerWebSocketMessage(
protocol::ToServerWebSocketMessage {
data: b"ping".to_vec(),
binary: false,
},
),
}
}

fn response_start_message_mk1(
gateway_id: protocol::mk2::GatewayId,
request_id: protocol::mk2::RequestId,
) -> protocol::ToServerTunnelMessage {
response_start_message_mk1_with_stream(gateway_id, request_id, false)
}

fn response_start_message_mk1_with_stream(
gateway_id: protocol::mk2::GatewayId,
request_id: protocol::mk2::RequestId,
stream: bool,
) -> protocol::ToServerTunnelMessage {
protocol::ToServerTunnelMessage {
message_id: protocol::MessageId {
gateway_id,
request_id,
message_index: 0,
},
message_kind: protocol::ToServerTunnelMessageKind::ToServerResponseStart(
protocol::ToServerResponseStart {
status: 200,
headers: Default::default(),
body: None,
stream,
},
),
}
}

fn response_chunk_message_mk1(
gateway_id: protocol::mk2::GatewayId,
request_id: protocol::mk2::RequestId,
finish: bool,
) -> protocol::ToServerTunnelMessage {
protocol::ToServerTunnelMessage {
message_id: protocol::MessageId {
gateway_id,
request_id,
message_index: 0,
},
message_kind: protocol::ToServerTunnelMessageKind::ToServerResponseChunk(
protocol::ToServerResponseChunk {
body: b"chunk".to_vec(),
finish,
},
),
}
}

#[tokio::test]
async fn rejects_unissued_mk2_tunnel_message_pairs() {
let pubsub = memory_pubsub("pegboard-runner-ws-to-tunnel-test-reject-mk2");
Expand Down Expand Up @@ -82,7 +218,7 @@ async fn republishes_issued_mk2_tunnel_message_pairs() {
&pubsub,
1024,
&authorized_tunnel_routes,
response_abort_message_mk2(gateway_id, request_id),
websocket_message_mk2(gateway_id, request_id),
)
.await
.unwrap();
Expand All @@ -92,6 +228,11 @@ async fn republishes_issued_mk2_tunnel_message_pairs() {
.unwrap()
.unwrap();
assert!(matches!(msg, NextOutput::Message(_)));
assert!(
authorized_tunnel_routes
.contains_async(&(gateway_id, request_id))
.await
);
}

#[tokio::test]
Expand Down Expand Up @@ -137,7 +278,7 @@ async fn republishes_issued_mk1_tunnel_message_pairs() {
&pubsub,
1024,
&authorized_tunnel_routes,
response_abort_message_mk1(gateway_id, request_id),
websocket_message_mk1(gateway_id, request_id),
)
.await
.unwrap();
Expand All @@ -147,4 +288,9 @@ async fn republishes_issued_mk1_tunnel_message_pairs() {
.unwrap()
.unwrap();
assert!(matches!(msg, NextOutput::Message(_)));
assert!(
authorized_tunnel_routes
.contains_async(&(gateway_id, request_id))
.await
);
}
Loading