@@ -860,6 +860,9 @@ async fn handle_tunnel_message_mk2(
860860 authorized_tunnel_routes : & HashMap < ( protocol:: mk2:: GatewayId , protocol:: mk2:: RequestId ) , ( ) > ,
861861 msg : protocol:: mk2:: ToServerTunnelMessage ,
862862) -> Result < ( ) > {
863+ let route = ( msg. message_id . gateway_id , msg. message_id . request_id ) ;
864+ let clear_route = should_clear_tunnel_route_mk2 ( & msg. message_kind ) ;
865+
863866 // Extract inner data length before consuming msg
864867 let inner_data_len = tunnel_message_inner_data_len_mk2 ( & msg. message_kind ) ;
865868
@@ -868,10 +871,7 @@ async fn handle_tunnel_message_mk2(
868871 return Err ( errors:: WsError :: InvalidPacket ( "payload too large" . to_string ( ) ) . build ( ) ) ;
869872 }
870873
871- if !authorized_tunnel_routes
872- . contains_async ( & ( msg. message_id . gateway_id , msg. message_id . request_id ) )
873- . await
874- {
874+ if !authorized_tunnel_routes. contains_async ( & route) . await {
875875 return Err (
876876 errors:: WsError :: InvalidPacket ( "unauthorized tunnel message" . to_string ( ) ) . build ( ) ,
877877 ) ;
@@ -899,6 +899,10 @@ async fn handle_tunnel_message_mk2(
899899 )
900900 } ) ?;
901901
902+ if clear_route {
903+ authorized_tunnel_routes. remove_async ( & route) . await ;
904+ }
905+
902906 Ok ( ( ) )
903907}
904908
@@ -909,6 +913,9 @@ async fn handle_tunnel_message_mk1(
909913 authorized_tunnel_routes : & HashMap < ( protocol:: mk2:: GatewayId , protocol:: mk2:: RequestId ) , ( ) > ,
910914 msg : protocol:: ToServerTunnelMessage ,
911915) -> Result < ( ) > {
916+ let route = ( msg. message_id . gateway_id , msg. message_id . request_id ) ;
917+ let clear_route = should_clear_tunnel_route_mk1 ( & msg. message_kind ) ;
918+
912919 // Ignore DeprecatedTunnelAck messages (used only for backwards compatibility)
913920 if matches ! (
914921 msg. message_kind,
@@ -925,10 +932,7 @@ async fn handle_tunnel_message_mk1(
925932 return Err ( errors:: WsError :: InvalidPacket ( "payload too large" . to_string ( ) ) . build ( ) ) ;
926933 }
927934
928- if !authorized_tunnel_routes
929- . contains_async ( & ( msg. message_id . gateway_id , msg. message_id . request_id ) )
930- . await
931- {
935+ if !authorized_tunnel_routes. contains_async ( & route) . await {
932936 return Err (
933937 errors:: WsError :: InvalidPacket ( "unauthorized tunnel message" . to_string ( ) ) . build ( ) ,
934938 ) ;
@@ -950,9 +954,35 @@ async fn handle_tunnel_message_mk1(
950954 )
951955 } ) ?;
952956
957+ if clear_route {
958+ authorized_tunnel_routes. remove_async ( & route) . await ;
959+ }
960+
953961 Ok ( ( ) )
954962}
955963
964+ fn should_clear_tunnel_route_mk2 ( msg_kind : & protocol:: mk2:: ToServerTunnelMessageKind ) -> bool {
965+ match msg_kind {
966+ protocol:: mk2:: ToServerTunnelMessageKind :: ToServerResponseStart ( response) => {
967+ !response. stream
968+ }
969+ protocol:: mk2:: ToServerTunnelMessageKind :: ToServerResponseChunk ( chunk) => chunk. finish ,
970+ protocol:: mk2:: ToServerTunnelMessageKind :: ToServerResponseAbort
971+ | protocol:: mk2:: ToServerTunnelMessageKind :: ToServerWebSocketClose ( _) => true ,
972+ _ => false ,
973+ }
974+ }
975+
976+ fn should_clear_tunnel_route_mk1 ( msg_kind : & protocol:: ToServerTunnelMessageKind ) -> bool {
977+ match msg_kind {
978+ protocol:: ToServerTunnelMessageKind :: ToServerResponseStart ( response) => !response. stream ,
979+ protocol:: ToServerTunnelMessageKind :: ToServerResponseChunk ( chunk) => chunk. finish ,
980+ protocol:: ToServerTunnelMessageKind :: ToServerResponseAbort
981+ | protocol:: ToServerTunnelMessageKind :: ToServerWebSocketClose ( _) => true ,
982+ _ => false ,
983+ }
984+ }
985+
956986/// Returns the length of the inner data payload for a tunnel message kind.
957987fn tunnel_message_inner_data_len_mk2 ( kind : & protocol:: mk2:: ToServerTunnelMessageKind ) -> usize {
958988 use protocol:: mk2:: ToServerTunnelMessageKind ;
0 commit comments