From bd27d9bb849bde7f44d60aaac46a0257b95d48f3 Mon Sep 17 00:00:00 2001 From: jbesraa Date: Fri, 18 Oct 2024 12:45:42 +0300 Subject: [PATCH] Allow sniffer to mutate ongoing messages ..Add new `InterceptMessage` property to allow the sniffer to mutate a message before it sent to downstream/upstream. --- roles/tests-integration/tests/common/mod.rs | 11 +- .../tests-integration/tests/common/sniffer.rs | 116 ++++++++++++++++-- .../tests/pool_integration.rs | 43 ++++++- 3 files changed, 161 insertions(+), 9 deletions(-) diff --git a/roles/tests-integration/tests/common/mod.rs b/roles/tests-integration/tests/common/mod.rs index 87b5a3096..782b56ad5 100644 --- a/roles/tests-integration/tests/common/mod.rs +++ b/roles/tests-integration/tests/common/mod.rs @@ -7,6 +7,7 @@ use once_cell::sync::Lazy; use pool_sv2::PoolSv2; use rand::{thread_rng, Rng}; use sniffer::Sniffer; +pub use sniffer::{InterceptMessage, MessageDirection}; use std::{ collections::HashSet, convert::{TryFrom, TryInto}, @@ -199,8 +200,16 @@ pub async fn start_sniffer( listening_address: SocketAddr, upstream: SocketAddr, check_on_drop: bool, + intercept_message: Option>, ) -> Sniffer { - let sniffer = Sniffer::new(identifier, listening_address, upstream, check_on_drop).await; + let sniffer = Sniffer::new( + identifier, + listening_address, + upstream, + check_on_drop, + intercept_message, + ) + .await; let sniffer_clone = sniffer.clone(); tokio::spawn(async move { sniffer_clone.start().await; diff --git a/roles/tests-integration/tests/common/sniffer.rs b/roles/tests-integration/tests/common/sniffer.rs index 9c6cd112b..985f1ba92 100644 --- a/roles/tests-integration/tests/common/sniffer.rs +++ b/roles/tests-integration/tests/common/sniffer.rs @@ -1,6 +1,6 @@ use async_channel::{Receiver, Sender}; use codec_sv2::{ - framing_sv2::framing::Frame, HandshakeRole, Initiator, Responder, StandardEitherFrame, + framing_sv2::framing::Frame, HandshakeRole, Initiator, Responder, StandardEitherFrame, Sv2Frame, }; use key_utils::{Secp256k1PublicKey, Secp256k1SecretKey}; use network_helpers_sv2::noise_connection_tokio::Connection; @@ -13,8 +13,8 @@ use roles_logic_sv2::{ IdentifyTransactionsSuccess, ProvideMissingTransactions, ProvideMissingTransactionsSuccess, SubmitSolution, }, - TemplateDistribution, - TemplateDistribution::CoinbaseOutputDataSize, + PoolMessages, + TemplateDistribution::{self, CoinbaseOutputDataSize}, }, utils::Mutex, }; @@ -22,6 +22,7 @@ use std::{collections::VecDeque, convert::TryInto, net::SocketAddr, sync::Arc}; use tokio::{ net::{TcpListener, TcpStream}, select, + time::sleep, }; type MessageFrame = StandardEitherFrame>; type MsgType = u8; @@ -30,6 +31,7 @@ type MsgType = u8; enum SnifferError { DownstreamClosed, UpstreamClosed, + MessageInterrupted, } /// Allows to intercept messages sent between two roles. @@ -42,7 +44,15 @@ enum SnifferError { /// forwarded to the downstream role. Both `messages_from_downstream` and `messages_from_upstream` /// can be accessed as FIFO queues. /// -/// It is useful for testing purposes, as it allows to assert that the roles have sent specific +/// In order to alter the messages sent between the roles, the [`Sniffer::intercept_messages`] +/// field can be used. It will look for the [`InterceptMessage::expected_message_type`] in the +/// specified [`InterceptMessage::direction`] and replace it with +/// [`InterceptMessage::response_message`]. +/// +/// If `break_on` is set to `true`, the [`Sniffer`] will stop the communication after sending the +/// response message. +/// +/// Can be useful for testing purposes, as it allows to assert that the roles have sent specific /// messages in a specific order and to inspect the messages details. #[derive(Debug, Clone)] pub struct Sniffer { @@ -52,6 +62,40 @@ pub struct Sniffer { messages_from_downstream: MessagesAggregator, messages_from_upstream: MessagesAggregator, check_on_drop: bool, + intercept_messages: Vec, +} + +#[derive(Debug, Clone)] +pub struct InterceptMessage { + direction: MessageDirection, + expected_message_type: MsgType, + response_message: PoolMessages<'static>, + response_message_type: MsgType, + break_on: bool, +} + +impl InterceptMessage { + pub fn new( + direction: MessageDirection, + expected_message_type: MsgType, + response_message: PoolMessages<'static>, + response_message_type: MsgType, + break_on: bool, + ) -> Self { + Self { + direction, + expected_message_type, + response_message, + response_message_type, + break_on, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum MessageDirection { + ToDownstream, + ToUpstream, } impl Sniffer { @@ -62,6 +106,7 @@ impl Sniffer { listening_address: SocketAddr, upstream_address: SocketAddr, check_on_drop: bool, + intercept_messages: Option>, ) -> Self { Self { identifier, @@ -70,6 +115,7 @@ impl Sniffer { messages_from_downstream: MessagesAggregator::new(), messages_from_upstream: MessagesAggregator::new(), check_on_drop, + intercept_messages: intercept_messages.unwrap_or_default(), } } @@ -91,10 +137,13 @@ impl Sniffer { .expect("Failed to create upstream"); let downstream_messages = self.messages_from_downstream.clone(); let upstream_messages = self.messages_from_upstream.clone(); + let intercept_messages = self.intercept_messages.clone(); let _ = select! { - r = Self::recv_from_down_send_to_up(downstream_receiver, upstream_sender, downstream_messages) => r, - r = Self::recv_from_up_send_to_down(upstream_receiver, downstream_sender, upstream_messages) => r, + r = Self::recv_from_down_send_to_up(downstream_receiver, upstream_sender, downstream_messages, intercept_messages.clone()) => r, + r = Self::recv_from_up_send_to_down(upstream_receiver, downstream_sender, upstream_messages, intercept_messages) => r, }; + // wait a bit so we dont drop the sniffer before the test has finished + sleep(std::time::Duration::from_secs(1)).await; } /// Returns the oldest message sent by downstream. @@ -169,9 +218,36 @@ impl Sniffer { recv: Receiver, send: Sender, downstream_messages: MessagesAggregator, + intercept_messages: Vec, ) -> Result<(), SnifferError> { while let Ok(mut frame) = recv.recv().await { let (msg_type, msg) = Self::message_from_frame(&mut frame); + for intercept_message in intercept_messages.iter() { + if intercept_message.direction == MessageDirection::ToUpstream + && intercept_message.expected_message_type == msg_type + { + let extension_type = 0; + let channel_msg = false; + let frame = StandardEitherFrame::>::Sv2( + Sv2Frame::from_message( + intercept_message.response_message.clone(), + intercept_message.response_message_type, + extension_type, + channel_msg, + ) + .expect("Failed to create the frame"), + ); + downstream_messages + .add_message(msg_type, intercept_message.response_message.clone()); + let _ = send.send(frame).await; + if intercept_message.break_on { + return Err(SnifferError::MessageInterrupted); + } else { + continue; + } + } + } + downstream_messages.add_message(msg_type, msg); if send.send(frame).await.is_err() { return Err(SnifferError::UpstreamClosed); @@ -184,13 +260,39 @@ impl Sniffer { recv: Receiver, send: Sender, upstream_messages: MessagesAggregator, + intercept_messages: Vec, ) -> Result<(), SnifferError> { while let Ok(mut frame) = recv.recv().await { let (msg_type, msg) = Self::message_from_frame(&mut frame); - upstream_messages.add_message(msg_type, msg); + for intercept_message in intercept_messages.iter() { + if intercept_message.direction == MessageDirection::ToDownstream + && intercept_message.expected_message_type == msg_type + { + let extension_type = 0; + let channel_msg = false; + let frame = StandardEitherFrame::>::Sv2( + Sv2Frame::from_message( + intercept_message.response_message.clone(), + intercept_message.response_message_type, + extension_type, + channel_msg, + ) + .expect("Failed to create the frame"), + ); + upstream_messages + .add_message(msg_type, intercept_message.response_message.clone()); + let _ = send.send(frame).await; + if intercept_message.break_on { + return Err(SnifferError::MessageInterrupted); + } else { + continue; + } + } + } if send.send(frame).await.is_err() { return Err(SnifferError::DownstreamClosed); }; + upstream_messages.add_message(msg_type, msg); } Err(SnifferError::UpstreamClosed) } diff --git a/roles/tests-integration/tests/pool_integration.rs b/roles/tests-integration/tests/pool_integration.rs index 942ca1c50..6230d4324 100644 --- a/roles/tests-integration/tests/pool_integration.rs +++ b/roles/tests-integration/tests/pool_integration.rs @@ -1,7 +1,11 @@ mod common; +use std::convert::TryInto; + +use common::{InterceptMessage, MessageDirection}; +use const_sv2::MESSAGE_TYPE_SETUP_CONNECTION_ERROR; use roles_logic_sv2::{ - common_messages_sv2::{Protocol, SetupConnection}, + common_messages_sv2::{Protocol, SetupConnection, SetupConnectionError}, parsers::{CommonMessages, PoolMessages, TemplateDistribution}, }; @@ -23,6 +27,7 @@ async fn success_pool_template_provider_connection() { sniffer_addr, tp_addr, sniffer_check_on_drop, + None, ) .await; let _ = common::start_pool(Some(pool_addr), Some(sniffer_addr)).await; @@ -53,3 +58,39 @@ async fn success_pool_template_provider_connection() { assert_tp_message!(&sniffer.next_message_from_upstream(), NewTemplate); assert_tp_message!(sniffer.next_message_from_upstream(), SetNewPrevHash); } + +#[tokio::test] +async fn test_sniffer_interrupter() { + let sniffer_addr = common::get_available_address(); + let tp_addr = common::get_available_address(); + let pool_addr = common::get_available_address(); + let _tp = common::start_template_provider(tp_addr.port()).await; + use const_sv2::MESSAGE_TYPE_SETUP_CONNECTION_SUCCESS; + let message = + PoolMessages::Common(CommonMessages::SetupConnectionError(SetupConnectionError { + flags: 0, + error_code: "unsupported-feature-flags" + .to_string() + .into_bytes() + .try_into() + .unwrap(), + })); + let interrupt_msgs = InterceptMessage::new( + MessageDirection::ToDownstream, + MESSAGE_TYPE_SETUP_CONNECTION_SUCCESS, + message, + MESSAGE_TYPE_SETUP_CONNECTION_ERROR, + true, + ); + let sniffer = common::start_sniffer( + "1".to_string(), + sniffer_addr, + tp_addr, + false, + Some(vec![interrupt_msgs]), + ) + .await; + let _ = common::start_pool(Some(pool_addr), Some(sniffer_addr)).await; + assert_common_message!(&sniffer.next_message_from_downstream(), SetupConnection); + assert_common_message!(&sniffer.next_message_from_upstream(), SetupConnectionError); +}