Skip to content

Commit eeba1f7

Browse files
committed
fix(pegboard-runner): clear terminal tunnel routes
1 parent 0aff89a commit eeba1f7

File tree

2 files changed

+446
-10
lines changed

2 files changed

+446
-10
lines changed

engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,9 @@ async fn handle_tunnel_message_mk2(
860860
authorized_tunnel_routes: &HashMap<(protocol::mk2::GatewayId, protocol::mk2::RequestId), ()>,
861861
msg: protocol::mk2::ToServerTunnelMessage,
862862
) -> Result<()> {
863+
let route = (msg.message_id.gateway_id, msg.message_id.request_id);
864+
let clear_route = should_clear_tunnel_route_mk2(&msg.message_kind);
865+
863866
// Extract inner data length before consuming msg
864867
let inner_data_len = tunnel_message_inner_data_len_mk2(&msg.message_kind);
865868

@@ -868,10 +871,7 @@ async fn handle_tunnel_message_mk2(
868871
return Err(errors::WsError::InvalidPacket("payload too large".to_string()).build());
869872
}
870873

871-
if !authorized_tunnel_routes
872-
.contains_async(&(msg.message_id.gateway_id, msg.message_id.request_id))
873-
.await
874-
{
874+
if !authorized_tunnel_routes.contains_async(&route).await {
875875
return Err(
876876
errors::WsError::InvalidPacket("unauthorized tunnel message".to_string()).build(),
877877
);
@@ -899,6 +899,10 @@ async fn handle_tunnel_message_mk2(
899899
)
900900
})?;
901901

902+
if clear_route {
903+
authorized_tunnel_routes.remove_async(&route).await;
904+
}
905+
902906
Ok(())
903907
}
904908

@@ -909,6 +913,9 @@ async fn handle_tunnel_message_mk1(
909913
authorized_tunnel_routes: &HashMap<(protocol::mk2::GatewayId, protocol::mk2::RequestId), ()>,
910914
msg: protocol::ToServerTunnelMessage,
911915
) -> Result<()> {
916+
let route = (msg.message_id.gateway_id, msg.message_id.request_id);
917+
let clear_route = should_clear_tunnel_route_mk1(&msg.message_kind);
918+
912919
// Ignore DeprecatedTunnelAck messages (used only for backwards compatibility)
913920
if matches!(
914921
msg.message_kind,
@@ -925,10 +932,7 @@ async fn handle_tunnel_message_mk1(
925932
return Err(errors::WsError::InvalidPacket("payload too large".to_string()).build());
926933
}
927934

928-
if !authorized_tunnel_routes
929-
.contains_async(&(msg.message_id.gateway_id, msg.message_id.request_id))
930-
.await
931-
{
935+
if !authorized_tunnel_routes.contains_async(&route).await {
932936
return Err(
933937
errors::WsError::InvalidPacket("unauthorized tunnel message".to_string()).build(),
934938
);
@@ -950,9 +954,35 @@ async fn handle_tunnel_message_mk1(
950954
)
951955
})?;
952956

957+
if clear_route {
958+
authorized_tunnel_routes.remove_async(&route).await;
959+
}
960+
953961
Ok(())
954962
}
955963

964+
fn should_clear_tunnel_route_mk2(msg_kind: &protocol::mk2::ToServerTunnelMessageKind) -> bool {
965+
match msg_kind {
966+
protocol::mk2::ToServerTunnelMessageKind::ToServerResponseStart(response) => {
967+
!response.stream
968+
}
969+
protocol::mk2::ToServerTunnelMessageKind::ToServerResponseChunk(chunk) => chunk.finish,
970+
protocol::mk2::ToServerTunnelMessageKind::ToServerResponseAbort
971+
| protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketClose(_) => true,
972+
_ => false,
973+
}
974+
}
975+
976+
fn should_clear_tunnel_route_mk1(msg_kind: &protocol::ToServerTunnelMessageKind) -> bool {
977+
match msg_kind {
978+
protocol::ToServerTunnelMessageKind::ToServerResponseStart(response) => !response.stream,
979+
protocol::ToServerTunnelMessageKind::ToServerResponseChunk(chunk) => chunk.finish,
980+
protocol::ToServerTunnelMessageKind::ToServerResponseAbort
981+
| protocol::ToServerTunnelMessageKind::ToServerWebSocketClose(_) => true,
982+
_ => false,
983+
}
984+
}
985+
956986
/// Returns the length of the inner data payload for a tunnel message kind.
957987
fn tunnel_message_inner_data_len_mk2(kind: &protocol::mk2::ToServerTunnelMessageKind) -> usize {
958988
use protocol::mk2::ToServerTunnelMessageKind;

0 commit comments

Comments
 (0)