Skip to content

Commit

Permalink
Merge pull request stratum-mining#1228 from jbesraa/2024-10-18-allow-…
Browse files Browse the repository at this point in the history
…sniffer-to-alter-msgs

Allow altering messages exchanged by roles in test env
  • Loading branch information
plebhash authored Nov 21, 2024
2 parents e650561 + bd27d9b commit 67a3f00
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 15 deletions.
23 changes: 17 additions & 6 deletions roles/pool/src/lib/template_receiver/setup_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ use super::super::{
};
use async_channel::{Receiver, Sender};
use roles_logic_sv2::{
common_messages_sv2::{Protocol, SetupConnection},
common_messages_sv2::{Protocol, SetupConnection, SetupConnectionError},
errors::Error,
handlers::common::{ParseUpstreamCommonMessages, SendTo},
parsers::PoolMessages,
parsers::{CommonMessages, PoolMessages},
routing_logic::{CommonRoutingLogic, NoRouting},
utils::Mutex,
};
Expand Down Expand Up @@ -79,12 +79,23 @@ impl ParseUpstreamCommonMessages<NoRouting> for SetupConnectionHandler {

fn handle_setup_connection_error(
&mut self,
_: roles_logic_sv2::common_messages_sv2::SetupConnectionError,
m: SetupConnectionError,
) -> Result<roles_logic_sv2::handlers::common::SendTo, Error> {
//return error result
todo!()
let flags = m.flags;
let message = SetupConnectionError {
flags,
// this error code is currently a hack because there is a lifetime problem with
// `error_code`.
error_code: "unsupported-feature-flags"
.to_string()
.into_bytes()
.try_into()
.unwrap(),
};
Ok(SendTo::RelayNewMessage(
CommonMessages::SetupConnectionError(message),
))
}

fn handle_channel_endpoint_changed(
&mut self,
_: roles_logic_sv2::common_messages_sv2::ChannelEndpointChanged,
Expand Down
11 changes: 10 additions & 1 deletion roles/tests-integration/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use once_cell::sync::Lazy;
use pool_sv2::PoolSv2;
use rand::{thread_rng, Rng};
use sniffer::Sniffer;
pub use sniffer::{InterceptMessage, MessageDirection};
use std::{
collections::HashSet,
convert::{TryFrom, TryInto},
Expand Down Expand Up @@ -199,8 +200,16 @@ pub async fn start_sniffer(
listening_address: SocketAddr,
upstream: SocketAddr,
check_on_drop: bool,
intercept_message: Option<Vec<InterceptMessage>>,
) -> Sniffer {
let sniffer = Sniffer::new(identifier, listening_address, upstream, check_on_drop).await;
let sniffer = Sniffer::new(
identifier,
listening_address,
upstream,
check_on_drop,
intercept_message,
)
.await;
let sniffer_clone = sniffer.clone();
tokio::spawn(async move {
sniffer_clone.start().await;
Expand Down
116 changes: 109 additions & 7 deletions roles/tests-integration/tests/common/sniffer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use async_channel::{Receiver, Sender};
use codec_sv2::{
framing_sv2::framing::Frame, HandshakeRole, Initiator, Responder, StandardEitherFrame,
framing_sv2::framing::Frame, HandshakeRole, Initiator, Responder, StandardEitherFrame, Sv2Frame,
};
use key_utils::{Secp256k1PublicKey, Secp256k1SecretKey};
use network_helpers_sv2::noise_connection_tokio::Connection;
Expand All @@ -13,15 +13,16 @@ use roles_logic_sv2::{
IdentifyTransactionsSuccess, ProvideMissingTransactions,
ProvideMissingTransactionsSuccess, SubmitSolution,
},
TemplateDistribution,
TemplateDistribution::CoinbaseOutputDataSize,
PoolMessages,
TemplateDistribution::{self, CoinbaseOutputDataSize},
},
utils::Mutex,
};
use std::{collections::VecDeque, convert::TryInto, net::SocketAddr, sync::Arc};
use tokio::{
net::{TcpListener, TcpStream},
select,
time::sleep,
};
type MessageFrame = StandardEitherFrame<AnyMessage<'static>>;
type MsgType = u8;
Expand All @@ -30,6 +31,7 @@ type MsgType = u8;
enum SnifferError {
DownstreamClosed,
UpstreamClosed,
MessageInterrupted,
}

/// Allows to intercept messages sent between two roles.
Expand All @@ -42,7 +44,15 @@ enum SnifferError {
/// forwarded to the downstream role. Both `messages_from_downstream` and `messages_from_upstream`
/// can be accessed as FIFO queues.
///
/// It is useful for testing purposes, as it allows to assert that the roles have sent specific
/// In order to alter the messages sent between the roles, the [`Sniffer::intercept_messages`]
/// field can be used. It will look for the [`InterceptMessage::expected_message_type`] in the
/// specified [`InterceptMessage::direction`] and replace it with
/// [`InterceptMessage::response_message`].
///
/// If `break_on` is set to `true`, the [`Sniffer`] will stop the communication after sending the
/// response message.
///
/// Can be useful for testing purposes, as it allows to assert that the roles have sent specific
/// messages in a specific order and to inspect the messages details.
#[derive(Debug, Clone)]
pub struct Sniffer {
Expand All @@ -52,6 +62,40 @@ pub struct Sniffer {
messages_from_downstream: MessagesAggregator,
messages_from_upstream: MessagesAggregator,
check_on_drop: bool,
intercept_messages: Vec<InterceptMessage>,
}

#[derive(Debug, Clone)]
pub struct InterceptMessage {
direction: MessageDirection,
expected_message_type: MsgType,
response_message: PoolMessages<'static>,
response_message_type: MsgType,
break_on: bool,
}

impl InterceptMessage {
pub fn new(
direction: MessageDirection,
expected_message_type: MsgType,
response_message: PoolMessages<'static>,
response_message_type: MsgType,
break_on: bool,
) -> Self {
Self {
direction,
expected_message_type,
response_message,
response_message_type,
break_on,
}
}
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MessageDirection {
ToDownstream,
ToUpstream,
}

impl Sniffer {
Expand All @@ -62,6 +106,7 @@ impl Sniffer {
listening_address: SocketAddr,
upstream_address: SocketAddr,
check_on_drop: bool,
intercept_messages: Option<Vec<InterceptMessage>>,
) -> Self {
Self {
identifier,
Expand All @@ -70,6 +115,7 @@ impl Sniffer {
messages_from_downstream: MessagesAggregator::new(),
messages_from_upstream: MessagesAggregator::new(),
check_on_drop,
intercept_messages: intercept_messages.unwrap_or_default(),
}
}

Expand All @@ -91,10 +137,13 @@ impl Sniffer {
.expect("Failed to create upstream");
let downstream_messages = self.messages_from_downstream.clone();
let upstream_messages = self.messages_from_upstream.clone();
let intercept_messages = self.intercept_messages.clone();
let _ = select! {
r = Self::recv_from_down_send_to_up(downstream_receiver, upstream_sender, downstream_messages) => r,
r = Self::recv_from_up_send_to_down(upstream_receiver, downstream_sender, upstream_messages) => r,
r = Self::recv_from_down_send_to_up(downstream_receiver, upstream_sender, downstream_messages, intercept_messages.clone()) => r,
r = Self::recv_from_up_send_to_down(upstream_receiver, downstream_sender, upstream_messages, intercept_messages) => r,
};
// wait a bit so we dont drop the sniffer before the test has finished
sleep(std::time::Duration::from_secs(1)).await;
}

/// Returns the oldest message sent by downstream.
Expand Down Expand Up @@ -169,9 +218,36 @@ impl Sniffer {
recv: Receiver<MessageFrame>,
send: Sender<MessageFrame>,
downstream_messages: MessagesAggregator,
intercept_messages: Vec<InterceptMessage>,
) -> Result<(), SnifferError> {
while let Ok(mut frame) = recv.recv().await {
let (msg_type, msg) = Self::message_from_frame(&mut frame);
for intercept_message in intercept_messages.iter() {
if intercept_message.direction == MessageDirection::ToUpstream
&& intercept_message.expected_message_type == msg_type
{
let extension_type = 0;
let channel_msg = false;
let frame = StandardEitherFrame::<AnyMessage<'_>>::Sv2(
Sv2Frame::from_message(
intercept_message.response_message.clone(),
intercept_message.response_message_type,
extension_type,
channel_msg,
)
.expect("Failed to create the frame"),
);
downstream_messages
.add_message(msg_type, intercept_message.response_message.clone());
let _ = send.send(frame).await;
if intercept_message.break_on {
return Err(SnifferError::MessageInterrupted);
} else {
continue;
}
}
}

downstream_messages.add_message(msg_type, msg);
if send.send(frame).await.is_err() {
return Err(SnifferError::UpstreamClosed);
Expand All @@ -184,13 +260,39 @@ impl Sniffer {
recv: Receiver<MessageFrame>,
send: Sender<MessageFrame>,
upstream_messages: MessagesAggregator,
intercept_messages: Vec<InterceptMessage>,
) -> Result<(), SnifferError> {
while let Ok(mut frame) = recv.recv().await {
let (msg_type, msg) = Self::message_from_frame(&mut frame);
upstream_messages.add_message(msg_type, msg);
for intercept_message in intercept_messages.iter() {
if intercept_message.direction == MessageDirection::ToDownstream
&& intercept_message.expected_message_type == msg_type
{
let extension_type = 0;
let channel_msg = false;
let frame = StandardEitherFrame::<AnyMessage<'_>>::Sv2(
Sv2Frame::from_message(
intercept_message.response_message.clone(),
intercept_message.response_message_type,
extension_type,
channel_msg,
)
.expect("Failed to create the frame"),
);
upstream_messages
.add_message(msg_type, intercept_message.response_message.clone());
let _ = send.send(frame).await;
if intercept_message.break_on {
return Err(SnifferError::MessageInterrupted);
} else {
continue;
}
}
}
if send.send(frame).await.is_err() {
return Err(SnifferError::DownstreamClosed);
};
upstream_messages.add_message(msg_type, msg);
}
Err(SnifferError::UpstreamClosed)
}
Expand Down
43 changes: 42 additions & 1 deletion roles/tests-integration/tests/pool_integration.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
mod common;

use std::convert::TryInto;

use common::{InterceptMessage, MessageDirection};
use const_sv2::MESSAGE_TYPE_SETUP_CONNECTION_ERROR;
use roles_logic_sv2::{
common_messages_sv2::{Protocol, SetupConnection},
common_messages_sv2::{Protocol, SetupConnection, SetupConnectionError},
parsers::{CommonMessages, PoolMessages, TemplateDistribution},
};

Expand All @@ -23,6 +27,7 @@ async fn success_pool_template_provider_connection() {
sniffer_addr,
tp_addr,
sniffer_check_on_drop,
None,
)
.await;
let _ = common::start_pool(Some(pool_addr), Some(sniffer_addr)).await;
Expand Down Expand Up @@ -53,3 +58,39 @@ async fn success_pool_template_provider_connection() {
assert_tp_message!(&sniffer.next_message_from_upstream(), NewTemplate);
assert_tp_message!(sniffer.next_message_from_upstream(), SetNewPrevHash);
}

#[tokio::test]
async fn test_sniffer_interrupter() {
let sniffer_addr = common::get_available_address();
let tp_addr = common::get_available_address();
let pool_addr = common::get_available_address();
let _tp = common::start_template_provider(tp_addr.port()).await;
use const_sv2::MESSAGE_TYPE_SETUP_CONNECTION_SUCCESS;
let message =
PoolMessages::Common(CommonMessages::SetupConnectionError(SetupConnectionError {
flags: 0,
error_code: "unsupported-feature-flags"
.to_string()
.into_bytes()
.try_into()
.unwrap(),
}));
let interrupt_msgs = InterceptMessage::new(
MessageDirection::ToDownstream,
MESSAGE_TYPE_SETUP_CONNECTION_SUCCESS,
message,
MESSAGE_TYPE_SETUP_CONNECTION_ERROR,
true,
);
let sniffer = common::start_sniffer(
"1".to_string(),
sniffer_addr,
tp_addr,
false,
Some(vec![interrupt_msgs]),
)
.await;
let _ = common::start_pool(Some(pool_addr), Some(sniffer_addr)).await;
assert_common_message!(&sniffer.next_message_from_downstream(), SetupConnection);
assert_common_message!(&sniffer.next_message_from_upstream(), SetupConnectionError);
}

0 comments on commit 67a3f00

Please sign in to comment.