Skip to content

Commit

Permalink
feat(p2p_stream): add response timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
CHr15F0x committed Feb 19, 2025
1 parent 9809a93 commit 3e6ac7f
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 25 deletions.
1 change: 1 addition & 0 deletions crates/p2p/src/behaviour/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ impl Builder {

let p2p_stream_cfg = p2p_stream::Config::default()
.stream_timeout(cfg.stream_timeout)
.response_timeout(cfg.response_timeout)
.max_concurrent_streams(cfg.max_concurrent_streams);

let header_sync = header_sync
Expand Down
2 changes: 1 addition & 1 deletion crates/p2p_stream/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ async-trait = { workspace = true }
futures = { workspace = true }
futures-bounded = { workspace = true }
libp2p = { workspace = true, features = ["identify", "noise", "tcp", "tokio"] }
tokio = { workspace = true, features = ["macros", "time"] }
tracing = { workspace = true }
void = { workspace = true }

Expand All @@ -26,5 +27,4 @@ libp2p-plaintext = { workspace = true }
libp2p-swarm-test = { workspace = true }
rstest = { workspace = true }
test-log = { workspace = true, features = ["trace"] }
tokio = { workspace = true, features = ["macros", "time"] }
tracing-subscriber = { workspace = true, features = ["env-filter"] }
58 changes: 52 additions & 6 deletions crates/p2p_stream/src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ where
inbound_request_id: Arc<AtomicU64>,

worker_streams: futures_bounded::FuturesMap<RequestId, Result<Event<TCodec>, io::Error>>,

response_timeout: Duration,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
Expand All @@ -110,6 +112,7 @@ where
inbound_protocols: Vec<TCodec::Protocol>,
codec: TCodec,
substream_timeout: Duration,
response_timeout: Duration,
inbound_request_id: Arc<AtomicU64>,
max_concurrent_streams: usize,
) -> Self {
Expand All @@ -130,6 +133,7 @@ where
substream_timeout,
max_concurrent_streams,
),
response_timeout,
}
}

Expand All @@ -148,6 +152,7 @@ where
let mut codec = self.codec.clone();
let request_id = self.next_inbound_request_id();
let mut sender = self.inbound_sender.clone();
let response_timeout = self.response_timeout;

let recv_request_then_fwd_outgoing_responses = async move {
let (rs_send, mut rs_recv) = mpsc::channel(0);
Expand All @@ -163,8 +168,17 @@ where

// Keep on forwarding until the channel is closed
while let Some(response) = rs_recv.next().await {
let write = codec.write_response(&protocol, &mut stream, response);
write.await?;
tokio::time::timeout(
response_timeout,
codec.write_response(&protocol, &mut stream, response),
)
.await
.map_err(|_| {
io::Error::new(
io::ErrorKind::TimedOut,
format!("Timeout writing response to stream for request id {request_id}"),
)
})??;
}

stream.close().await?;
Expand Down Expand Up @@ -202,6 +216,7 @@ where
let (mut rs_send, rs_recv) = mpsc::channel(0);

let mut sender = self.outbound_sender.clone();
let response_timeout = self.response_timeout;

let send_req_then_fwd_incoming_responses = async move {
let write = codec.write_request(&protocol, &mut stream, message.request);
Expand All @@ -217,17 +232,34 @@ where

// Keep on forwarding until the channel is closed or error occurs
loop {
match codec.read_response(&protocol, &mut stream).await {
Ok(response) => {
match tokio::time::timeout(
response_timeout,
codec.read_response(&protocol, &mut stream),
)
.await
{
Err(_) => {
rs_send
.send(Err(io::Error::new(
io::ErrorKind::TimedOut,
format!(
"Timeout reading response from stream for request id \
{request_id}"
),
)))
.await
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
}
Ok(Ok(response)) => {
rs_send
.send(Ok(response))
.await
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
}
// The stream is closed, there's nothing more to receive
Err(error) if error.kind() == io::ErrorKind::UnexpectedEof => break,
Ok(Err(error)) if error.kind() == io::ErrorKind::UnexpectedEof => break,
// An error occurred, propagate it
Err(error) => {
Ok(Err(error)) => {
let error_clone = io::Error::new(error.kind(), error.to_string());
rs_send
.send(Err(error_clone))
Expand Down Expand Up @@ -443,6 +475,13 @@ where
Poll::Ready((_, Ok(Ok(event)))) => {
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event));
}
Poll::Ready((RequestId::Inbound(id), Ok(Err(e))))
if e.kind() == std::io::ErrorKind::TimedOut =>
{
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
Event::InboundTimeout(id),
));
}
Poll::Ready((RequestId::Inbound(id), Ok(Err(e)))) => {
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
Event::InboundStreamFailed {
Expand All @@ -451,6 +490,13 @@ where
},
));
}
Poll::Ready((RequestId::Outbound(id), Ok(Err(e))))
if e.kind() == std::io::ErrorKind::TimedOut =>
{
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
Event::OutboundTimeout(id),
));
}
Poll::Ready((RequestId::Outbound(id), Ok(Err(e)))) => {
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
Event::OutboundStreamFailed {
Expand Down
2 changes: 2 additions & 0 deletions crates/p2p_stream/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,7 @@ where
self.protocols.clone(),
self.codec.clone(),
self.config.stream_timeout,
self.config.response_timeout,
self.next_inbound_request_id.clone(),
self.config.max_concurrent_streams,
);
Expand Down Expand Up @@ -645,6 +646,7 @@ where
self.protocols.clone(),
self.codec.clone(),
self.config.stream_timeout,
self.config.response_timeout,
self.next_inbound_request_id.clone(),
self.config.max_concurrent_streams,
);
Expand Down
142 changes: 131 additions & 11 deletions crates/p2p_stream/tests/error_reporting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub mod utils;

use utils::{
new_swarm,
new_swarm_with_timeout,
new_swarm_with_timeouts,
wait_inbound_failure,
wait_inbound_request,
wait_inbound_response_stream_closed,
Expand Down Expand Up @@ -123,10 +123,14 @@ async fn report_outbound_failure_on_write_request_failure() {
}

#[test_log::test(tokio::test)]
async fn report_outbound_timeout_on_read_response_timeout() {
// `swarm1` needs to have a bigger timeout to avoid racing
let (peer1_id, mut swarm1) = new_swarm_with_timeout(Duration::from_millis(200));
let (peer2_id, mut swarm2) = new_swarm_with_timeout(Duration::from_millis(100));
async fn report_outbound_timeout_on_response_stream_timeout() {
// `swarm1` needs to have a bigger stream timeout to avoid racing
// response timeouts are set to a larger value to make sure we trigger a
// stream timeout
let (peer1_id, mut swarm1) =
new_swarm_with_timeouts(Duration::from_millis(200), Duration::from_millis(500));
let (peer2_id, mut swarm2) =
new_swarm_with_timeouts(Duration::from_millis(100), Duration::from_millis(500));

swarm1.listen().with_memory_addr_external().await;
swarm2.connect(&mut swarm1).await;
Expand Down Expand Up @@ -168,6 +172,58 @@ async fn report_outbound_timeout_on_read_response_timeout() {
tokio::join!(server_task, client_task);
}

#[test_log::test(tokio::test)]
async fn report_outbound_timeout_on_read_response_timeout() {
// `swarm1` needs to have a bigger timeout to avoid racing
// stream timeouts are set to a larger value to make sure we trigger a response
// timeout
let (peer1_id, mut swarm1) =
new_swarm_with_timeouts(Duration::from_millis(500), Duration::from_millis(200));
let (peer2_id, mut swarm2) =
new_swarm_with_timeouts(Duration::from_millis(500), Duration::from_millis(100));

swarm1.listen().with_memory_addr_external().await;
swarm2.connect(&mut swarm1).await;

let server_task = async move {
let (peer, req_id, action, mut resp_tx) = wait_inbound_request(&mut swarm1).await.unwrap();
assert_eq!(peer, peer2_id);
assert_eq!(action, Action::TimeoutOnReadResponse);

resp_tx.send(Action::TimeoutOnReadResponse).await.unwrap();

let (peer, req_id_done, error) = wait_inbound_failure(&mut swarm1).await.unwrap();
assert_eq!(peer, peer2_id);
assert_eq!(req_id_done, req_id);
assert!(matches!(error, InboundFailure::Timeout));
};

let client_task = async move {
let req_id = swarm2
.behaviour_mut()
.send_request(&peer1_id, Action::TimeoutOnReadResponse);

let (peer, req_id_done, mut resp_rx) =
wait_outbound_request_sent_awaiting_responses(&mut swarm2)
.await
.unwrap();
assert_eq!(peer, peer1_id);
assert_eq!(req_id_done, req_id);

assert!(
matches!(resp_rx.next().await, Some(Err(error)) if error.kind() == io::ErrorKind::TimedOut)
);

let (peer, req_id_done, error) = wait_outbound_failure(&mut swarm2).await.unwrap();
assert_eq!(peer, peer1_id);
assert_eq!(req_id_done, req_id);
assert!(matches!(error, OutboundFailure::Timeout));
};

// Make sure both run to completion
tokio::join!(server_task, client_task);
}

#[test_log::test(tokio::test)]
async fn report_inbound_closure_on_read_request_failure() {
let (peer1_id, mut swarm1) = new_swarm();
Expand Down Expand Up @@ -267,11 +323,71 @@ async fn report_inbound_failure_on_write_response_failure() {
tokio::join!(client_task, server_task);
}

#[test_log::test(tokio::test)]
async fn report_inbound_timeout_on_response_stream_timeout() {
// `swarm2` needs to have a bigger timeout to avoid racing
// response timeouts are set to a larger value to make sure we trigger a
// stream timeout
let (peer1_id, mut swarm1) =
new_swarm_with_timeouts(Duration::from_millis(100), Duration::from_millis(500));
let (peer2_id, mut swarm2) =
new_swarm_with_timeouts(Duration::from_millis(200), Duration::from_millis(500));

swarm1.listen().with_memory_addr_external().await;
swarm2.connect(&mut swarm1).await;

let server_task = async move {
let (peer, req_id, action, mut resp_channel) =
wait_inbound_request(&mut swarm1).await.unwrap();
assert_eq!(peer, peer2_id);
assert_eq!(action, Action::TimeoutOnWriteResponse);

resp_channel
.send(Action::TimeoutOnWriteResponse)
.await
.unwrap();

let (peer, req_id_done, error) = wait_inbound_failure(&mut swarm1).await.unwrap();
assert_eq!(peer, peer2_id);
assert_eq!(req_id_done, req_id);
assert!(matches!(error, InboundFailure::Timeout));
};

let client_task = async move {
let req_id = swarm2
.behaviour_mut()
.send_request(&peer1_id, Action::TimeoutOnWriteResponse);

let (peer, req_id_done, mut resp_channel) =
wait_outbound_request_sent_awaiting_responses(&mut swarm2)
.await
.unwrap();
assert_eq!(peer, peer1_id);
assert_eq!(req_id_done, req_id);

assert!(resp_channel.next().await.is_none());

let (peer, req_id_done) = wait_inbound_response_stream_closed(&mut swarm2)
.await
.unwrap();

assert_eq!(peer, peer1_id);
assert_eq!(req_id_done, req_id);
};

// Make sure both run to completion
tokio::join!(client_task, server_task);
}

#[test_log::test(tokio::test)]
async fn report_inbound_timeout_on_write_response_timeout() {
// `swarm2` needs to have a bigger timeout to avoid racing
let (peer1_id, mut swarm1) = new_swarm_with_timeout(Duration::from_millis(100));
let (peer2_id, mut swarm2) = new_swarm_with_timeout(Duration::from_millis(200));
// stream timeouts are set to a larger value to make sure we trigger a
// response timeout
let (peer1_id, mut swarm1) =
new_swarm_with_timeouts(Duration::from_millis(500), Duration::from_millis(100));
let (peer2_id, mut swarm2) =
new_swarm_with_timeouts(Duration::from_millis(500), Duration::from_millis(200));

swarm1.listen().with_memory_addr_external().await;
swarm2.connect(&mut swarm1).await;
Expand Down Expand Up @@ -322,8 +438,10 @@ async fn report_inbound_timeout_on_write_response_timeout() {
#[test_log::test(tokio::test)]
async fn report_outbound_timeout_on_write_request_timeout() {
// `swarm1` needs to have a bigger timeout to avoid racing
let (peer1_id, mut swarm1) = new_swarm_with_timeout(Duration::from_millis(200));
let (_peer2_id, mut swarm2) = new_swarm_with_timeout(Duration::from_millis(100));
let (peer1_id, mut swarm1) =
new_swarm_with_timeouts(Duration::from_millis(200), Duration::from_millis(200));
let (_peer2_id, mut swarm2) =
new_swarm_with_timeouts(Duration::from_millis(100), Duration::from_millis(100));

swarm1.listen().with_memory_addr_external().await;
swarm2.connect(&mut swarm1).await;
Expand Down Expand Up @@ -357,8 +475,10 @@ async fn report_outbound_timeout_on_write_request_timeout() {
#[test_log::test(tokio::test)]
async fn report_outbound_timeout_on_read_request_timeout() {
// `swarm2` needs to have a bigger timeout to avoid racing
let (peer1_id, mut swarm1) = new_swarm_with_timeout(Duration::from_millis(200));
let (_peer2_id, mut swarm2) = new_swarm_with_timeout(Duration::from_millis(100));
let (peer1_id, mut swarm1) =
new_swarm_with_timeouts(Duration::from_millis(200), Duration::from_millis(200));
let (_peer2_id, mut swarm2) =
new_swarm_with_timeouts(Duration::from_millis(100), Duration::from_millis(100));

swarm1.listen().with_memory_addr_external().await;
swarm2.connect(&mut swarm1).await;
Expand Down
8 changes: 5 additions & 3 deletions crates/p2p_stream/tests/sanity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use rstest::rstest;
pub mod utils;

use utils::{
new_swarm_with_timeout,
new_swarm_with_timeouts,
wait_inbound_request,
wait_inbound_response_stream_closed,
wait_outbound_request_sent_awaiting_responses,
Expand All @@ -35,8 +35,10 @@ struct Scenario {

// peer1 is the server, peer2 is the client
async fn setup() -> (PeerId, TestSwarm, PeerId, TestSwarm) {
let (srv_peer_id, mut srv_swarm) = new_swarm_with_timeout(Duration::from_secs(10));
let (cli_peer_id, mut cli_swarm) = new_swarm_with_timeout(Duration::from_secs(10));
let (srv_peer_id, mut srv_swarm) =
new_swarm_with_timeouts(Duration::from_secs(10), Duration::from_secs(10));
let (cli_peer_id, mut cli_swarm) =
new_swarm_with_timeouts(Duration::from_secs(10), Duration::from_secs(10));

srv_swarm.listen().with_memory_addr_external().await;
cli_swarm.connect(&mut srv_swarm).await;
Expand Down
Loading

0 comments on commit 3e6ac7f

Please sign in to comment.