From dc70505d7ef9dc609d07afd946aa528e6e14454f Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 6 Mar 2025 01:07:54 -0500 Subject: [PATCH 001/206] Add initial setup for noise wrapper --- src/crypto/handshake.rs | 5 +++-- src/lib.rs | 1 + src/protocol.rs | 7 ++++--- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index 64db407..ee2d941 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -69,6 +69,7 @@ impl HandshakeResult { } } +#[derive(Debug)] pub(crate) struct Handshake { result: HandshakeResult, state: HandshakeState, @@ -170,11 +171,11 @@ impl Handshake { Ok(tx_buf) } - pub(crate) fn into_result(self) -> Result { + pub(crate) fn into_result(&self) -> Result<&HandshakeResult> { if !self.complete() { Err(Error::new(ErrorKind::Other, "Handshake is not complete")) } else { - Ok(self.result) + Ok(&self.result) } } } diff --git a/src/lib.rs b/src/lib.rs index 531a068..0e5f037 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -123,6 +123,7 @@ mod constants; mod crypto; mod duplex; mod message; +mod noise; mod protocol; mod reader; mod util; diff --git a/src/protocol.rs b/src/protocol.rs index 7b8d468..3e8a2b5 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -10,6 +10,7 @@ use std::io::{self, Error, ErrorKind, Result}; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; +use tracing::trace; use crate::channels::{Channel, ChannelMap}; use crate::constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}; @@ -286,8 +287,8 @@ where } fn init(&mut self) -> Result<()> { - tracing::debug!( - "protocol init, state {:?}, options {:?}", + trace!( + "protocol Init, state {:?}, options {:?}", self.state, self.options ); @@ -479,7 +480,7 @@ where self.state = State::Established; } // Store handshake result - self.handshake = Some(handshake_result); + self.handshake = Some(handshake_result.clone()); } Ok(()) } From 640efc9944c93dc5f7a6e0930c066ee20a6d7e74 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Fri, 7 Mar 2025 15:01:28 -0500 Subject: [PATCH 002/206] use futures --- Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Cargo.toml b/Cargo.toml index d77679f..a5ac273 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ futures-lite = "1" sha2 = "0.10" curve25519-dalek = "4" crypto_secretstream = "0.2" +futures = "0.3.13" [dependencies.hypercore] version = "0.14.0" From 11afd0be5df208a061426dbb7dab5ce048952ff9 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Fri, 7 Mar 2025 15:01:47 -0500 Subject: [PATCH 003/206] fix logger --- examples/replication.rs | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/examples/replication.rs b/examples/replication.rs index bf65b72..459df9f 100644 --- a/examples/replication.rs +++ b/examples/replication.rs @@ -3,23 +3,23 @@ use async_std::net::{TcpListener, TcpStream}; use async_std::prelude::*; use async_std::sync::{Arc, Mutex}; use async_std::task; -use env_logger::Env; use futures_lite::stream::StreamExt; use hypercore::{ Hypercore, HypercoreBuilder, PartialKeypair, RequestBlock, RequestUpgrade, Storage, VerifyingKey, }; -use log::*; use std::collections::HashMap; use std::convert::TryInto; use std::env; use std::fmt::Debug; +use std::sync::OnceLock; +use tracing::{error, info}; use hypercore_protocol::schema::*; use hypercore_protocol::{discovery_key, Channel, Event, Message, ProtocolBuilder}; fn main() { - init_logger(); + log(); if env::args().count() < 3 { usage(); } @@ -93,8 +93,8 @@ async fn onconnection( let mut protocol = ProtocolBuilder::new(is_initiator).connect(stream); info!("protocol created, polling for next()"); while let Some(event) = protocol.next().await { - let event = event?; info!("protocol event {:?}", event); + let event = event?; match event { Event::Handshake(_) => { if is_initiator { @@ -414,9 +414,21 @@ async fn onmessage( Ok(()) } -/// Init EnvLogger, logging info, warn and error messages to stdout. -pub fn init_logger() { - env_logger::from_env(Env::default().default_filter_or("info")).init(); +#[allow(unused)] +pub fn log() { + use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; + static START_LOGS: OnceLock<()> = OnceLock::new(); + START_LOGS.get_or_init(|| { + tracing_subscriber::fmt() + .with_target(true) + .with_line_number(true) + // print when instrumented funtion enters + .with_span_events(FmtSpan::ENTER | FmtSpan::EXIT) + .with_file(true) + .with_env_filter(EnvFilter::from_default_env()) // Reads `RUST_LOG` environment variable + .without_time() + .init(); + }); } /// Log a result if it's an error. From 5f7d10b9dc8908d952bb44081909cc84ae7912a7 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Fri, 7 Mar 2025 15:02:09 -0500 Subject: [PATCH 004/206] Add test_utils --- src/lib.rs | 2 ++ src/test_utils.rs | 91 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+) create mode 100644 src/test_utils.rs diff --git a/src/lib.rs b/src/lib.rs index 0e5f037..f12bcdb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -126,6 +126,8 @@ mod message; mod noise; mod protocol; mod reader; +#[cfg(test)] +mod test_utils; mod util; mod writer; diff --git a/src/test_utils.rs b/src/test_utils.rs new file mode 100644 index 0000000..9eb986c --- /dev/null +++ b/src/test_utils.rs @@ -0,0 +1,91 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use async_channel::{unbounded, Receiver, SendError, Sender}; +use futures::{Sink, SinkExt, Stream, StreamExt}; + +#[derive(Debug)] +pub struct Io { + receiver: Receiver>, + sender: Sender>, +} + +impl Default for Io { + fn default() -> Self { + let (sender, receiver) = unbounded(); + Self { sender, receiver } + } +} + +impl Stream for Io { + type Item = Vec; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.receiver).poll_next(cx) + } +} + +impl Sink> for Io { + type Error = SendError>; + + fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: Vec) -> Result<(), Self::Error> { + let _ = self.sender.try_send(item); + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + todo!() + } +} + +#[derive(Default, Debug)] +pub struct TwoWay { + l_to_r: Io, + r_to_l: Io, +} + +impl TwoWay { + fn split_sides(self) -> (Io, Io) { + let left = Io { + sender: self.l_to_r.sender, + receiver: self.r_to_l.receiver, + }; + let right = Io { + sender: self.r_to_l.sender, + receiver: self.l_to_r.receiver, + }; + (left, right) + } +} + +pub fn create_connected() -> (Io, Io) { + TwoWay::default().split_sides() +} +#[tokio::test] +async fn way_one() { + let mut a = Io::default(); + let _ = a.send(b"hello".into()).await; + let Some(res) = a.next().await else { panic!() }; + assert_eq!(res, b"hello"); +} + +#[tokio::test] +async fn split() { + let (mut left, mut right) = (TwoWay::default()).split_sides(); + + left.send(b"hello".to_vec()).await; + let Some(res) = right.next().await else { + panic!(); + }; + assert_eq!(res, b"hello"); +} From 4ee231c741f64bc67f5b7837b84d799f620d383e Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Fri, 7 Mar 2025 15:02:34 -0500 Subject: [PATCH 005/206] Add standalone noise wrapper --- src/lib.rs | 1 + src/noise.rs | 186 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 187 insertions(+) create mode 100644 src/noise.rs diff --git a/src/lib.rs b/src/lib.rs index f12bcdb..646a93e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -136,6 +136,7 @@ pub mod schema; pub use builder::Builder as ProtocolBuilder; pub use channels::Channel; +pub use noise::Encrypted; // Export the needed types for Channel::take_receiver, and Channel::local_sender() pub use async_channel::{ Receiver as ChannelReceiver, SendError as ChannelSendError, Sender as ChannelSender, diff --git a/src/noise.rs b/src/noise.rs new file mode 100644 index 0000000..c0efe4f --- /dev/null +++ b/src/noise.rs @@ -0,0 +1,186 @@ +use futures::{Sink, Stream}; +use std::{collections::VecDeque, io::Result, mem::replace, pin::Pin, task::Poll}; +use tracing::{error, trace, warn}; + +use crate::crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeResult}; + +#[derive(Debug)] +pub(crate) enum Step { + NotInitialized, + Handshake(Box), + SecretStream((EncryptCipher, HandshakeResult)), + Established((EncryptCipher, DecryptCipher, HandshakeResult)), +} + +/// Wrap a stream with encryption +#[derive(Debug)] +pub struct Encrypted { + io: IO, + step: Step, + is_initiator: bool, + encrypted_tx: VecDeque>, + encrypted_rx: VecDeque>, + plain_tx: VecDeque>, + plain_rx: VecDeque>, +} + +impl Encrypted +where + IO: Stream + Sink> + Send + Unpin + 'static, +{ + /// Create [`Self`] from a Stream/Sink + pub fn new(is_initiator: bool, io: IO) -> Self { + Self { + io, + is_initiator, + step: Step::NotInitialized, + encrypted_tx: Default::default(), + encrypted_rx: Default::default(), + plain_tx: Default::default(), + plain_rx: Default::default(), + } + } +} + +impl> + Send + Unpin + 'static> Stream for Encrypted { + type Item = Vec; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let Encrypted { + io, + step, + is_initiator, + encrypted_tx, + encrypted_rx, + plain_tx, + plain_rx, + .. + } = self.get_mut(); + + if let Step::Established((encryptor, decryptor, ..)) = step { + // send any pending outgoing messages + while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { + if let Some(encrypted_out) = encrypted_tx.pop_front() { + let _todo = Sink::start_send(Pin::new(io), encrypted_out); + } else { + break; + } + } + + // decrypt any incromming encrypted messages + while let Some(incoming_msg) = encrypted_rx.pop_front() { + match decryptor.decrypt_buf(&incoming_msg) { + Ok((plain_msg, _tag)) => plain_rx.push_back(plain_msg), + Err(e) => error!("RX message failed to decrypt: {e:?}"), + } + } + + // encrypt any pending plaintext outgoinng messages + while let Some(mut plain_out) = plain_tx.pop_front() { + // it encrypts in-place?? + if let Err(_e) = encryptor.encrypt(&mut plain_out) { + todo!("We failed to encrypt our own message...?"); + } + encrypted_tx.push_back(plain_out); + } + + // emit any messages that are ready + if let Some(msg) = plain_rx.pop_front() { + Poll::Ready(Some(msg)) + } else { + Poll::Pending + } + } else { + // Still setting up + if let Ok(Some(msg)) = init(step, *is_initiator) { + // queue the init message to send first + encrypted_tx.push_front(msg); + } + while let Some(incoming_msg) = encrypted_rx.pop_front() { + if let Ok(msgs) = handle_setup_message(step, &incoming_msg, *is_initiator) { + for msg in msgs { + encrypted_tx.push_back(msg); + } + } + } + Poll::Pending + } + } +} + +fn init(step: &mut Step, is_initiator: bool) -> Result>> { + if !matches!(step, Step::NotInitialized) { + return Ok(None); + } + trace!( + "protocol Init, state {:?}, is_initiator {:?}", + step, + is_initiator + ); + let mut handshake = Handshake::new(is_initiator)?; + let out = handshake.start()?; + // next up is handshaking + *step = Step::Handshake(Box::new(handshake)); + Ok(out) +} + +fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Result>> { + match &step { + Step::NotInitialized => { + warn!("Encrypted state was reset"); + let mut handshake = Handshake::new(is_initiator)?; + let start_msg = handshake.start()?; + *step = Step::Handshake(Box::new(handshake)); + + Ok(start_msg.map(|x| vec![x]).unwrap_or(vec![])) + } + Step::Handshake(_) => { + let mut out = vec![]; + if let Step::Handshake(mut handshake) = replace(step, Step::NotInitialized) { + if let Some(response) = handshake.read(msg)? { + out.push(response); + } + + if handshake.complete() { + let handshake_result = handshake.into_result()?; + // The cipher will be put to use to the writer only after the peer's answer has come + let (cipher, init_msg) = EncryptCipher::from_handshake_tx(handshake_result)?; + out.push(init_msg); + *step = Step::SecretStream((cipher, handshake_result.clone())); + } else { + *step = Step::Handshake(handshake); + } + } + Ok(out) + } + Step::SecretStream(_) => { + if let Step::SecretStream((enc_cipher, hs_result)) = replace(step, Step::NotInitialized) + { + let dec_cipher = DecryptCipher::from_handshake_rx_and_init_msg(&hs_result, msg)?; + *step = Step::Established((enc_cipher, dec_cipher, hs_result)); + } + Ok(vec![]) + } + Step::Established((..)) => todo!(), + } +} + +#[cfg(test)] +mod tset { + use crate::test_utils::create_connected; + + use super::*; + use futures::{SinkExt, StreamExt}; + + #[tokio::test] + async fn test_encrypted() -> Result<()> { + let (left, right) = create_connected(); + let left = Encrypted::new(true, left); + let right = Encrypted::new(true, right); + //left.send(b"hello").await?; + todo!() + } +} From ca061fe7fab492b41b6641eb7e419c5c639a76ce Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Fri, 7 Mar 2025 15:04:43 -0500 Subject: [PATCH 006/206] lint test_utils --- src/test_utils.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/test_utils.rs b/src/test_utils.rs index 9eb986c..273935f 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -7,7 +7,7 @@ use async_channel::{unbounded, Receiver, SendError, Sender}; use futures::{Sink, SinkExt, Stream, StreamExt}; #[derive(Debug)] -pub struct Io { +pub(crate) struct Io { receiver: Receiver>, sender: Sender>, } @@ -49,7 +49,7 @@ impl Sink> for Io { } #[derive(Default, Debug)] -pub struct TwoWay { +pub(crate) struct TwoWay { l_to_r: Io, r_to_l: Io, } @@ -68,7 +68,7 @@ impl TwoWay { } } -pub fn create_connected() -> (Io, Io) { +pub(crate) fn create_connected() -> (Io, Io) { TwoWay::default().split_sides() } #[tokio::test] @@ -83,7 +83,7 @@ async fn way_one() { async fn split() { let (mut left, mut right) = (TwoWay::default()).split_sides(); - left.send(b"hello".to_vec()).await; + left.send(b"hello".to_vec()).await.unwrap(); let Some(res) = right.next().await else { panic!(); }; From 3e82e0f5f9e787a2b5b780bb818af6fb5a16945a Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 9 Mar 2025 00:21:12 -0500 Subject: [PATCH 007/206] wip noise --- src/noise.rs | 343 ++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 312 insertions(+), 31 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index c0efe4f..a28c2c8 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -1,6 +1,13 @@ use futures::{Sink, Stream}; -use std::{collections::VecDeque, io::Result, mem::replace, pin::Pin, task::Poll}; -use tracing::{error, trace, warn}; +use std::{ + collections::VecDeque, + fmt::Debug, + io::Result, + mem::replace, + pin::Pin, + task::{Context, Poll, Waker}, +}; +use tracing::{error, info, instrument, trace, warn}; use crate::crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeResult}; @@ -22,14 +29,17 @@ pub struct Encrypted { encrypted_rx: VecDeque>, plain_tx: VecDeque>, plain_rx: VecDeque>, + flush: bool, + count: usize, + name: String, } impl Encrypted where - IO: Stream + Sink> + Send + Unpin + 'static, + IO: Stream> + Sink> + Send + Unpin + 'static, { /// Create [`Self`] from a Stream/Sink - pub fn new(is_initiator: bool, io: IO) -> Self { + pub fn new(is_initiator: bool, io: IO, name: &str) -> Self { Self { io, is_initiator, @@ -38,17 +48,35 @@ where encrypted_rx: Default::default(), plain_tx: Default::default(), plain_rx: Default::default(), + flush: false, + count: 0, + name: name.to_string(), } } } -impl> + Send + Unpin + 'static> Stream for Encrypted { - type Item = Vec; +impl> + Sink> + Send + Unpin + Debug + 'static> Sink> + for Encrypted +{ + type Error = (); + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(mut self: Pin<&mut Self>, item: Vec) -> std::result::Result<(), Self::Error> { + trace!("{} add plain tx", self.name); + self.plain_tx.push_back(item); + Ok(()) + } - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { let Encrypted { io, step, @@ -57,24 +85,233 @@ impl> + Send + Unpin + 'static> Stream for Encrypted 200 { + //panic!(); + } + // send any pending outgoing messages + while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { + if let Some(encrypted_out) = encrypted_tx.pop_front() { + trace!( + "{name} enc tx send msg + {encrypted_out:?} +" + ); + let _todo = Sink::start_send(Pin::new(io), encrypted_out); + *flush = true; + } else { + break; + } + } + if *flush { + // confusing docs related to start send + // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.start_send + // First part says: + // "you must use **poll_flush** ... inorder to garuntee + // completions of send" + // Then it says: + // " It is only necessary to call poll_flush if you need to guarantee that all + // of the items placed into the Sink have been sent" + // + // So do I need to do it or not? + // must `poll_flush` be called for **anything** to send? + match Sink::poll_flush(Pin::new(io), cx) { + Poll::Ready(Ok(())) => { + *flush = false; + trace!("{name} flushed good"); + } + Poll::Ready(Err(_e)) => error!("{name} Error sending encrypted msg"), + Poll::Pending => { + // More confusing docs + // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.poll_flush + // It says: + // "Returns Poll::Pending if there is more work left to do, in which case the + // current task is scheduled (via cx.waker().wake_by_ref()) to wake up when + // poll_flush should be called again." + // Does this mean, each time this task wakes up again from this code path that + // I must trigger another poll_flush? But how would I know i need more + // flushing? + *flush = true; + } + } + } + + // pull in any incomming encrypted messages + loop { + match Stream::poll_next(Pin::new(io), cx) { + Poll::Pending => break, + Poll::Ready(None) => todo!(), + Poll::Ready(Some(encrypted_msg)) => { + trace!( + "{name} enc rx queue + {encrypted_msg:?} + ); +" + ); + encrypted_rx.push_back(encrypted_msg); + } + } + } + if let Step::Established((encryptor, decryptor, ..)) = step { - // send any pending outgoing messages - while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { - if let Some(encrypted_out) = encrypted_tx.pop_front() { - let _todo = Sink::start_send(Pin::new(io), encrypted_out); - } else { - break; + // decrypt any incromming encrypted messages + while let Some(incoming_msg) = encrypted_rx.pop_front() { + match decryptor.decrypt_buf(&incoming_msg) { + Ok((plain_msg, _tag)) => { + trace!("{name} plain rx queue"); + plain_rx.push_back(plain_msg); + } + Err(e) => error!("{name} RX message failed to decrypt: {e:?}"), + } + } + + // encrypt any pending plaintext outgoinng messages + while let Some(mut plain_out) = plain_tx.pop_front() { + // it encrypts in-place?? + if let Err(_e) = encryptor.encrypt(&mut plain_out) { + todo!("{name} We failed to encrypt our own message...?"); + } + trace!("{name} enc tx queue"); + encrypted_tx.push_back(plain_out); + } + + if *flush { + Poll::Pending + } else { + Poll::Ready(Ok(())) + } + } else { + trace!("{name} doing setup"); + // Still setting up + if let Ok(Some(msg)) = init(step, *is_initiator) { + // queue the init message to send first + trace!("{name} queue initial msg"); + encrypted_tx.push_front(msg); + } + while let Some(incoming_msg) = encrypted_rx.pop_front() { + trace!("{name} recieved setup msg"); + if let Ok(msgs) = handle_setup_message(step, &incoming_msg, *is_initiator, &name) { + for msg in msgs { + trace!("{name} queue more setup msg"); + encrypted_tx.push_front(msg); + } } } + cx.waker().wake_by_ref(); + Poll::Pending + } + } + + fn poll_close( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + todo!() + } +} +impl> + Sink> + Send + Unpin + Debug + 'static> Stream + for Encrypted +{ + type Item = Vec; + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let Encrypted { + io, + step, + is_initiator, + encrypted_tx, + encrypted_rx, + plain_tx, + plain_rx, + flush, + count, + name, + .. + } = self.get_mut(); + + *count += 1; + if *count > 200 { + //panic!(); + } + // send any pending outgoing messages + while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { + if let Some(encrypted_out) = encrypted_tx.pop_front() { + trace!( + "{name} enc tx send msg + {encrypted_out:?} +" + ); + let _todo = Sink::start_send(Pin::new(io), encrypted_out); + *flush = true; + } else { + break; + } + } + if *flush { + // confusing docs related to start send + // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.start_send + // First part says: + // "you must use **poll_flush** ... inorder to garuntee + // completions of send" + // Then it says: + // " It is only necessary to call poll_flush if you need to guarantee that all + // of the items placed into the Sink have been sent" + // + // So do I need to do it or not? + // must `poll_flush` be called for **anything** to send? + match Sink::poll_flush(Pin::new(io), cx) { + Poll::Ready(Ok(())) => { + *flush = false; + trace!("{name} flushed good"); + } + Poll::Ready(Err(_e)) => error!("{name} Error sending encrypted msg"), + Poll::Pending => { + // More confusing docs + // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.poll_flush + // It says: + // "Returns Poll::Pending if there is more work left to do, in which case the + // current task is scheduled (via cx.waker().wake_by_ref()) to wake up when + // poll_flush should be called again." + // Does this mean, each time this task wakes up again from this code path that + // I must trigger another poll_flush? But how would I know i need more + // flushing? + *flush = true; + } + } + } + + // pull in any incomming encrypted messages + loop { + match Stream::poll_next(Pin::new(io), cx) { + Poll::Pending => break, + Poll::Ready(None) => todo!(), + Poll::Ready(Some(encrypted_msg)) => { + trace!( + "{name} enc rx queue + {encrypted_msg:?} + ); +" + ); + encrypted_rx.push_back(encrypted_msg); + } + } + } + + if let Step::Established((encryptor, decryptor, ..)) = step { // decrypt any incromming encrypted messages while let Some(incoming_msg) = encrypted_rx.pop_front() { match decryptor.decrypt_buf(&incoming_msg) { - Ok((plain_msg, _tag)) => plain_rx.push_back(plain_msg), - Err(e) => error!("RX message failed to decrypt: {e:?}"), + Ok((plain_msg, _tag)) => { + trace!("{name} plain rx queue"); + plain_rx.push_back(plain_msg); + } + Err(e) => error!("{name} RX message failed to decrypt: {e:?}"), } } @@ -82,35 +319,49 @@ impl> + Send + Unpin + 'static> Stream for Encrypted> + Sink> + Send + Unpin + 'static>( + encrypted: &mut Encrypted, + cx: &mut Context<'_>, +) { + todo!() +} + fn init(step: &mut Step, is_initiator: bool) -> Result>> { if !matches!(step, Step::NotInitialized) { return Ok(None); @@ -122,15 +373,19 @@ fn init(step: &mut Step, is_initiator: bool) -> Result>> { ); let mut handshake = Handshake::new(is_initiator)?; let out = handshake.start()?; - // next up is handshaking *step = Step::Handshake(Box::new(handshake)); Ok(out) } -fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Result>> { +fn handle_setup_message( + step: &mut Step, + msg: &[u8], + is_initiator: bool, + name: &str, +) -> Result>> { match &step { Step::NotInitialized => { - warn!("Encrypted state was reset"); + warn!("{name} Encrypted state was reset"); let mut handshake = Handshake::new(is_initiator)?; let start_msg = handshake.start()?; *step = Step::Handshake(Box::new(handshake)); @@ -140,14 +395,33 @@ fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Resu Step::Handshake(_) => { let mut out = vec![]; if let Step::Handshake(mut handshake) = replace(step, Step::NotInitialized) { - if let Some(response) = handshake.read(msg)? { + if let Some(response) = match handshake.read(msg) { + Ok(x) => x, + Err(e) => { + error!("error in handshake.read(msg) {e:?}"); + return Err(e); + } + } { out.push(response); } if handshake.complete() { - let handshake_result = handshake.into_result()?; + let handshake_result = match handshake.into_result() { + Ok(x) => x, + Err(e) => { + error!("into-result error {e:?}"); + return Err(e); + } + }; // The cipher will be put to use to the writer only after the peer's answer has come - let (cipher, init_msg) = EncryptCipher::from_handshake_tx(handshake_result)?; + let (cipher, init_msg) = + match EncryptCipher::from_handshake_tx(handshake_result) { + Ok(x) => x, + Err(e) => { + error!("from_handshake_tx error {e:?}"); + return Err(e); + } + }; out.push(init_msg); *step = Step::SecretStream((cipher, handshake_result.clone())); } else { @@ -170,17 +444,24 @@ fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Resu #[cfg(test)] mod tset { - use crate::test_utils::create_connected; + use crate::test_utils::{create_connected, log, Io}; use super::*; use futures::{SinkExt, StreamExt}; #[tokio::test] async fn test_encrypted() -> Result<()> { + log(); let (left, right) = create_connected(); - let left = Encrypted::new(true, left); - let right = Encrypted::new(true, right); - //left.send(b"hello").await?; + let mut left = Encrypted::new(true, left, "left"); + let mut right = Encrypted::new(true, right, "right"); + tokio::task::spawn(async move { + left.send(b"hello".into()).await.unwrap(); + }); + //tokio::task::spawn(async move { + // let x = left.next().await; + //}); + dbg!(right.next().await); todo!() } } From 241622f36c65e7043935e989ad41b7c7dc67c247 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 9 Mar 2025 23:10:52 -0400 Subject: [PATCH 008/206] wip poll impl --- src/noise.rs | 134 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 133 insertions(+), 1 deletion(-) diff --git a/src/noise.rs b/src/noise.rs index a28c2c8..e977315 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -359,7 +359,139 @@ fn poll> + Sink> + Send + Unpin + 'static>( encrypted: &mut Encrypted, cx: &mut Context<'_>, ) { - todo!() + let Encrypted { + io, + step, + is_initiator, + encrypted_tx, + encrypted_rx, + plain_tx, + plain_rx, + flush, + count, + name, + .. + } = encrypted; + + *count += 1; + if *count > 200 { + //panic!(); + } + // send any pending outgoing messages + while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { + if let Some(encrypted_out) = encrypted_tx.pop_front() { + trace!( + "{name} enc tx send msg + {encrypted_out:?} +" + ); + let _todo = Sink::start_send(Pin::new(io), encrypted_out); + *flush = true; + } else { + break; + } + } + if *flush { + // confusing docs related to start send + // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.start_send + // First part says: + // "you must use **poll_flush** ... inorder to garuntee + // completions of send" + // Then it says: + // " It is only necessary to call poll_flush if you need to guarantee that all + // of the items placed into the Sink have been sent" + // + // So do I need to do it or not? + // must `poll_flush` be called for **anything** to send? + match Sink::poll_flush(Pin::new(io), cx) { + Poll::Ready(Ok(())) => { + *flush = false; + trace!("{name} flushed good"); + } + Poll::Ready(Err(_e)) => error!("{name} Error sending encrypted msg"), + Poll::Pending => { + // More confusing docs + // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.poll_flush + // It says: + // "Returns Poll::Pending if there is more work left to do, in which case the + // current task is scheduled (via cx.waker().wake_by_ref()) to wake up when + // poll_flush should be called again." + // Does this mean, each time this task wakes up again from this code path that + // I must trigger another poll_flush? But how would I know i need more + // flushing? + *flush = true; + } + } + } + + // pull in any incomming encrypted messages + loop { + match Stream::poll_next(Pin::new(io), cx) { + Poll::Pending => break, + Poll::Ready(None) => todo!(), + Poll::Ready(Some(encrypted_msg)) => { + trace!( + "{name} enc rx queue + {encrypted_msg:?} + ); +" + ); + encrypted_rx.push_back(encrypted_msg); + } + } + } + + if let Step::Established((encryptor, decryptor, ..)) = step { + // decrypt any incromming encrypted messages + while let Some(incoming_msg) = encrypted_rx.pop_front() { + match decryptor.decrypt_buf(&incoming_msg) { + Ok((plain_msg, _tag)) => { + trace!("{name} plain rx queue"); + plain_rx.push_back(plain_msg); + } + Err(e) => error!("{name} RX message failed to decrypt: {e:?}"), + } + } + + // encrypt any pending plaintext outgoinng messages + while let Some(mut plain_out) = plain_tx.pop_front() { + // it encrypts in-place?? + if let Err(_e) = encryptor.encrypt(&mut plain_out) { + todo!("{name} We failed to encrypt our own message...?"); + } + trace!("{name} enc tx queue"); + encrypted_tx.push_back(plain_out); + } + + // emit any messages that are ready + if let Some(msg) = plain_rx.pop_front() { + trace!("{name} plain rx emit"); + //Poll::Ready(Some(msg)) + todo!() + } else { + //Poll::Pending + todo!() + } + } else { + trace!("{name} doing setup"); + // Still setting up + if let Ok(Some(msg)) = init(step, *is_initiator) { + // queue the init message to send first + trace!("{name} queue initial msg"); + encrypted_tx.push_front(msg); + } + while let Some(incoming_msg) = encrypted_rx.pop_front() { + trace!("{name} recieved setup msg"); + if let Ok(msgs) = handle_setup_message(step, &incoming_msg, *is_initiator, &name) { + for msg in msgs { + trace!("{name} queue more setup msg"); + encrypted_tx.push_front(msg); + } + } + } + cx.waker().wake_by_ref(); + todo!() + } } fn init(step: &mut Step, is_initiator: bool) -> Result>> { From 2591d2889eab4dfb358bde75a7ef61337d271146 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 10 Mar 2025 16:51:07 -0400 Subject: [PATCH 009/206] Add start_raw and read_raw --- src/crypto/handshake.rs | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index ee2d941..21c5442 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -101,15 +101,17 @@ impl Handshake { }) } - pub(crate) fn start(&mut self) -> Result>> { + pub(crate) fn start_raw(&mut self) -> Result>> { if self.is_initiator() { let tx_len = self.send()?; - let wrapped = wrap_uint24_le(&self.tx_buf[..tx_len].to_vec()); - Ok(Some(wrapped)) + Ok(Some(self.tx_buf[..tx_len].to_vec())) } else { Ok(None) } } + pub(crate) fn start(&mut self) -> Result>> { + Ok(self.start_raw()?.map(|x| wrap_uint24_le(&x))) + } pub(crate) fn complete(&self) -> bool { self.complete @@ -124,13 +126,13 @@ impl Handshake { .read_message(msg, &mut self.rx_buf) .map_err(map_err) } - fn send(&mut self) -> Result { + pub(crate) fn send(&mut self) -> Result { self.state .write_message(&self.payload, &mut self.tx_buf) .map_err(map_err) } - pub(crate) fn read(&mut self, msg: &[u8]) -> Result>> { + pub(crate) fn read_raw(&mut self, msg: &[u8]) -> Result>> { // eprintln!("hs read len {}", msg.len()); if self.complete() { return Err(Error::new(ErrorKind::Other, "Handshake read after finish")); @@ -138,16 +140,17 @@ impl Handshake { let _rx_len = self.recv(msg)?; + // first non-init if !self.is_initiator() && !self.did_receive { self.did_receive = true; let tx_len = self.send()?; - let wrapped = wrap_uint24_le(&self.tx_buf[..tx_len].to_vec()); + let wrapped = self.tx_buf[..tx_len].to_vec(); return Ok(Some(wrapped)); } let tx_buf = if self.is_initiator() { let tx_len = self.send()?; - let wrapped = wrap_uint24_le(&self.tx_buf[..tx_len].to_vec()); + let wrapped = self.tx_buf[..tx_len].to_vec(); Some(wrapped) } else { None @@ -170,6 +173,10 @@ impl Handshake { self.complete = true; Ok(tx_buf) } + // reads in `msg` without framing bytes, but emits msg WITH framing bytes + pub(crate) fn read(&mut self, msg: &[u8]) -> Result>> { + Ok(self.read_raw(msg)?.map(|x| wrap_uint24_le(&x))) + } pub(crate) fn into_result(&self) -> Result<&HandshakeResult> { if !self.complete() { From 400d57d354e02e0bd0b3efdc3fe5f958ff1bd73a Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 11 Mar 2025 01:30:52 -0400 Subject: [PATCH 010/206] lint --- src/noise.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index e977315..86de5b3 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -5,9 +5,9 @@ use std::{ io::Result, mem::replace, pin::Pin, - task::{Context, Poll, Waker}, + task::{Context, Poll}, }; -use tracing::{error, info, instrument, trace, warn}; +use tracing::{error, info, trace, warn}; use crate::crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeResult}; @@ -210,7 +210,7 @@ impl> + Sink> + Send + Unpin + Debug + 'static fn poll_close( self: Pin<&mut Self>, - cx: &mut Context<'_>, + _cx: &mut Context<'_>, ) -> Poll> { todo!() } @@ -355,7 +355,7 @@ impl> + Sink> + Send + Unpin + Debug + 'static } } -fn poll> + Sink> + Send + Unpin + 'static>( +fn _poll> + Sink> + Send + Unpin + 'static>( encrypted: &mut Encrypted, cx: &mut Context<'_>, ) { @@ -464,7 +464,7 @@ fn poll> + Sink> + Send + Unpin + 'static>( } // emit any messages that are ready - if let Some(msg) = plain_rx.pop_front() { + if let Some(_msg) = plain_rx.pop_front() { trace!("{name} plain rx emit"); //Poll::Ready(Some(msg)) todo!() From 7c9e2f53ba1aa026bacf00e9dc2584b363b65820 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 11 Mar 2025 01:31:37 -0400 Subject: [PATCH 011/206] use read_raw & start_raw --- src/noise.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index 86de5b3..467008e 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -504,7 +504,7 @@ fn init(step: &mut Step, is_initiator: bool) -> Result>> { is_initiator ); let mut handshake = Handshake::new(is_initiator)?; - let out = handshake.start()?; + let out = handshake.start_raw()?; *step = Step::Handshake(Box::new(handshake)); Ok(out) } @@ -519,7 +519,7 @@ fn handle_setup_message( Step::NotInitialized => { warn!("{name} Encrypted state was reset"); let mut handshake = Handshake::new(is_initiator)?; - let start_msg = handshake.start()?; + let start_msg = handshake.start_raw()?; *step = Step::Handshake(Box::new(handshake)); Ok(start_msg.map(|x| vec![x]).unwrap_or(vec![])) @@ -527,10 +527,10 @@ fn handle_setup_message( Step::Handshake(_) => { let mut out = vec![]; if let Step::Handshake(mut handshake) = replace(step, Step::NotInitialized) { - if let Some(response) = match handshake.read(msg) { + if let Some(response) = match handshake.read_raw(msg) { Ok(x) => x, Err(e) => { - error!("error in handshake.read(msg) {e:?}"); + error!("error in handshake.read_raw(msg) {e:?}"); return Err(e); } } { From 20302eb4a01a8c238e4db89345f4768c749379fb Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 11 Mar 2025 17:23:52 -0400 Subject: [PATCH 012/206] Encrypted stream is now working! --- src/noise.rs | 354 +++++++++++++++++++-------------------------------- 1 file changed, 128 insertions(+), 226 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index 467008e..c22c1fd 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -7,16 +7,26 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use tracing::{error, info, trace, warn}; +use tracing::{error, info, instrument, trace, warn}; -use crate::crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeResult}; +use crate::crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeResult, RawEncryptCipher}; + +macro_rules! name { + ($name:tt) => {{ + if $name { + "initiator" + } else { + "other" + } + }}; +} #[derive(Debug)] pub(crate) enum Step { NotInitialized, Handshake(Box), - SecretStream((EncryptCipher, HandshakeResult)), - Established((EncryptCipher, DecryptCipher, HandshakeResult)), + SecretStream((RawEncryptCipher, HandshakeResult)), + Established((RawEncryptCipher, DecryptCipher, HandshakeResult)), } /// Wrap a stream with encryption @@ -31,15 +41,22 @@ pub struct Encrypted { plain_rx: VecDeque>, flush: bool, count: usize, - name: String, } +fn ename(is_initiator: bool) -> String { + if is_initiator { + "initiator".to_string() + } else { + "other".to_string() + } +} impl Encrypted where - IO: Stream> + Sink> + Send + Unpin + 'static, + IO: Stream> + Sink> + Send + Unpin + Debug + 'static, { /// Create [`Self`] from a Stream/Sink - pub fn new(is_initiator: bool, io: IO, name: &str) -> Self { + #[instrument(skip_all, fields(name = %ename(is_initiator)))] + pub fn new(is_initiator: bool, io: IO) -> Self { Self { io, is_initiator, @@ -50,7 +67,6 @@ where plain_rx: Default::default(), flush: false, count: 0, - name: name.to_string(), } } } @@ -67,12 +83,14 @@ impl> + Sink> + Send + Unpin + Debug + 'static Poll::Ready(Ok(())) } + #[instrument(skip_all, fields(name = %ename(self.is_initiator)))] fn start_send(mut self: Pin<&mut Self>, item: Vec) -> std::result::Result<(), Self::Error> { - trace!("{} add plain tx", self.name); + trace!("add plain tx"); self.plain_tx.push_back(item); Ok(()) } + #[instrument(skip_all, fields(name = %ename(self.is_initiator)))] fn poll_flush( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -87,7 +105,6 @@ impl> + Sink> + Send + Unpin + Debug + 'static plain_rx, flush, count, - name, .. } = self.get_mut(); @@ -99,9 +116,8 @@ impl> + Sink> + Send + Unpin + Debug + 'static while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { if let Some(encrypted_out) = encrypted_tx.pop_front() { trace!( - "{name} enc tx send msg - {encrypted_out:?} -" + name = %ename(*is_initiator), + "enc tx send msg\n{encrypted_out:?}" ); let _todo = Sink::start_send(Pin::new(io), encrypted_out); *flush = true; @@ -124,9 +140,11 @@ impl> + Sink> + Send + Unpin + Debug + 'static match Sink::poll_flush(Pin::new(io), cx) { Poll::Ready(Ok(())) => { *flush = false; - trace!("{name} flushed good"); + trace!(name = %ename(*is_initiator), "flushed good"); } - Poll::Ready(Err(_e)) => error!("{name} Error sending encrypted msg"), + Poll::Ready(Err(_e)) => error!( + name = %ename(*is_initiator), + "Error sending encrypted msg"), Poll::Pending => { // More confusing docs // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.poll_flush @@ -148,12 +166,7 @@ impl> + Sink> + Send + Unpin + Debug + 'static Poll::Pending => break, Poll::Ready(None) => todo!(), Poll::Ready(Some(encrypted_msg)) => { - trace!( - "{name} enc rx queue - {encrypted_msg:?} - ); -" - ); + trace!(name = %ename(*is_initiator), "enc rx queue\n{encrypted_msg:?}"); encrypted_rx.push_back(encrypted_msg); } } @@ -164,41 +177,45 @@ impl> + Sink> + Send + Unpin + Debug + 'static while let Some(incoming_msg) = encrypted_rx.pop_front() { match decryptor.decrypt_buf(&incoming_msg) { Ok((plain_msg, _tag)) => { - trace!("{name} plain rx queue"); + trace!(name = %ename(*is_initiator), "plain rx queue"); plain_rx.push_back(plain_msg); } - Err(e) => error!("{name} RX message failed to decrypt: {e:?}"), + Err(e) => { + error!(name = %ename(*is_initiator), "RX message failed to decrypt: {e:?}") + } } } // encrypt any pending plaintext outgoinng messages while let Some(mut plain_out) = plain_tx.pop_front() { // it encrypts in-place?? - if let Err(_e) = encryptor.encrypt(&mut plain_out) { - todo!("{name} We failed to encrypt our own message...?"); - } - trace!("{name} enc tx queue"); - encrypted_tx.push_back(plain_out); + let enc_out = match encryptor.encrypt(&mut plain_out) { + Ok(x) => x, + Err(_e) => todo!("We failed to encrypt our own message...?"), + }; + trace!(name = %ename(*is_initiator), "enc from plain tx queue\n{enc_out:?}"); + encrypted_tx.push_back(enc_out); + *flush = true; } if *flush { + cx.waker().wake_by_ref(); Poll::Pending } else { Poll::Ready(Ok(())) } } else { - trace!("{name} doing setup"); // Still setting up - if let Ok(Some(msg)) = init(step, *is_initiator) { + if let Ok(Some(msg)) = maybe_init(step, *is_initiator) { // queue the init message to send first - trace!("{name} queue initial msg"); + trace!(name = %ename(*is_initiator),"queue initial msg\n{msg:?}"); encrypted_tx.push_front(msg); } while let Some(incoming_msg) = encrypted_rx.pop_front() { - trace!("{name} recieved setup msg"); - if let Ok(msgs) = handle_setup_message(step, &incoming_msg, *is_initiator, &name) { - for msg in msgs { - trace!("{name} queue more setup msg"); + trace!(name = %ename(*is_initiator),"recieved setup msg"); + if let Ok(msgs) = handle_setup_message(step, &incoming_msg, *is_initiator) { + for msg in msgs.into_iter().rev() { + trace!(name = %ename(*is_initiator),"queue more setup msg\n{msg:?}"); encrypted_tx.push_front(msg); } } @@ -208,6 +225,7 @@ impl> + Sink> + Send + Unpin + Debug + 'static } } + #[instrument(skip_all, fields(name = %ename(self.is_initiator)))] fn poll_close( self: Pin<&mut Self>, _cx: &mut Context<'_>, @@ -220,6 +238,7 @@ impl> + Sink> + Send + Unpin + Debug + 'static { type Item = Vec; + #[instrument(skip_all, fields(name = %ename(self.is_initiator)))] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let Encrypted { io, @@ -231,7 +250,6 @@ impl> + Sink> + Send + Unpin + Debug + 'static plain_rx, flush, count, - name, .. } = self.get_mut(); @@ -242,8 +260,7 @@ impl> + Sink> + Send + Unpin + Debug + 'static // send any pending outgoing messages while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { if let Some(encrypted_out) = encrypted_tx.pop_front() { - trace!( - "{name} enc tx send msg + trace!(name = %ename(*is_initiator), "enc tx send msg {encrypted_out:?} " ); @@ -268,9 +285,11 @@ impl> + Sink> + Send + Unpin + Debug + 'static match Sink::poll_flush(Pin::new(io), cx) { Poll::Ready(Ok(())) => { *flush = false; - trace!("{name} flushed good"); + trace!(name = %ename(*is_initiator), "flushed good"); + } + Poll::Ready(Err(_e)) => { + error!(name = %ename(*is_initiator), "Error sending encrypted msg") } - Poll::Ready(Err(_e)) => error!("{name} Error sending encrypted msg"), Poll::Pending => { // More confusing docs // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.poll_flush @@ -290,14 +309,9 @@ impl> + Sink> + Send + Unpin + Debug + 'static loop { match Stream::poll_next(Pin::new(io), cx) { Poll::Pending => break, - Poll::Ready(None) => todo!(), + Poll::Ready(None) => break, Poll::Ready(Some(encrypted_msg)) => { - trace!( - "{name} enc rx queue - {encrypted_msg:?} - ); -" - ); + trace!(name = %ename(*is_initiator), "enc rx queue\n{encrypted_msg:?}"); encrypted_rx.push_back(encrypted_msg); } } @@ -308,43 +322,50 @@ impl> + Sink> + Send + Unpin + Debug + 'static while let Some(incoming_msg) = encrypted_rx.pop_front() { match decryptor.decrypt_buf(&incoming_msg) { Ok((plain_msg, _tag)) => { - trace!("{name} plain rx queue"); + trace!(name = %ename(*is_initiator), "plain rx queue"); plain_rx.push_back(plain_msg); } - Err(e) => error!("{name} RX message failed to decrypt: {e:?}"), + Err(e) => { + error!(name = %ename(*is_initiator),"RX message failed to decrypt: {e:?}") + } } } // encrypt any pending plaintext outgoinng messages while let Some(mut plain_out) = plain_tx.pop_front() { - // it encrypts in-place?? - if let Err(_e) = encryptor.encrypt(&mut plain_out) { - todo!("{name} We failed to encrypt our own message...?"); - } - trace!("{name} enc tx queue"); - encrypted_tx.push_back(plain_out); + let enc_out = match encryptor.encrypt(&mut plain_out) { + Ok(x) => x, + Err(_e) => todo!("We failed to encrypt our own message...?"), + }; + trace!(name = %ename(*is_initiator), "enc from plain tx queue\n{enc_out:?}"); + encrypted_tx.push_back(enc_out); } // emit any messages that are ready if let Some(msg) = plain_rx.pop_front() { - trace!("{name} plain rx emit"); + trace!(name = %ename(*is_initiator), "plain rx emit"); Poll::Ready(Some(msg)) } else { Poll::Pending } } else { - trace!("{name} doing setup"); // Still setting up - if let Ok(Some(msg)) = init(step, *is_initiator) { + if let Ok(Some(msg)) = maybe_init(step, *is_initiator) { // queue the init message to send first - trace!("{name} queue initial msg"); + trace!(name = %ename(*is_initiator),"queue initial msg\n{msg:?}"); encrypted_tx.push_front(msg); } while let Some(incoming_msg) = encrypted_rx.pop_front() { - trace!("{name} recieved setup msg"); - if let Ok(msgs) = handle_setup_message(step, &incoming_msg, *is_initiator, &name) { - for msg in msgs { - trace!("{name} queue more setup msg"); + trace!(name = %ename(*is_initiator), "recieved setup msg"); + if let Ok(msgs) = match handle_setup_message(step, &incoming_msg, *is_initiator) { + Ok(x) => Ok(x), + Err(e) => { + error!("handle_setup_message error: {e:?}"); + Err(e) + } + } { + for msg in msgs.into_iter().rev() { + trace!(name = %ename(*is_initiator),"queue more setup msg\n{msg:?}"); encrypted_tx.push_front(msg); } } @@ -355,169 +376,22 @@ impl> + Sink> + Send + Unpin + Debug + 'static } } -fn _poll> + Sink> + Send + Unpin + 'static>( - encrypted: &mut Encrypted, - cx: &mut Context<'_>, -) { - let Encrypted { - io, - step, - is_initiator, - encrypted_tx, - encrypted_rx, - plain_tx, - plain_rx, - flush, - count, - name, - .. - } = encrypted; - - *count += 1; - if *count > 200 { - //panic!(); - } - // send any pending outgoing messages - while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { - if let Some(encrypted_out) = encrypted_tx.pop_front() { - trace!( - "{name} enc tx send msg - {encrypted_out:?} -" - ); - let _todo = Sink::start_send(Pin::new(io), encrypted_out); - *flush = true; - } else { - break; - } - } - if *flush { - // confusing docs related to start send - // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.start_send - // First part says: - // "you must use **poll_flush** ... inorder to garuntee - // completions of send" - // Then it says: - // " It is only necessary to call poll_flush if you need to guarantee that all - // of the items placed into the Sink have been sent" - // - // So do I need to do it or not? - // must `poll_flush` be called for **anything** to send? - match Sink::poll_flush(Pin::new(io), cx) { - Poll::Ready(Ok(())) => { - *flush = false; - trace!("{name} flushed good"); - } - Poll::Ready(Err(_e)) => error!("{name} Error sending encrypted msg"), - Poll::Pending => { - // More confusing docs - // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.poll_flush - // It says: - // "Returns Poll::Pending if there is more work left to do, in which case the - // current task is scheduled (via cx.waker().wake_by_ref()) to wake up when - // poll_flush should be called again." - // Does this mean, each time this task wakes up again from this code path that - // I must trigger another poll_flush? But how would I know i need more - // flushing? - *flush = true; - } - } - } - - // pull in any incomming encrypted messages - loop { - match Stream::poll_next(Pin::new(io), cx) { - Poll::Pending => break, - Poll::Ready(None) => todo!(), - Poll::Ready(Some(encrypted_msg)) => { - trace!( - "{name} enc rx queue - {encrypted_msg:?} - ); -" - ); - encrypted_rx.push_back(encrypted_msg); - } - } - } - - if let Step::Established((encryptor, decryptor, ..)) = step { - // decrypt any incromming encrypted messages - while let Some(incoming_msg) = encrypted_rx.pop_front() { - match decryptor.decrypt_buf(&incoming_msg) { - Ok((plain_msg, _tag)) => { - trace!("{name} plain rx queue"); - plain_rx.push_back(plain_msg); - } - Err(e) => error!("{name} RX message failed to decrypt: {e:?}"), - } - } - - // encrypt any pending plaintext outgoinng messages - while let Some(mut plain_out) = plain_tx.pop_front() { - // it encrypts in-place?? - if let Err(_e) = encryptor.encrypt(&mut plain_out) { - todo!("{name} We failed to encrypt our own message...?"); - } - trace!("{name} enc tx queue"); - encrypted_tx.push_back(plain_out); - } - - // emit any messages that are ready - if let Some(_msg) = plain_rx.pop_front() { - trace!("{name} plain rx emit"); - //Poll::Ready(Some(msg)) - todo!() - } else { - //Poll::Pending - todo!() - } - } else { - trace!("{name} doing setup"); - // Still setting up - if let Ok(Some(msg)) = init(step, *is_initiator) { - // queue the init message to send first - trace!("{name} queue initial msg"); - encrypted_tx.push_front(msg); - } - while let Some(incoming_msg) = encrypted_rx.pop_front() { - trace!("{name} recieved setup msg"); - if let Ok(msgs) = handle_setup_message(step, &incoming_msg, *is_initiator, &name) { - for msg in msgs { - trace!("{name} queue more setup msg"); - encrypted_tx.push_front(msg); - } - } - } - cx.waker().wake_by_ref(); - todo!() - } -} - -fn init(step: &mut Step, is_initiator: bool) -> Result>> { +fn maybe_init(step: &mut Step, is_initiator: bool) -> Result>> { if !matches!(step, Step::NotInitialized) { return Ok(None); } - trace!( - "protocol Init, state {:?}, is_initiator {:?}", - step, - is_initiator - ); + trace!(name = %ename(is_initiator), "Init, state {step:?}"); let mut handshake = Handshake::new(is_initiator)?; let out = handshake.start_raw()?; *step = Step::Handshake(Box::new(handshake)); Ok(out) } -fn handle_setup_message( - step: &mut Step, - msg: &[u8], - is_initiator: bool, - name: &str, -) -> Result>> { +#[instrument(skip_all, fields(name = %ename(is_initiator)))] +fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Result>> { match &step { Step::NotInitialized => { - warn!("{name} Encrypted state was reset"); + warn!("{} Encrypted state was reset", name!(is_initiator)); let mut handshake = Handshake::new(is_initiator)?; let start_msg = handshake.start_raw()?; *step = Step::Handshake(Box::new(handshake)); @@ -527,17 +401,23 @@ fn handle_setup_message( Step::Handshake(_) => { let mut out = vec![]; if let Step::Handshake(mut handshake) = replace(step, Step::NotInitialized) { + trace!("Read in handshake msg\n{msg:?}"); if let Some(response) = match handshake.read_raw(msg) { Ok(x) => x, Err(e) => { - error!("error in handshake.read_raw(msg) {e:?}"); + panic!("error in handshake.read_raw(msg) {e:?}"); return Err(e); } } { + info!( + "{} read message and emitting response {response:?}", + name!(is_initiator) + ); out.push(response); } if handshake.complete() { + info!("{} HS complete. Making result", name!(is_initiator)); let handshake_result = match handshake.into_result() { Ok(x) => x, Err(e) => { @@ -547,13 +427,14 @@ fn handle_setup_message( }; // The cipher will be put to use to the writer only after the peer's answer has come let (cipher, init_msg) = - match EncryptCipher::from_handshake_tx(handshake_result) { + match RawEncryptCipher::from_handshake_tx(handshake_result) { Ok(x) => x, Err(e) => { error!("from_handshake_tx error {e:?}"); return Err(e); } }; + info!("{} made enc cipher", name!(is_initiator)); out.push(init_msg); *step = Step::SecretStream((cipher, handshake_result.clone())); } else { @@ -563,6 +444,7 @@ fn handle_setup_message( Ok(out) } Step::SecretStream(_) => { + info!("E're a secret stream now!!!!!"); if let Step::SecretStream((enc_cipher, hs_result)) = replace(step, Step::NotInitialized) { let dec_cipher = DecryptCipher::from_handshake_rx_and_init_msg(&hs_result, msg)?; @@ -576,24 +458,44 @@ fn handle_setup_message( #[cfg(test)] mod tset { - use crate::test_utils::{create_connected, log, Io}; + use crate::test_utils::{create_connected, log}; use super::*; use futures::{SinkExt, StreamExt}; + #[tokio::test] + async fn steps() -> Result<()> { + // figure out handshake problem + let mut left_hs = Handshake::new(true)?; + let s1 = left_hs.start_raw()?.unwrap(); + + println!("s1 {s1:?}"); + let mut right_hs = Handshake::new(false)?; + + let s2 = right_hs.read_raw(&s1)?.unwrap(); + println!("s2 {s2:?}"); + + let s3 = left_hs.read_raw(&s2)?.unwrap(); + println!("s3 {s3:?}"); + + let s4 = right_hs.read_raw(&s3)?; + + println!("s4 {s4:?}"); + Ok(()) + } + #[tokio::test] async fn test_encrypted() -> Result<()> { log(); + let expected = b"hello"; let (left, right) = create_connected(); - let mut left = Encrypted::new(true, left, "left"); - let mut right = Encrypted::new(true, right, "right"); + let mut left = Encrypted::new(true, left); + let mut right = Encrypted::new(false, right); tokio::task::spawn(async move { - left.send(b"hello".into()).await.unwrap(); + left.send(expected.into()).await.unwrap(); }); - //tokio::task::spawn(async move { - // let x = left.next().await; - //}); - dbg!(right.next().await); - todo!() + let result = right.next().await.unwrap(); + assert_eq!(result, expected); + Ok(()) } } From 792358dfa1eab32e9aed497bd7e96a289a6c8d5e Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 11 Mar 2025 17:29:41 -0400 Subject: [PATCH 013/206] Add RawEncrytpCipher --- src/crypto/cipher.rs | 78 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index c0e54a9..28ff1f4 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -184,3 +184,81 @@ fn write_stream_id(handshake_hash: &[u8], is_initiator: bool, out: &mut [u8]) { let result = result.as_slice(); out.copy_from_slice(result); } + +//NB "raw" here means UN-framed. No frame header. +const RAW_HEADER_MSG_LEN: usize = STREAM_ID_LENGTH + Header::BYTES; + +pub(crate) struct RawDecryptCipher { + pull_stream: PullStream, +} + +pub(crate) struct RawEncryptCipher { + push_stream: PushStream, +} +impl std::fmt::Debug for RawDecryptCipher { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DecryptCipher(crypto_secretstream)") + } +} + +impl std::fmt::Debug for RawEncryptCipher { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "EncryptCipher(crypto_secretstream)") + } +} + +impl RawEncryptCipher { + pub(crate) fn from_handshake_tx( + handshake_result: &HandshakeResult, + ) -> std::io::Result<(Self, Vec)> { + let key: [u8; KEY_LENGTH] = handshake_result.split_tx[..KEY_LENGTH] + .try_into() + .expect("split_tx with incorrect length"); + let key = Key::from(key); + + let mut header_message: [u8; RAW_HEADER_MSG_LEN] = [0; RAW_HEADER_MSG_LEN]; + + write_stream_id( + &handshake_result.handshake_hash, + handshake_result.is_initiator, + &mut header_message[..STREAM_ID_LENGTH], + ); + + let (header, push_stream) = PushStream::init(OsRng, &key); + let header = header.as_ref(); + header_message[STREAM_ID_LENGTH..].copy_from_slice(header); + let msg = header_message.to_vec(); + Ok((Self { push_stream }, msg)) + } + + // Possible API's: + // encrypted message is (tag + encrypted + mac ) + // to have *zero* alocations we could + // * take a buffer of the expected final length, plantext starts at 1 to 1 + planetext.len() + // * final length is 1 + plaintext.len() + mac.len() + // * we write tag to 0 + // * encrypt plain text part in place + // * write mac to end + // + // it would be akward to take an array like this. We could infer the plaintext via the buffer + // it's range would be (1..(buf.len() - mac.len())) + // encypt-in-place the palintext, + // For now... let's just return the encrypted buffer + /// Encrypts message in the given buffer to the same buffer, returns number of byte + pub(crate) fn encrypt(&mut self, buf: &mut [u8]) -> io::Result> { + let mut out = buf.to_vec(); + self.push_stream + .push(&mut out, &[], Tag::Message) + .map_err(|err| { + io::Error::new(io::ErrorKind::Other, format!("Encrypt failed: {err}")) + })?; + Ok(out) + } + /// Get the length needed for encryption, that includes padding. + pub(crate) fn safe_encrypted_len(&self, plaintext_len: usize) -> usize { + // ChaCha20-Poly1305 uses padding in two places, use two 15 bytes as a safe + // extra room. + // https://mailarchive.ietf.org/arch/msg/cfrg/u734TEOSDDWyQgE0pmhxjdncwvw/ + plaintext_len + 2 * 15 + } +} From 1f6fb96dd9a01e15e41db3f0f520034d3188fd91 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 12 Mar 2025 14:31:46 -0400 Subject: [PATCH 014/206] More tests, logging, lints --- src/crypto/cipher.rs | 41 ++++++++++++++------- src/crypto/handshake.rs | 3 ++ src/crypto/mod.rs | 2 +- src/noise.rs | 80 +++++++++++++++++++++++++---------------- src/protocol.rs | 4 +-- src/test_utils.rs | 19 ++++++++++ src/util.rs | 2 +- 7 files changed, 104 insertions(+), 47 deletions(-) diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index 28ff1f4..5c8d8a2 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -188,17 +188,9 @@ fn write_stream_id(handshake_hash: &[u8], is_initiator: bool, out: &mut [u8]) { //NB "raw" here means UN-framed. No frame header. const RAW_HEADER_MSG_LEN: usize = STREAM_ID_LENGTH + Header::BYTES; -pub(crate) struct RawDecryptCipher { - pull_stream: PullStream, -} - pub(crate) struct RawEncryptCipher { push_stream: PushStream, -} -impl std::fmt::Debug for RawDecryptCipher { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "DecryptCipher(crypto_secretstream)") - } + buf: Vec, } impl std::fmt::Debug for RawEncryptCipher { @@ -228,7 +220,13 @@ impl RawEncryptCipher { let header = header.as_ref(); header_message[STREAM_ID_LENGTH..].copy_from_slice(header); let msg = header_message.to_vec(); - Ok((Self { push_stream }, msg)) + Ok(( + Self { + push_stream, + buf: Default::default(), + }, + msg, + )) } // Possible API's: @@ -244,9 +242,12 @@ impl RawEncryptCipher { // it's range would be (1..(buf.len() - mac.len())) // encypt-in-place the palintext, // For now... let's just return the encrypted buffer - /// Encrypts message in the given buffer to the same buffer, returns number of byte - pub(crate) fn encrypt(&mut self, buf: &mut [u8]) -> io::Result> { - let mut out = buf.to_vec(); + // + /// Encrypts `msg` and returns the encrypted bytes + pub(crate) fn encrypt(&mut self, msg: &[u8]) -> io::Result> { + // NB: the result is written in place to the provided, however the buffer must be able to + // grow, since the encrypted message is bigger. So here we convert the slice to a vec. + let mut out = msg.to_vec(); self.push_stream .push(&mut out, &[], Tag::Message) .map_err(|err| { @@ -254,6 +255,20 @@ impl RawEncryptCipher { })?; Ok(out) } + + pub(crate) fn encrypt_in_place<'a>(&'a mut self, msg: &[u8]) -> io::Result<&'a [u8]> { + let min_safe_length = self.safe_encrypted_len(msg.len()); + if self.buf.len() < min_safe_length { + self.buf.resize(min_safe_length, 0); + } + // write message starting at index 1. we write the tag to index zero + self.buf[1..].copy_from_slice(msg); + // insert tag + // let enc_len = self.encrypt_no_alloc(&mut self.buff, 1..(1 + msg.len()))?; + // self.buf[..enc_len] + todo!() + } + /// Get the length needed for encryption, that includes padding. pub(crate) fn safe_encrypted_len(&self, plaintext_len: usize) -> usize { // ChaCha20-Poly1305 uses padding in two places, use two 15 bytes as a safe diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index 21c5442..74a1ada 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -7,6 +7,7 @@ use blake2::{ use snow::resolvers::{DefaultResolver, FallbackResolver}; use snow::{Builder, Error as SnowError, HandshakeState}; use std::io::{Error, ErrorKind, Result}; +use tracing::instrument; const CIPHERKEYLEN: usize = 32; const HANDSHAKE_PATTERN: &str = "Noise_XX_Ed25519_ChaChaPoly_BLAKE2b"; @@ -81,6 +82,7 @@ pub(crate) struct Handshake { } impl Handshake { + #[instrument] pub(crate) fn new(is_initiator: bool) -> Result { let (state, local_pubkey) = build_handshake_state(is_initiator).map_err(map_err)?; @@ -132,6 +134,7 @@ impl Handshake { .map_err(map_err) } + #[instrument(skip_all, fields(is_initiator = %self.result.is_initiator))] pub(crate) fn read_raw(&mut self, msg: &[u8]) -> Result>> { // eprintln!("hs read len {}", msg.len()); if self.complete() { diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 66bb62d..27f12b4 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -1,5 +1,5 @@ mod cipher; mod curve; mod handshake; -pub(crate) use cipher::{DecryptCipher, EncryptCipher}; +pub(crate) use cipher::{DecryptCipher, EncryptCipher, RawEncryptCipher}; pub(crate) use handshake::{Handshake, HandshakeResult}; diff --git a/src/noise.rs b/src/noise.rs index c22c1fd..beff242 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -7,19 +7,9 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use tracing::{error, info, instrument, trace, warn}; +use tracing::{debug, error, info, instrument, trace, warn}; -use crate::crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeResult, RawEncryptCipher}; - -macro_rules! name { - ($name:tt) => {{ - if $name { - "initiator" - } else { - "other" - } - }}; -} +use crate::crypto::{DecryptCipher, Handshake, HandshakeResult, RawEncryptCipher}; #[derive(Debug)] pub(crate) enum Step { @@ -28,6 +18,20 @@ pub(crate) enum Step { SecretStream((RawEncryptCipher, HandshakeResult)), Established((RawEncryptCipher, DecryptCipher, HandshakeResult)), } +impl std::fmt::Display for Step { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Step::NotInitialized => "NotInitialized", + Step::Handshake(_) => "Handshake", + Step::SecretStream(_) => "SecretStream", + Step::Established(_) => "Established", + } + ) + } +} /// Wrap a stream with encryption #[derive(Debug)] @@ -55,7 +59,7 @@ where IO: Stream> + Sink> + Send + Unpin + Debug + 'static, { /// Create [`Self`] from a Stream/Sink - #[instrument(skip_all, fields(name = %ename(is_initiator)))] + #[instrument(skip_all, fields(is_initiator = %is_initiator))] pub fn new(is_initiator: bool, io: IO) -> Self { Self { io, @@ -83,14 +87,14 @@ impl> + Sink> + Send + Unpin + Debug + 'static Poll::Ready(Ok(())) } - #[instrument(skip_all, fields(name = %ename(self.is_initiator)))] + #[instrument(skip_all, fields(is_initiator = %self.is_initiator))] fn start_send(mut self: Pin<&mut Self>, item: Vec) -> std::result::Result<(), Self::Error> { trace!("add plain tx"); self.plain_tx.push_back(item); Ok(()) } - #[instrument(skip_all, fields(name = %ename(self.is_initiator)))] + #[instrument(skip_all, fields(is_initiator = %self.is_initiator))] fn poll_flush( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -225,7 +229,7 @@ impl> + Sink> + Send + Unpin + Debug + 'static } } - #[instrument(skip_all, fields(name = %ename(self.is_initiator)))] + #[instrument(skip_all, fields(is_initiator = %self.is_initiator))] fn poll_close( self: Pin<&mut Self>, _cx: &mut Context<'_>, @@ -238,7 +242,7 @@ impl> + Sink> + Send + Unpin + Debug + 'static { type Item = Vec; - #[instrument(skip_all, fields(name = %ename(self.is_initiator)))] + #[instrument(skip_all, fields(is_initiator = %self.is_initiator))] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let Encrypted { io, @@ -387,14 +391,15 @@ fn maybe_init(step: &mut Step, is_initiator: bool) -> Result>> { Ok(out) } -#[instrument(skip_all, fields(name = %ename(is_initiator)))] +#[instrument(skip_all, fields(is_initiator = %is_initiator))] fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Result>> { match &step { Step::NotInitialized => { - warn!("{} Encrypted state was reset", name!(is_initiator)); + warn!(initiator = %is_initiator, "Encrypted state was reset"); let mut handshake = Handshake::new(is_initiator)?; let start_msg = handshake.start_raw()?; *step = Step::Handshake(Box::new(handshake)); + debug!(initiator = %is_initiator, "Step changed to {step}"); Ok(start_msg.map(|x| vec![x]).unwrap_or(vec![])) } @@ -405,19 +410,18 @@ fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Resu if let Some(response) = match handshake.read_raw(msg) { Ok(x) => x, Err(e) => { - panic!("error in handshake.read_raw(msg) {e:?}"); return Err(e); } } { info!( - "{} read message and emitting response {response:?}", - name!(is_initiator) + initiator = %is_initiator, + "read message and emitting response {response:?}", ); out.push(response); } if handshake.complete() { - info!("{} HS complete. Making result", name!(is_initiator)); + debug!(initiator = %is_initiator, "HS complete. Making result"); let handshake_result = match handshake.into_result() { Ok(x) => x, Err(e) => { @@ -434,9 +438,9 @@ fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Resu return Err(e); } }; - info!("{} made enc cipher", name!(is_initiator)); out.push(init_msg); *step = Step::SecretStream((cipher, handshake_result.clone())); + debug!(initiator = %is_initiator, "Step changed to {step}"); } else { *step = Step::Handshake(handshake); } @@ -449,6 +453,7 @@ fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Resu { let dec_cipher = DecryptCipher::from_handshake_rx_and_init_msg(&hs_result, msg)?; *step = Step::Established((enc_cipher, dec_cipher, hs_result)); + debug!(initiator = %is_initiator, "Step changed to {step}"); } Ok(vec![]) } @@ -465,7 +470,6 @@ mod tset { #[tokio::test] async fn steps() -> Result<()> { - // figure out handshake problem let mut left_hs = Handshake::new(true)?; let s1 = left_hs.start_raw()?.unwrap(); @@ -481,21 +485,37 @@ mod tset { let s4 = right_hs.read_raw(&s3)?; println!("s4 {s4:?}"); + // both sides now ready Ok(()) } #[tokio::test] async fn test_encrypted() -> Result<()> { log(); - let expected = b"hello"; + let hello = b"hello"; + let world = b"world"; let (left, right) = create_connected(); let mut left = Encrypted::new(true, left); let mut right = Encrypted::new(false, right); - tokio::task::spawn(async move { - left.send(expected.into()).await.unwrap(); + + // NB: we cannot totally finish 'left.send' until the other side becomes active + // this is because the handshake with the other side ('right') must complete + // before the message is sent. So we must spawn here, so we can proceed to run 'right' + let left_handle = tokio::task::spawn(async move { + left.send(hello.into()).await.unwrap(); + left }); - let result = right.next().await.unwrap(); - assert_eq!(result, expected); + + // right recieves left's message + assert_eq!(right.next().await.unwrap(), hello); + + let mut left = left_handle.await?; + + // now that the encrypted channel is established, we don't need to spawn. + right.send(world.into()).await.unwrap(); + + // left recieves right's message + assert_eq!(left.next().await.unwrap(), world); Ok(()) } } diff --git a/src/protocol.rs b/src/protocol.rs index 3e8a2b5..9d1ebe9 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -10,7 +10,7 @@ use std::io::{self, Error, ErrorKind, Result}; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; -use tracing::trace; +use tracing::{info, trace}; use crate::channels::{Channel, ChannelMap}; use crate::constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}; @@ -466,7 +466,7 @@ where if self.options.encrypted { // The cipher will be put to use to the writer only after the peer's answer has come - let (cipher, init_msg) = EncryptCipher::from_handshake_tx(&handshake_result)?; + let (cipher, init_msg) = EncryptCipher::from_handshake_tx(handshake_result)?; self.state = State::SecretStream(Some(cipher)); // Send the secret stream init message header to the other side diff --git a/src/test_utils.rs b/src/test_utils.rs index 273935f..7d8c3a7 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -1,5 +1,6 @@ use std::{ pin::Pin, + sync::OnceLock, task::{Context, Poll}, }; @@ -71,6 +72,24 @@ impl TwoWay { pub(crate) fn create_connected() -> (Io, Io) { TwoWay::default().split_sides() } + +#[allow(dead_code)] +pub(crate) fn log() { + use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; + static START_LOGS: OnceLock<()> = OnceLock::new(); + START_LOGS.get_or_init(|| { + tracing_subscriber::fmt() + .with_target(true) + .with_line_number(true) + // print when instrumented funtion enters + .with_span_events(FmtSpan::ENTER | FmtSpan::EXIT) + .with_file(true) + .with_env_filter(EnvFilter::from_default_env()) // Reads `RUST_LOG` environment variable + .without_time() + .init(); + }); +} + #[tokio::test] async fn way_one() { let mut a = Io::default(); diff --git a/src/util.rs b/src/util.rs index c99ff9c..1350728 100644 --- a/src/util.rs +++ b/src/util.rs @@ -31,7 +31,7 @@ pub(crate) fn map_channel_err(err: async_channel::SendError) -> Error { pub(crate) const UINT_24_LENGTH: usize = 3; #[inline] -pub(crate) fn wrap_uint24_le(data: &Vec) -> Vec { +pub(crate) fn wrap_uint24_le(data: &[u8]) -> Vec { let mut buf: Vec = vec![0; 3]; let n = data.len(); write_uint24_le(n, &mut buf); From 730baac9c90a876b1b322c594b6f1dfa510b0050 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 12 Mar 2025 16:10:59 -0400 Subject: [PATCH 015/206] rm encrypt_in_place --- src/crypto/cipher.rs | 30 +----------------------------- 1 file changed, 1 insertion(+), 29 deletions(-) diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index 5c8d8a2..cbf84bc 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -190,7 +190,6 @@ const RAW_HEADER_MSG_LEN: usize = STREAM_ID_LENGTH + Header::BYTES; pub(crate) struct RawEncryptCipher { push_stream: PushStream, - buf: Vec, } impl std::fmt::Debug for RawEncryptCipher { @@ -220,13 +219,7 @@ impl RawEncryptCipher { let header = header.as_ref(); header_message[STREAM_ID_LENGTH..].copy_from_slice(header); let msg = header_message.to_vec(); - Ok(( - Self { - push_stream, - buf: Default::default(), - }, - msg, - )) + Ok((Self { push_stream }, msg)) } // Possible API's: @@ -255,25 +248,4 @@ impl RawEncryptCipher { })?; Ok(out) } - - pub(crate) fn encrypt_in_place<'a>(&'a mut self, msg: &[u8]) -> io::Result<&'a [u8]> { - let min_safe_length = self.safe_encrypted_len(msg.len()); - if self.buf.len() < min_safe_length { - self.buf.resize(min_safe_length, 0); - } - // write message starting at index 1. we write the tag to index zero - self.buf[1..].copy_from_slice(msg); - // insert tag - // let enc_len = self.encrypt_no_alloc(&mut self.buff, 1..(1 + msg.len()))?; - // self.buf[..enc_len] - todo!() - } - - /// Get the length needed for encryption, that includes padding. - pub(crate) fn safe_encrypted_len(&self, plaintext_len: usize) -> usize { - // ChaCha20-Poly1305 uses padding in two places, use two 15 bytes as a safe - // extra room. - // https://mailarchive.ietf.org/arch/msg/cfrg/u734TEOSDDWyQgE0pmhxjdncwvw/ - plaintext_len + 2 * 15 - } } From bb7354b97cefbbffd18571274a8bf0f20959618f Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 12 Mar 2025 16:11:46 -0400 Subject: [PATCH 016/206] rm stuff used for development --- src/noise.rs | 72 ++++++++++++++++++---------------------------------- 1 file changed, 24 insertions(+), 48 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index beff242..4851c5b 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -44,16 +44,8 @@ pub struct Encrypted { plain_tx: VecDeque>, plain_rx: VecDeque>, flush: bool, - count: usize, } -fn ename(is_initiator: bool) -> String { - if is_initiator { - "initiator".to_string() - } else { - "other".to_string() - } -} impl Encrypted where IO: Stream> + Sink> + Send + Unpin + Debug + 'static, @@ -70,7 +62,6 @@ where plain_tx: Default::default(), plain_rx: Default::default(), flush: false, - count: 0, } } } @@ -108,21 +99,13 @@ impl> + Sink> + Send + Unpin + Debug + 'static plain_tx, plain_rx, flush, - count, .. } = self.get_mut(); - *count += 1; - if *count > 200 { - //panic!(); - } // send any pending outgoing messages while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { if let Some(encrypted_out) = encrypted_tx.pop_front() { - trace!( - name = %ename(*is_initiator), - "enc tx send msg\n{encrypted_out:?}" - ); + trace!(initiator = %is_initiator, "enc tx send msg\n{encrypted_out:?}"); let _todo = Sink::start_send(Pin::new(io), encrypted_out); *flush = true; } else { @@ -144,10 +127,10 @@ impl> + Sink> + Send + Unpin + Debug + 'static match Sink::poll_flush(Pin::new(io), cx) { Poll::Ready(Ok(())) => { *flush = false; - trace!(name = %ename(*is_initiator), "flushed good"); + trace!(initiator = %is_initiator, "flushed good"); } Poll::Ready(Err(_e)) => error!( - name = %ename(*is_initiator), + initiator = %is_initiator, "Error sending encrypted msg"), Poll::Pending => { // More confusing docs @@ -168,9 +151,9 @@ impl> + Sink> + Send + Unpin + Debug + 'static loop { match Stream::poll_next(Pin::new(io), cx) { Poll::Pending => break, - Poll::Ready(None) => todo!(), + Poll::Ready(None) => break, Poll::Ready(Some(encrypted_msg)) => { - trace!(name = %ename(*is_initiator), "enc rx queue\n{encrypted_msg:?}"); + trace!(initiator = %is_initiator, "enc rx queue\n{encrypted_msg:?}"); encrypted_rx.push_back(encrypted_msg); } } @@ -181,11 +164,11 @@ impl> + Sink> + Send + Unpin + Debug + 'static while let Some(incoming_msg) = encrypted_rx.pop_front() { match decryptor.decrypt_buf(&incoming_msg) { Ok((plain_msg, _tag)) => { - trace!(name = %ename(*is_initiator), "plain rx queue"); + trace!(initiator = %is_initiator, "plain rx queue"); plain_rx.push_back(plain_msg); } Err(e) => { - error!(name = %ename(*is_initiator), "RX message failed to decrypt: {e:?}") + error!(initiator = %is_initiator, "RX message failed to decrypt: {e:?}") } } } @@ -197,7 +180,7 @@ impl> + Sink> + Send + Unpin + Debug + 'static Ok(x) => x, Err(_e) => todo!("We failed to encrypt our own message...?"), }; - trace!(name = %ename(*is_initiator), "enc from plain tx queue\n{enc_out:?}"); + trace!(initiator = %is_initiator, "enc from plain tx queue\n{enc_out:?}"); encrypted_tx.push_back(enc_out); *flush = true; } @@ -212,14 +195,14 @@ impl> + Sink> + Send + Unpin + Debug + 'static // Still setting up if let Ok(Some(msg)) = maybe_init(step, *is_initiator) { // queue the init message to send first - trace!(name = %ename(*is_initiator),"queue initial msg\n{msg:?}"); + trace!(initiator = %is_initiator,"queue initial msg\n{msg:?}"); encrypted_tx.push_front(msg); } while let Some(incoming_msg) = encrypted_rx.pop_front() { - trace!(name = %ename(*is_initiator),"recieved setup msg"); + trace!(initiator = %is_initiator,"recieved setup msg"); if let Ok(msgs) = handle_setup_message(step, &incoming_msg, *is_initiator) { for msg in msgs.into_iter().rev() { - trace!(name = %ename(*is_initiator),"queue more setup msg\n{msg:?}"); + trace!(initiator = %is_initiator,"queue more setup msg\n{msg:?}"); encrypted_tx.push_front(msg); } } @@ -253,21 +236,13 @@ impl> + Sink> + Send + Unpin + Debug + 'static plain_tx, plain_rx, flush, - count, .. } = self.get_mut(); - *count += 1; - if *count > 200 { - //panic!(); - } // send any pending outgoing messages while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { if let Some(encrypted_out) = encrypted_tx.pop_front() { - trace!(name = %ename(*is_initiator), "enc tx send msg - {encrypted_out:?} -" - ); + trace!(initiator = %is_initiator, "enc tx send msg\n{encrypted_out:?}"); let _todo = Sink::start_send(Pin::new(io), encrypted_out); *flush = true; } else { @@ -289,10 +264,10 @@ impl> + Sink> + Send + Unpin + Debug + 'static match Sink::poll_flush(Pin::new(io), cx) { Poll::Ready(Ok(())) => { *flush = false; - trace!(name = %ename(*is_initiator), "flushed good"); + trace!(initiator = %is_initiator, "flushed good"); } Poll::Ready(Err(_e)) => { - error!(name = %ename(*is_initiator), "Error sending encrypted msg") + error!(initiator = %is_initiator, "Error sending encrypted msg") } Poll::Pending => { // More confusing docs @@ -315,7 +290,8 @@ impl> + Sink> + Send + Unpin + Debug + 'static Poll::Pending => break, Poll::Ready(None) => break, Poll::Ready(Some(encrypted_msg)) => { - trace!(name = %ename(*is_initiator), "enc rx queue\n{encrypted_msg:?}"); + trace!( + initiator = %is_initiator, "enc rx queue\n{encrypted_msg:?}"); encrypted_rx.push_back(encrypted_msg); } } @@ -326,11 +302,11 @@ impl> + Sink> + Send + Unpin + Debug + 'static while let Some(incoming_msg) = encrypted_rx.pop_front() { match decryptor.decrypt_buf(&incoming_msg) { Ok((plain_msg, _tag)) => { - trace!(name = %ename(*is_initiator), "plain rx queue"); + trace!(initiator = %is_initiator, "plain rx queue"); plain_rx.push_back(plain_msg); } Err(e) => { - error!(name = %ename(*is_initiator),"RX message failed to decrypt: {e:?}") + error!(initiator = %is_initiator,"RX message failed to decrypt: {e:?}") } } } @@ -341,13 +317,13 @@ impl> + Sink> + Send + Unpin + Debug + 'static Ok(x) => x, Err(_e) => todo!("We failed to encrypt our own message...?"), }; - trace!(name = %ename(*is_initiator), "enc from plain tx queue\n{enc_out:?}"); + trace!(initiator = %is_initiator, "enc from plain tx queue\n{enc_out:?}"); encrypted_tx.push_back(enc_out); } // emit any messages that are ready if let Some(msg) = plain_rx.pop_front() { - trace!(name = %ename(*is_initiator), "plain rx emit"); + trace!(initiator = %is_initiator, "plain rx emit"); Poll::Ready(Some(msg)) } else { Poll::Pending @@ -356,11 +332,11 @@ impl> + Sink> + Send + Unpin + Debug + 'static // Still setting up if let Ok(Some(msg)) = maybe_init(step, *is_initiator) { // queue the init message to send first - trace!(name = %ename(*is_initiator),"queue initial msg\n{msg:?}"); + trace!(initiator = %is_initiator,"queue initial msg\n{msg:?}"); encrypted_tx.push_front(msg); } while let Some(incoming_msg) = encrypted_rx.pop_front() { - trace!(name = %ename(*is_initiator), "recieved setup msg"); + trace!(initiator = %is_initiator, "recieved setup msg"); if let Ok(msgs) = match handle_setup_message(step, &incoming_msg, *is_initiator) { Ok(x) => Ok(x), Err(e) => { @@ -369,7 +345,7 @@ impl> + Sink> + Send + Unpin + Debug + 'static } } { for msg in msgs.into_iter().rev() { - trace!(name = %ename(*is_initiator),"queue more setup msg\n{msg:?}"); + trace!(initiator = %is_initiator,"queue more setup msg\n{msg:?}"); encrypted_tx.push_front(msg); } } @@ -384,7 +360,7 @@ fn maybe_init(step: &mut Step, is_initiator: bool) -> Result>> { if !matches!(step, Step::NotInitialized) { return Ok(None); } - trace!(name = %ename(is_initiator), "Init, state {step:?}"); + trace!(initiator = %is_initiator, "Init, state {step:?}"); let mut handshake = Handshake::new(is_initiator)?; let out = handshake.start_raw()?; *step = Step::Handshake(Box::new(handshake)); From 4e6217a0122d1770c2324e6be297d07e172db620 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 13 Mar 2025 17:23:39 -0400 Subject: [PATCH 017/206] Add LengthPrefixed framing --- src/framing.rs | 249 +++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 2 + 2 files changed, 251 insertions(+) create mode 100644 src/framing.rs diff --git a/src/framing.rs b/src/framing.rs new file mode 100644 index 0000000..3504895 --- /dev/null +++ b/src/framing.rs @@ -0,0 +1,249 @@ +use std::{ + collections::VecDeque, + fmt::Debug, + io::Result, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::{Sink, Stream}; +use futures_lite::io::{AsyncRead, AsyncWrite}; +use tracing::{debug, instrument, trace}; + +use crate::util::{stat_uint24_le, wrap_uint24_le}; + +const BUF_SIZE: usize = 1024 * 8; +const HEADER_LEN: usize = 3; + +/// Turn a `AsyncWrite` of length prefixed messages and emit the messages with a Stream +pub struct LengthPrefixed { + io: IO, + to_stream: Vec, + from_sink: VecDeque>, + /// The index in [`Self::buf`] of the last byte that was to the [`Stream`]. + last_out_idx: usize, + /// The index in [`Self::buf`] of the last byte that was read from [`Self::io`] via + /// [`AsyncRead`] + last_data_idx: usize, + step: Step, +} +impl Debug for LengthPrefixed { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Format()") + } +} +impl LengthPrefixed +where + IO: AsyncWrite + AsyncRead + Send + Unpin + 'static, +{ + /// Build [`LengthPrefixed`] around an [`AsyncWrite`]/[`AsyncRead`] thing. + pub fn new(io: IO) -> Self { + Self { + io, + to_stream: vec![0u8; BUF_SIZE], + from_sink: VecDeque::new(), + last_out_idx: 0, + last_data_idx: 0, + step: Step::Header, + } + } +} + +#[derive(Debug)] +enum Step { + Header, + Body { start: usize, end: u64 }, +} + +impl Stream for LengthPrefixed +where + IO: AsyncWrite + AsyncRead + Send + Unpin + 'static, +{ + type Item = Result>; + + #[instrument(skip_all)] + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + trace!("from poll next!!"); + let Self { + io, + to_stream, + last_out_idx, + last_data_idx, + step, + .. + } = self.get_mut(); + let n_bytes_read = match Pin::new(io).poll_read(cx, &mut to_stream[*last_data_idx..]) { + Poll::Ready(Ok(n)) => n, + Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))), + Poll::Pending => 0, + }; + // TODO handle if to_stream is full + trace!("adding #=[{n_bytes_read}] bytes to end=[{}]", last_data_idx); + *last_data_idx += n_bytes_read; + // grow buffer if it's full + if *last_data_idx == to_stream.len() - 1 { + to_stream.extend(vec![0; to_stream.len() * 2]); + } + + if let Step::Header = step { + trace!(step = ?*step, "enter"); + if *last_data_idx - *last_out_idx < HEADER_LEN { + trace!("not enough bytes to read header"); + return Poll::Pending; + } + let Some((header_len, body_len)) = + stat_uint24_le(&to_stream[*last_out_idx..(*last_out_idx + HEADER_LEN)]) + else { + // we check above the there is room for header so this should never happen + todo!() + }; + + let cur_frame_start = *last_out_idx + header_len; + let cur_frame_end = (cur_frame_start as u64) + body_len; + *step = Step::Body { + start: cur_frame_start, + end: cur_frame_end, + }; + } + + if let Step::Body { start, end } = step { + let end = *end as usize; + if end <= *last_data_idx { + debug!(frame_size = end - *start, "Frame ready"); + let out = to_stream[*start..end].to_vec(); + *step = Step::Header; + *last_out_idx = end; + + return Poll::Ready(Some(Ok(out))); + } + } + Poll::Pending + } +} +impl Sink> for LengthPrefixed +where + IO: AsyncWrite + AsyncRead + Send + Unpin + 'static, +{ + type Error = std::io::Error; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(mut self: Pin<&mut Self>, item: Vec) -> std::result::Result<(), Self::Error> { + self.from_sink.push_back(wrap_uint24_le(&item)); + Ok(()) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let Self { from_sink, io, .. } = self.get_mut(); + if let Some(msg) = from_sink.pop_front() { + match Pin::new(io).poll_write(cx, &msg) { + Poll::Pending => { + from_sink.push_front(msg); + return Poll::Pending; + } + Poll::Ready(Ok(n)) => { + if n != msg.len() { + from_sink.push_front(msg[n..].to_vec()); + return Poll::Ready(Ok(())); + } + } + Poll::Ready(Err(_e)) => todo!(), + } + } + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + todo!() + } +} +#[cfg(test)] +mod test { + use crate::test_utils::log; + + use super::*; + use futures::{ + io::{AsyncReadExt, AsyncWriteExt}, + AsyncRead, AsyncWrite, SinkExt, StreamExt, + }; + use tokio_util::compat::TokioAsyncReadCompatExt; + + fn duplex(channel_size: usize) -> (impl AsyncRead + AsyncWrite, impl AsyncRead + AsyncWrite) { + let (left, right) = tokio::io::duplex(channel_size); + (left.compat(), right.compat()) + } + + #[tokio::test] + async fn t_duplex() -> Result<()> { + let (mut left, mut right) = duplex(64); + left.write_all(b"hello").await?; + let mut b = vec![0; 5]; + right.read_exact(&mut b).await?; + assert_eq!(b, b"hello"); + Ok(()) + } + + #[tokio::test] + async fn t_input() -> Result<()> { + log(); + let (left, mut right) = duplex(64); + let mut lp = LengthPrefixed::new(left); + let input = b"yelp"; + let msg = wrap_uint24_le(input); + dbg!(&msg); + right.write_all(&msg).await?; + let Some(Ok(rx)) = lp.next().await else { + panic!() + }; + assert_eq!(rx, input); + Ok(()) + } + #[tokio::test] + async fn t_stream_many() -> Result<()> { + log(); + let (left, mut right) = duplex(64); + let mut lp = LengthPrefixed::new(left); + let data: &[&[u8]] = &[b"yolo", b"squalor", b"idle", b"hello", b"stuff"]; + for d in data { + let msg = wrap_uint24_le(d); + right.write_all(&msg).await?; + } + for d in data { + dbg!(); + let Some(Ok(res)) = lp.next().await else { + panic!(); + }; + dbg!(&res); + assert_eq!(&res, d); + } + Ok(()) + } + #[tokio::test] + async fn t_sink_many() -> Result<()> { + log(); + let (left, mut right) = duplex(64); + let mut lp = LengthPrefixed::new(left); + let data: &[&[u8]] = &[b"yolo", b"squalor", b"idle", b"hello", b"stuff"]; + for d in data { + lp.send(d.to_vec()).await.unwrap(); + } + + let mut expected = vec![]; + data.iter().for_each(|d| expected.extend(wrap_uint24_le(d))); + let mut result = vec![0; expected.len()]; + right.read_exact(&mut result).await?; + assert_eq!(result, expected); + Ok(()) + } +} diff --git a/src/lib.rs b/src/lib.rs index 646a93e..b1a043a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -122,6 +122,7 @@ mod channels; mod constants; mod crypto; mod duplex; +mod framing; mod message; mod noise; mod protocol; @@ -136,6 +137,7 @@ pub mod schema; pub use builder::Builder as ProtocolBuilder; pub use channels::Channel; +pub use framing::LengthPrefixed; pub use noise::Encrypted; // Export the needed types for Channel::take_receiver, and Channel::local_sender() pub use async_channel::{ From b576c9f4d2a7fc16885c9bc53a07d1f4b011306b Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 13 Mar 2025 17:49:41 -0400 Subject: [PATCH 018/206] use var for header len --- src/crypto/cipher.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index cbf84bc..8278692 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -148,7 +148,7 @@ impl EncryptCipher { let encrypted_len = to_encrypt.len(); write_uint24_le(encrypted_len, buf); buf[header_len..header_len + encrypted_len].copy_from_slice(to_encrypt.as_slice()); - Ok(3 + encrypted_len) + Ok(header_len + encrypted_len) } else { Err(io::Error::new( io::ErrorKind::InvalidData, From a3a50d071da37309926104e1995cd87e703a35ef Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 13 Mar 2025 17:50:19 -0400 Subject: [PATCH 019/206] refactor encrypted poll functions --- src/noise.rs | 357 +++++++++++++++++++++++---------------------------- 1 file changed, 159 insertions(+), 198 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index 4851c5b..84fc381 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -69,7 +69,7 @@ where impl> + Sink> + Send + Unpin + Debug + 'static> Sink> for Encrypted { - type Error = (); + type Error = std::io::Error; fn poll_ready( self: Pin<&mut Self>, @@ -102,88 +102,19 @@ impl> + Sink> + Send + Unpin + Debug + 'static .. } = self.get_mut(); - // send any pending outgoing messages - while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { - if let Some(encrypted_out) = encrypted_tx.pop_front() { - trace!(initiator = %is_initiator, "enc tx send msg\n{encrypted_out:?}"); - let _todo = Sink::start_send(Pin::new(io), encrypted_out); - *flush = true; - } else { - break; - } - } - if *flush { - // confusing docs related to start send - // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.start_send - // First part says: - // "you must use **poll_flush** ... inorder to garuntee - // completions of send" - // Then it says: - // " It is only necessary to call poll_flush if you need to guarantee that all - // of the items placed into the Sink have been sent" - // - // So do I need to do it or not? - // must `poll_flush` be called for **anything** to send? - match Sink::poll_flush(Pin::new(io), cx) { - Poll::Ready(Ok(())) => { - *flush = false; - trace!(initiator = %is_initiator, "flushed good"); - } - Poll::Ready(Err(_e)) => error!( - initiator = %is_initiator, - "Error sending encrypted msg"), - Poll::Pending => { - // More confusing docs - // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.poll_flush - // It says: - // "Returns Poll::Pending if there is more work left to do, in which case the - // current task is scheduled (via cx.waker().wake_by_ref()) to wake up when - // poll_flush should be called again." - // Does this mean, each time this task wakes up again from this code path that - // I must trigger another poll_flush? But how would I know i need more - // flushing? - *flush = true; - } - } - } - - // pull in any incomming encrypted messages - loop { - match Stream::poll_next(Pin::new(io), cx) { - Poll::Pending => break, - Poll::Ready(None) => break, - Poll::Ready(Some(encrypted_msg)) => { - trace!(initiator = %is_initiator, "enc rx queue\n{encrypted_msg:?}"); - encrypted_rx.push_back(encrypted_msg); - } - } - } + poll_encrypted_side_io(io, cx, encrypted_tx, encrypted_rx, *is_initiator, flush); if let Step::Established((encryptor, decryptor, ..)) = step { - // decrypt any incromming encrypted messages - while let Some(incoming_msg) = encrypted_rx.pop_front() { - match decryptor.decrypt_buf(&incoming_msg) { - Ok((plain_msg, _tag)) => { - trace!(initiator = %is_initiator, "plain rx queue"); - plain_rx.push_back(plain_msg); - } - Err(e) => { - error!(initiator = %is_initiator, "RX message failed to decrypt: {e:?}") - } - } - } - - // encrypt any pending plaintext outgoinng messages - while let Some(mut plain_out) = plain_tx.pop_front() { - // it encrypts in-place?? - let enc_out = match encryptor.encrypt(&mut plain_out) { - Ok(x) => x, - Err(_e) => todo!("We failed to encrypt our own message...?"), - }; - trace!(initiator = %is_initiator, "enc from plain tx queue\n{enc_out:?}"); - encrypted_tx.push_back(enc_out); - *flush = true; - } + poll_do_encrypt_and_decrypt( + encryptor, + decryptor, + encrypted_tx, + encrypted_rx, + plain_tx, + plain_rx, + *is_initiator, + flush, + ); if *flush { cx.waker().wake_by_ref(); @@ -192,21 +123,7 @@ impl> + Sink> + Send + Unpin + Debug + 'static Poll::Ready(Ok(())) } } else { - // Still setting up - if let Ok(Some(msg)) = maybe_init(step, *is_initiator) { - // queue the init message to send first - trace!(initiator = %is_initiator,"queue initial msg\n{msg:?}"); - encrypted_tx.push_front(msg); - } - while let Some(incoming_msg) = encrypted_rx.pop_front() { - trace!(initiator = %is_initiator,"recieved setup msg"); - if let Ok(msgs) = handle_setup_message(step, &incoming_msg, *is_initiator) { - for msg in msgs.into_iter().rev() { - trace!(initiator = %is_initiator,"queue more setup msg\n{msg:?}"); - encrypted_tx.push_front(msg); - } - } - } + poll_setup(step, encrypted_tx, encrypted_rx, *is_initiator); cx.waker().wake_by_ref(); Poll::Pending } @@ -220,6 +137,7 @@ impl> + Sink> + Send + Unpin + Debug + 'static todo!() } } + impl> + Sink> + Send + Unpin + Debug + 'static> Stream for Encrypted { @@ -239,88 +157,19 @@ impl> + Sink> + Send + Unpin + Debug + 'static .. } = self.get_mut(); - // send any pending outgoing messages - while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { - if let Some(encrypted_out) = encrypted_tx.pop_front() { - trace!(initiator = %is_initiator, "enc tx send msg\n{encrypted_out:?}"); - let _todo = Sink::start_send(Pin::new(io), encrypted_out); - *flush = true; - } else { - break; - } - } - if *flush { - // confusing docs related to start send - // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.start_send - // First part says: - // "you must use **poll_flush** ... inorder to garuntee - // completions of send" - // Then it says: - // " It is only necessary to call poll_flush if you need to guarantee that all - // of the items placed into the Sink have been sent" - // - // So do I need to do it or not? - // must `poll_flush` be called for **anything** to send? - match Sink::poll_flush(Pin::new(io), cx) { - Poll::Ready(Ok(())) => { - *flush = false; - trace!(initiator = %is_initiator, "flushed good"); - } - Poll::Ready(Err(_e)) => { - error!(initiator = %is_initiator, "Error sending encrypted msg") - } - Poll::Pending => { - // More confusing docs - // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.poll_flush - // It says: - // "Returns Poll::Pending if there is more work left to do, in which case the - // current task is scheduled (via cx.waker().wake_by_ref()) to wake up when - // poll_flush should be called again." - // Does this mean, each time this task wakes up again from this code path that - // I must trigger another poll_flush? But how would I know i need more - // flushing? - *flush = true; - } - } - } - - // pull in any incomming encrypted messages - loop { - match Stream::poll_next(Pin::new(io), cx) { - Poll::Pending => break, - Poll::Ready(None) => break, - Poll::Ready(Some(encrypted_msg)) => { - trace!( - initiator = %is_initiator, "enc rx queue\n{encrypted_msg:?}"); - encrypted_rx.push_back(encrypted_msg); - } - } - } + poll_encrypted_side_io(io, cx, encrypted_tx, encrypted_rx, *is_initiator, flush); if let Step::Established((encryptor, decryptor, ..)) = step { - // decrypt any incromming encrypted messages - while let Some(incoming_msg) = encrypted_rx.pop_front() { - match decryptor.decrypt_buf(&incoming_msg) { - Ok((plain_msg, _tag)) => { - trace!(initiator = %is_initiator, "plain rx queue"); - plain_rx.push_back(plain_msg); - } - Err(e) => { - error!(initiator = %is_initiator,"RX message failed to decrypt: {e:?}") - } - } - } - - // encrypt any pending plaintext outgoinng messages - while let Some(mut plain_out) = plain_tx.pop_front() { - let enc_out = match encryptor.encrypt(&mut plain_out) { - Ok(x) => x, - Err(_e) => todo!("We failed to encrypt our own message...?"), - }; - trace!(initiator = %is_initiator, "enc from plain tx queue\n{enc_out:?}"); - encrypted_tx.push_back(enc_out); - } - + poll_do_encrypt_and_decrypt( + encryptor, + decryptor, + encrypted_tx, + encrypted_rx, + plain_tx, + plain_rx, + *is_initiator, + flush, + ); // emit any messages that are ready if let Some(msg) = plain_rx.pop_front() { trace!(initiator = %is_initiator, "plain rx emit"); @@ -329,31 +178,144 @@ impl> + Sink> + Send + Unpin + Debug + 'static Poll::Pending } } else { - // Still setting up - if let Ok(Some(msg)) = maybe_init(step, *is_initiator) { - // queue the init message to send first - trace!(initiator = %is_initiator,"queue initial msg\n{msg:?}"); + poll_setup(step, encrypted_tx, encrypted_rx, *is_initiator); + cx.waker().wake_by_ref(); + Poll::Pending + } + } +} + +fn poll_setup( + step: &mut Step, + encrypted_tx: &mut VecDeque>, + encrypted_rx: &mut VecDeque>, + is_initiator: bool, +) { + // Still setting up + if let Ok(Some(msg)) = maybe_init(step, is_initiator) { + // queue the init message to send first + trace!(initiator = %is_initiator,"queue initial msg\n{msg:?}"); + encrypted_tx.push_front(msg); + } + while let Some(incoming_msg) = encrypted_rx.pop_front() { + trace!(initiator = %is_initiator, "recieved setup msg"); + if let Ok(msgs) = match handle_setup_message(step, &incoming_msg, is_initiator) { + Ok(x) => Ok(x), + Err(e) => { + error!("handle_setup_message error: {e:?}"); + Err(e) + } + } { + for msg in msgs.into_iter().rev() { + trace!(initiator = %is_initiator,"queue more setup msg\n{msg:?}"); encrypted_tx.push_front(msg); } - while let Some(incoming_msg) = encrypted_rx.pop_front() { - trace!(initiator = %is_initiator, "recieved setup msg"); - if let Ok(msgs) = match handle_setup_message(step, &incoming_msg, *is_initiator) { - Ok(x) => Ok(x), - Err(e) => { - error!("handle_setup_message error: {e:?}"); - Err(e) - } - } { - for msg in msgs.into_iter().rev() { - trace!(initiator = %is_initiator,"queue more setup msg\n{msg:?}"); - encrypted_tx.push_front(msg); - } - } + } + } +} + +fn poll_encrypted_side_io< + IO: Stream> + Sink> + Send + Unpin + Debug + 'static, +>( + io: &mut IO, + cx: &mut Context<'_>, + encrypted_tx: &mut VecDeque>, + encrypted_rx: &mut VecDeque>, + is_initiator: bool, + flush: &mut bool, +) { + // send any pending outgoing messages + while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { + if let Some(encrypted_out) = encrypted_tx.pop_front() { + trace!(initiator = %is_initiator, "enc tx send msg\n{encrypted_out:?}"); + let _todo = Sink::start_send(Pin::new(io), encrypted_out); + *flush = true; + } else { + break; + } + } + if *flush { + // confusing docs related to start send + // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.start_send + // First part says: + // "you must use **poll_flush** ... inorder to garuntee + // completions of send" + // Then it says: + // " It is only necessary to call poll_flush if you need to guarantee that all + // of the items placed into the Sink have been sent" + // + // So do I need to do it or not? + // must `poll_flush` be called for **anything** to send? + match Sink::poll_flush(Pin::new(io), cx) { + Poll::Ready(Ok(())) => { + *flush = false; + trace!(initiator = %is_initiator, "flushed good"); + } + Poll::Ready(Err(_e)) => error!( + initiator = %is_initiator, + "Error sending encrypted msg"), + Poll::Pending => { + // More confusing docs + // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.poll_flush + // It says: + // "Returns Poll::Pending if there is more work left to do, in which case the + // current task is scheduled (via cx.waker().wake_by_ref()) to wake up when + // poll_flush should be called again." + // Does this mean, each time this task wakes up again from this code path that + // I must trigger another poll_flush? But how would I know i need more + // flushing? + *flush = true; } - cx.waker().wake_by_ref(); - Poll::Pending } } + // pull in any incomming encrypted messages + loop { + match Stream::poll_next(Pin::new(io), cx) { + Poll::Pending => break, + Poll::Ready(None) => break, + Poll::Ready(Some(encrypted_msg)) => { + trace!(initiator = %is_initiator, "enc rx queue\n{encrypted_msg:?}"); + encrypted_rx.push_back(encrypted_msg); + } + } + } +} + +/// Process messages waiting to be encrypted or decrypted +// TODO sholud this return a Result +fn poll_do_encrypt_and_decrypt( + encryptor: &mut RawEncryptCipher, + decryptor: &mut DecryptCipher, + encrypted_tx: &mut VecDeque>, + encrypted_rx: &mut VecDeque>, + plain_tx: &mut VecDeque>, + plain_rx: &mut VecDeque>, + is_initiator: bool, + flush: &mut bool, +) { + // decrypt any incromming encrypted messages + while let Some(incoming_msg) = encrypted_rx.pop_front() { + match decryptor.decrypt_buf(&incoming_msg) { + Ok((plain_msg, _tag)) => { + trace!(initiator = %is_initiator, "plain rx queue"); + plain_rx.push_back(plain_msg); + } + Err(e) => { + error!(initiator = %is_initiator,"RX message failed to decrypt: {e:?}") + } + } + } + + // encrypt any pending plaintext outgoinng messages + while let Some(plain_out) = plain_tx.pop_front() { + let enc_out = match encryptor.encrypt(&plain_out) { + Ok(x) => x, + Err(_e) => todo!("We failed to encrypt our own message...?"), + }; + trace!(initiator = %is_initiator, "enc from plain tx queue\n{enc_out:?}"); + encrypted_tx.push_back(enc_out); + *flush = true; + } } fn maybe_init(step: &mut Step, is_initiator: bool) -> Result>> { @@ -439,7 +401,7 @@ fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Resu #[cfg(test)] mod tset { - use crate::test_utils::{create_connected, log}; + use crate::test_utils::create_connected; use super::*; use futures::{SinkExt, StreamExt}; @@ -467,7 +429,6 @@ mod tset { #[tokio::test] async fn test_encrypted() -> Result<()> { - log(); let hello = b"hello"; let world = b"world"; let (left, right) = create_connected(); From 94785a52b0074356f6b7b78380b6d9e99a6c50ed Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 13 Mar 2025 17:51:50 -0400 Subject: [PATCH 020/206] rename writer fields --- src/protocol.rs | 2 +- src/writer.rs | 59 +++++++++++++++++++++++++++---------------------- 2 files changed, 33 insertions(+), 28 deletions(-) diff --git a/src/protocol.rs b/src/protocol.rs index 9d1ebe9..1f24b1a 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -593,7 +593,7 @@ where fn queue_frame_direct(&mut self, body: Vec) -> Result { let mut frame = Frame::RawBatch(vec![body]); - self.write_state.try_queue_direct(&mut frame) + self.write_state.try_encode_frame_for_tx(&mut frame) } fn accept_channel(&mut self, local_id: usize) -> Result<()> { diff --git a/src/writer.rs b/src/writer.rs index e3cc5da..38d6dcf 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -19,11 +19,11 @@ pub(crate) enum Step { pub(crate) struct WriteState { queue: VecDeque, - buf: Vec, current_frame: Option, - start: usize, - end: usize, cipher: Option, + buf: Vec, + written_up_to_idx: usize, + should_write_up_to_idx: usize, step: Step, } @@ -31,12 +31,12 @@ impl fmt::Debug for WriteState { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("WriteState") .field("queue (len)", &self.queue.len()) - .field("step", &self.step) - .field("buf (len)", &self.buf.len()) .field("current_frame", &self.current_frame) - .field("start", &self.start) - .field("end", &self.end) .field("cipher", &self.cipher.is_some()) + .field("buf (len)", &self.buf.len()) + .field("start", &self.written_up_to_idx) + .field("end", &self.should_write_up_to_idx) + .field("step", &self.step) .finish() } } @@ -47,8 +47,8 @@ impl WriteState { queue: VecDeque::new(), buf: vec![0u8; BUF_SIZE], current_frame: None, - start: 0, - end: 0, + written_up_to_idx: 0, + should_write_up_to_idx: 0, cipher: None, step: Step::Processing, } @@ -61,7 +61,7 @@ impl WriteState { self.queue.push_back(frame.into()) } - pub(crate) fn try_queue_direct(&mut self, frame: &mut T) -> Result { + pub(crate) fn try_encode_frame_for_tx(&mut self, frame: &mut T) -> Result { let promised_len = frame.encoded_len()?; let padded_promised_len = self.safe_encrypted_len(promised_len); if self.buf.len() < padded_promised_len { @@ -70,13 +70,15 @@ impl WriteState { if padded_promised_len > self.remaining() { return Ok(false); } - let actual_len = frame.encode(&mut self.buf[self.end..])?; + + // write frame starting at end. fram is from end to end + actual_end + let actual_len = frame.encode(&mut self.buf[self.should_write_up_to_idx..])?; if actual_len != promised_len { panic!( "encoded_len() did not return that right size, expected={promised_len}, actual={actual_len}" ); } - self.advance(padded_promised_len)?; + self.encrypt_frame_contents(padded_promised_len)?; Ok(true) } @@ -93,16 +95,18 @@ impl WriteState { } } - fn advance(&mut self, n: usize) -> Result<()> { - let end = self.end + n; + fn encrypt_frame_contents(&mut self, max_message_size: usize) -> Result<()> { + let end_of_message_index = self.should_write_up_to_idx + max_message_size; let encrypted_end = if let Some(ref mut cipher) = self.cipher { - self.end + cipher.encrypt(&mut self.buf[self.end..end])? + self.should_write_up_to_idx + + cipher + .encrypt(&mut self.buf[self.should_write_up_to_idx..end_of_message_index])? } else { - end + end_of_message_index }; - self.end = encrypted_end; + self.should_write_up_to_idx = encrypted_end; Ok(()) } @@ -111,11 +115,11 @@ impl WriteState { } fn remaining(&self) -> usize { - self.buf.len() - self.end + self.buf.len() - self.should_write_up_to_idx } fn pending(&self) -> usize { - self.end - self.start + self.should_write_up_to_idx - self.written_up_to_idx } pub(crate) fn poll_send( @@ -134,7 +138,7 @@ impl WriteState { } if let Some(mut frame) = self.current_frame.take() { - if !self.try_queue_direct(&mut frame)? { + if !self.try_encode_frame_for_tx(&mut frame)? { self.current_frame = Some(frame); } } @@ -145,13 +149,14 @@ impl WriteState { Step::Writing } Step::Writing => { - let n = ready!( - Pin::new(&mut writer).poll_write(cx, &self.buf[self.start..self.end]) - )?; - self.start += n; - if self.start == self.end { - self.start = 0; - self.end = 0; + let n = ready!(Pin::new(&mut writer).poll_write( + cx, + &self.buf[self.written_up_to_idx..self.should_write_up_to_idx] + ))?; + self.written_up_to_idx += n; + if self.written_up_to_idx == self.should_write_up_to_idx { + self.written_up_to_idx = 0; + self.should_write_up_to_idx = 0; } Step::Flushing } From fb88b8912d100de85c6a3986d758f82a78d78cd9 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Fri, 14 Mar 2025 15:41:11 -0400 Subject: [PATCH 021/206] s/3/header_len/g --- src/crypto/cipher.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index 8278692..4dfaf19 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -85,8 +85,8 @@ impl DecryptCipher { let (to_decrypt, _tag) = self.decrypt_buf(&buf[header_len..header_len + body_len])?; let decrypted_len = to_decrypt.len(); write_uint24_le(decrypted_len, buf); - let decrypted_end = 3 + to_decrypt.len(); - buf[3..decrypted_end].copy_from_slice(to_decrypt.as_slice()); + let decrypted_end = header_len + to_decrypt.len(); + buf[header_len..decrypted_end].copy_from_slice(to_decrypt.as_slice()); // Set extra bytes in the buffer to 0 let encrypted_end = header_len + body_len; buf[decrypted_end..encrypted_end].fill(0x00); From 1bb619d436b4098e2aa4774c898599f3b5839156 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sat, 15 Mar 2025 15:50:31 -0400 Subject: [PATCH 022/206] add tokio-util for tests --- Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Cargo.toml b/Cargo.toml index a5ac273..ff89935 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,6 +61,7 @@ futures = "0.3.13" log = "0.4" test-log = { version = "0.2.11", default-features = false, features = ["trace"] } tracing-subscriber = { version = "0.3.16", features = ["env-filter", "fmt"] } +tokio-util = { version = "0.7.14", features = ["compat"] } [features] default = ["tokio", "sparse"] From b6db23333447d6b476806edbac0e4f91d072b418 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sat, 15 Mar 2025 15:51:14 -0400 Subject: [PATCH 023/206] Add result channel to test utils. refactor to use futures channels bc they implement Sender --- src/test_utils.rs | 107 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 100 insertions(+), 7 deletions(-) diff --git a/src/test_utils.rs b/src/test_utils.rs index 7d8c3a7..2e9b6cd 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -4,8 +4,13 @@ use std::{ task::{Context, Poll}, }; -use async_channel::{unbounded, Receiver, SendError, Sender}; -use futures::{Sink, SinkExt, Stream, StreamExt}; +//use async_channel::{unbounded, Receiver, SendError, Sender}; +use futures::{ + channel::mpsc::{ + unbounded, SendError, UnboundedReceiver as Receiver, UnboundedSender as Sender, + }, + Sink, SinkExt, Stream, StreamExt, +}; #[derive(Debug)] pub(crate) struct Io { @@ -29,15 +34,14 @@ impl Stream for Io { } impl Sink> for Io { - type Error = SendError>; + type Error = SendError; fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - fn start_send(self: Pin<&mut Self>, item: Vec) -> Result<(), Self::Error> { - let _ = self.sender.try_send(item); - Ok(()) + fn start_send(mut self: Pin<&mut Self>, item: Vec) -> Result<(), Self::Error> { + Pin::new(&mut self.sender).start_send(item) } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { @@ -73,7 +77,6 @@ pub(crate) fn create_connected() -> (Io, Io) { TwoWay::default().split_sides() } -#[allow(dead_code)] pub(crate) fn log() { use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; static START_LOGS: OnceLock<()> = OnceLock::new(); @@ -108,3 +111,93 @@ async fn split() { }; assert_eq!(res, b"hello"); } + +#[derive(Debug)] +pub(crate) struct Moo { + receiver: Rx, + sender: Tx, +} + +impl + Unpin, Tx: Unpin> Stream for Moo { + type Item = RxItem; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + Pin::new(&mut this.receiver).poll_next(cx) + } +} + +impl + Unpin> Sink + for Moo +{ + type Error = SendError; + + fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: TxItem) -> Result<(), Self::Error> { + let this = self.get_mut(); + Pin::new(&mut this.sender).start_send(item) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + todo!() + } +} + +/// Creaee [`Moo`] from return value of [`unbounded`] +impl From<(Tx, Rx)> for Moo { + fn from(value: (Tx, Rx)) -> Self { + Moo { + receiver: value.1, + sender: value.0, + } + } +} + +impl Moo { + /// connect two [`Moo`]s + fn connect( + self, + other: Moo, + ) -> (Moo, Moo) { + let left = Moo { + receiver: self.receiver, + sender: other.sender, + }; + let right = Moo { + receiver: other.receiver, + sender: self.sender, + }; + (left, right) + } +} + +fn result_channel() -> (Sender>, impl Stream, String>>) { + let (tx, rx) = unbounded::>(); + (tx, rx.map(|x| Ok(x))) +} + +pub(crate) fn create_result_connected() -> ( + Moo, String>>, impl Sink>>, + Moo, String>>, impl Sink>>, +) { + let a = Moo::from(result_channel()); + let b = Moo::from(result_channel()); + a.connect(b) +} + +#[tokio::test] +async fn foo() -> Result<(), Box> { + let a = Moo::from(result_channel()); + let b = Moo::from(result_channel()); + let (mut left, mut right) = a.connect(b); + left.send(b"hello".to_vec()).await?; + assert_eq!(right.next().await.unwrap(), Ok(b"hello".into())); + Ok(()) +} From 7780cb2b189f351f7547605f3a2e266e2361394f Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sat, 15 Mar 2025 15:52:54 -0400 Subject: [PATCH 024/206] refactor framing tests --- src/framing.rs | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/framing.rs b/src/framing.rs index 3504895..64576f3 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -63,7 +63,6 @@ where #[instrument(skip_all)] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - trace!("from poll next!!"); let Self { io, to_stream, @@ -169,23 +168,23 @@ where } } #[cfg(test)] -mod test { +pub(crate) mod test { use crate::test_utils::log; use super::*; - use futures::{ - io::{AsyncReadExt, AsyncWriteExt}, - AsyncRead, AsyncWrite, SinkExt, StreamExt, - }; + use futures::{SinkExt, StreamExt}; + use futures_lite::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio_util::compat::TokioAsyncReadCompatExt; - fn duplex(channel_size: usize) -> (impl AsyncRead + AsyncWrite, impl AsyncRead + AsyncWrite) { + pub(crate) fn duplex( + channel_size: usize, + ) -> (impl AsyncRead + AsyncWrite, impl AsyncRead + AsyncWrite) { let (left, right) = tokio::io::duplex(channel_size); (left.compat(), right.compat()) } #[tokio::test] - async fn t_duplex() -> Result<()> { + async fn duplex_works() -> Result<()> { let (mut left, mut right) = duplex(64); left.write_all(b"hello").await?; let mut b = vec![0; 5]; @@ -195,7 +194,7 @@ mod test { } #[tokio::test] - async fn t_input() -> Result<()> { + async fn input() -> Result<()> { log(); let (left, mut right) = duplex(64); let mut lp = LengthPrefixed::new(left); @@ -210,7 +209,7 @@ mod test { Ok(()) } #[tokio::test] - async fn t_stream_many() -> Result<()> { + async fn stream_many() -> Result<()> { log(); let (left, mut right) = duplex(64); let mut lp = LengthPrefixed::new(left); @@ -230,7 +229,7 @@ mod test { Ok(()) } #[tokio::test] - async fn t_sink_many() -> Result<()> { + async fn sink_many() -> Result<()> { log(); let (left, mut right) = duplex(64); let mut lp = LengthPrefixed::new(left); From ec9edc2ea626ad079c68ca1e40b153dd4fe943bd Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sat, 15 Mar 2025 16:43:48 -0400 Subject: [PATCH 025/206] Get Encrypted working with Result> from io --- src/test_utils.rs | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/src/test_utils.rs b/src/test_utils.rs index 2e9b6cd..d35af0e 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -1,10 +1,11 @@ use std::{ + io::{self, ErrorKind}, pin::Pin, sync::OnceLock, task::{Context, Poll}, }; -//use async_channel::{unbounded, Receiver, SendError, Sender}; +//use async_channel::{unbounded, Receiver, io::Error, Sender}; use futures::{ channel::mpsc::{ unbounded, SendError, UnboundedReceiver as Receiver, UnboundedSender as Sender, @@ -34,14 +35,16 @@ impl Stream for Io { } impl Sink> for Io { - type Error = SendError; + type Error = io::Error; fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn start_send(mut self: Pin<&mut Self>, item: Vec) -> Result<(), Self::Error> { - Pin::new(&mut self.sender).start_send(item) + Pin::new(&mut self.sender) + .start_send(item) + .map_err(|_e| io::Error::new(ErrorKind::Other, "SendError")) } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { @@ -112,7 +115,6 @@ async fn split() { assert_eq!(res, b"hello"); } -#[derive(Debug)] pub(crate) struct Moo { receiver: Rx, sender: Tx, @@ -127,10 +129,10 @@ impl + Unpin, Tx: Unpin> Stream for Moo } } -impl + Unpin> Sink +impl + Unpin> Sink for Moo { - type Error = SendError; + type Error = io::Error; fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) @@ -138,7 +140,9 @@ impl + Unpi fn start_send(self: Pin<&mut Self>, item: TxItem) -> Result<(), Self::Error> { let this = self.get_mut(); - Pin::new(&mut this.sender).start_send(item) + Pin::new(&mut this.sender) + .start_send(item) + .map_err(|_e| io::Error::new(ErrorKind::Other, "SendError")) } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { @@ -178,14 +182,14 @@ impl Moo { } } -fn result_channel() -> (Sender>, impl Stream, String>>) { +fn result_channel() -> (Sender>, impl Stream>>) { let (tx, rx) = unbounded::>(); (tx, rx.map(|x| Ok(x))) } pub(crate) fn create_result_connected() -> ( - Moo, String>>, impl Sink>>, - Moo, String>>, impl Sink>>, + Moo>>, impl Sink>>, + Moo>>, impl Sink>>, ) { let a = Moo::from(result_channel()); let b = Moo::from(result_channel()); @@ -198,6 +202,6 @@ async fn foo() -> Result<(), Box> { let b = Moo::from(result_channel()); let (mut left, mut right) = a.connect(b); left.send(b"hello".to_vec()).await?; - assert_eq!(right.next().await.unwrap(), Ok(b"hello".into())); + assert_eq!(right.next().await.unwrap()?, b"hello".to_vec()); Ok(()) } From c9a7709b5ba09b4de44a71c121126622331c850b Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 16 Mar 2025 15:42:48 -0400 Subject: [PATCH 026/206] Make Encrypted receive a Result --- src/noise.rs | 118 +++++++++++++++++++++++++++++++++++---------------- 1 file changed, 82 insertions(+), 36 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index 84fc381..dd687ec 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -34,24 +34,38 @@ impl std::fmt::Display for Step { } /// Wrap a stream with encryption -#[derive(Debug)] pub struct Encrypted { io: IO, step: Step, is_initiator: bool, encrypted_tx: VecDeque>, - encrypted_rx: VecDeque>, + encrypted_rx: VecDeque>>, plain_tx: VecDeque>, - plain_rx: VecDeque>, + plain_rx: VecDeque>>, flush: bool, } +impl std::fmt::Debug for Encrypted { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Encrypted") + //.field("io", &self.io) + .field("step", &self.step) + .field("is_initiator", &self.is_initiator) + //.field("encrypted_tx", &self.encrypted_tx) + .field("encrypted_rx", &self.encrypted_rx) + .field("plain_tx", &self.plain_tx) + .field("plain_rx", &self.plain_rx) + .field("flush", &self.flush) + .finish() + } +} + impl Encrypted where - IO: Stream> + Sink> + Send + Unpin + Debug + 'static, + IO: Stream>> + Sink> + Send + Unpin + 'static, { /// Create [`Self`] from a Stream/Sink - #[instrument(skip_all, fields(is_initiator = %is_initiator))] + #[instrument(skip_all, fields(initiator = %is_initiator))] pub fn new(is_initiator: bool, io: IO) -> Self { Self { io, @@ -66,26 +80,31 @@ where } } -impl> + Sink> + Send + Unpin + Debug + 'static> Sink> - for Encrypted +impl< + IO: Stream>> + + Sink, Error = std::io::Error> + + Send + + Unpin + + 'static, + > Sink> for Encrypted { type Error = std::io::Error; fn poll_ready( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, ) -> Poll> { - Poll::Ready(Ok(())) + Sink::poll_ready(Pin::new(&mut self.io), cx) } - #[instrument(skip_all, fields(is_initiator = %self.is_initiator))] + #[instrument(skip_all, fields(initiator = %self.is_initiator))] fn start_send(mut self: Pin<&mut Self>, item: Vec) -> std::result::Result<(), Self::Error> { - trace!("add plain tx"); + info!(initiator = %self.is_initiator, "enqueue plain_tx\n{item:?}"); self.plain_tx.push_back(item); Ok(()) } - #[instrument(skip_all, fields(is_initiator = %self.is_initiator))] + #[instrument(skip_all, fields(initiator = %self.is_initiator))] fn poll_flush( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -129,7 +148,7 @@ impl> + Sink> + Send + Unpin + Debug + 'static } } - #[instrument(skip_all, fields(is_initiator = %self.is_initiator))] + #[instrument(skip_all, fields(initiator = %self.is_initiator))] fn poll_close( self: Pin<&mut Self>, _cx: &mut Context<'_>, @@ -138,12 +157,12 @@ impl> + Sink> + Send + Unpin + Debug + 'static } } -impl> + Sink> + Send + Unpin + Debug + 'static> Stream +impl>> + Sink> + Send + Unpin + 'static> Stream for Encrypted { - type Item = Vec; + type Item = Result>; - #[instrument(skip_all, fields(is_initiator = %self.is_initiator))] + #[instrument(skip_all, fields(initiator = %self.is_initiator))] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let Encrypted { io, @@ -185,10 +204,11 @@ impl> + Sink> + Send + Unpin + Debug + 'static } } +#[instrument(skip_all, fields(initiator = %is_initiator))] fn poll_setup( step: &mut Step, encrypted_tx: &mut VecDeque>, - encrypted_rx: &mut VecDeque>, + encrypted_rx: &mut VecDeque>>, is_initiator: bool, ) { // Still setting up @@ -197,30 +217,54 @@ fn poll_setup( trace!(initiator = %is_initiator,"queue initial msg\n{msg:?}"); encrypted_tx.push_front(msg); } - while let Some(incoming_msg) = encrypted_rx.pop_front() { - trace!(initiator = %is_initiator, "recieved setup msg"); - if let Ok(msgs) = match handle_setup_message(step, &incoming_msg, is_initiator) { - Ok(x) => Ok(x), - Err(e) => { - error!("handle_setup_message error: {e:?}"); - Err(e) + // TODO handle error + loop { + match encrypted_rx.pop_front() { + None => { + debug!( + " + num_encrypted_rx = {} + num_encrypted_tx = {} +no more encrp incoming", + encrypted_rx.len(), + encrypted_tx.len(), + ); + break; } - } { - for msg in msgs.into_iter().rev() { - trace!(initiator = %is_initiator,"queue more setup msg\n{msg:?}"); - encrypted_tx.push_front(msg); + Some(Err(e)) => { + error!( + num_encrypted_rx = 0, + num_encrypted_tx = encrypted_tx.len(), + "{e:?}" + ); + break; + } + Some(Ok(incoming_msg)) => { + info!(initiator = %is_initiator, "recieved setup msg"); + if let Ok(msgs) = match handle_setup_message(step, &incoming_msg, is_initiator) { + Ok(x) => Ok(x), + Err(e) => { + error!("handle_setup_message error: {e:?}"); + Err(e) + } + } { + for msg in msgs.into_iter().rev() { + info!(initiator = %is_initiator,"queue more setup msg\n{msg:?}"); + encrypted_tx.push_front(msg); + } + } } } } } fn poll_encrypted_side_io< - IO: Stream> + Sink> + Send + Unpin + Debug + 'static, + IO: Stream>> + Sink> + Send + Unpin + 'static, >( io: &mut IO, cx: &mut Context<'_>, encrypted_tx: &mut VecDeque>, - encrypted_rx: &mut VecDeque>, + encrypted_rx: &mut VecDeque>>, is_initiator: bool, flush: &mut bool, ) { @@ -287,18 +331,20 @@ fn poll_do_encrypt_and_decrypt( encryptor: &mut RawEncryptCipher, decryptor: &mut DecryptCipher, encrypted_tx: &mut VecDeque>, - encrypted_rx: &mut VecDeque>, + encrypted_rx: &mut VecDeque>>, plain_tx: &mut VecDeque>, - plain_rx: &mut VecDeque>, + plain_rx: &mut VecDeque>>, is_initiator: bool, flush: &mut bool, ) { // decrypt any incromming encrypted messages - while let Some(incoming_msg) = encrypted_rx.pop_front() { + // TODO handle error + while let Some(Ok(incoming_msg)) = encrypted_rx.pop_front() { + info!(initiator = %is_initiator, "enc rx decrypting\n{incoming_msg:?}"); match decryptor.decrypt_buf(&incoming_msg) { Ok((plain_msg, _tag)) => { - trace!(initiator = %is_initiator, "plain rx queue"); - plain_rx.push_back(plain_msg); + info!(initiator = %is_initiator, "plain rx queue"); + plain_rx.push_back(Ok(plain_msg)); } Err(e) => { error!(initiator = %is_initiator,"RX message failed to decrypt: {e:?}") From f2cd806b0a849c0b9ade8347cb0159db7c7d6aec Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 16 Mar 2025 16:13:08 -0400 Subject: [PATCH 027/206] Fix impl of Sink fro Framing poll_flush fixes the issue of the messages not being sent --- src/framing.rs | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/src/framing.rs b/src/framing.rs index 64576f3..32ac6d8 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -132,32 +132,40 @@ where Poll::Ready(Ok(())) } + #[instrument(skip_all)] fn start_send(mut self: Pin<&mut Self>, item: Vec) -> std::result::Result<(), Self::Error> { self.from_sink.push_back(wrap_uint24_le(&item)); Ok(()) } + #[instrument(skip_all)] fn poll_flush( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { let Self { from_sink, io, .. } = self.get_mut(); - if let Some(msg) = from_sink.pop_front() { - match Pin::new(io).poll_write(cx, &msg) { - Poll::Pending => { - from_sink.push_front(msg); - return Poll::Pending; - } - Poll::Ready(Ok(n)) => { - if n != msg.len() { - from_sink.push_front(msg[n..].to_vec()); - return Poll::Ready(Ok(())); + loop { + if let Some(msg) = from_sink.pop_front() { + match Pin::new(&mut *io).poll_write(cx, &msg) { + Poll::Pending => { + from_sink.push_front(msg); + debug!("AsyncWrite busy, could not flush"); + return Poll::Pending; } + Poll::Ready(Ok(n)) => { + if n != msg.len() { + from_sink.push_front(msg[n..].to_vec()); + warn!("only wrote [{n} / {}]", msg.len()); + } + debug!("flushed whole message of N=[{n}] bytes"); + } + Poll::Ready(Err(_e)) => todo!(), } - Poll::Ready(Err(_e)) => todo!(), + } else { + debug!("No messages in self.from_sink. Flush done"); + return Poll::Ready(Ok(())); } } - Poll::Ready(Ok(())) } fn poll_close( From 08056577e27f6a9727fe24e7185b99b85b9fcca3 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 16 Mar 2025 23:46:00 -0400 Subject: [PATCH 028/206] Add docs handle todos --- src/framing.rs | 96 ++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 77 insertions(+), 19 deletions(-) diff --git a/src/framing.rs b/src/framing.rs index 32ac6d8..66ce84f 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -7,24 +7,28 @@ use std::{ }; use futures::{Sink, Stream}; + use futures_lite::io::{AsyncRead, AsyncWrite}; -use tracing::{debug, instrument, trace}; +use tracing::{debug, error, info, instrument, trace, warn}; use crate::util::{stat_uint24_le, wrap_uint24_le}; -const BUF_SIZE: usize = 1024 * 8; +const BUF_SIZE: usize = 1024 * 64; const HEADER_LEN: usize = 3; /// Turn a `AsyncWrite` of length prefixed messages and emit the messages with a Stream pub struct LengthPrefixed { io: IO, + /// Data from [`Self::io`]'s [`AsyncRead`] interface to be sent out via the [`Stream`] interface. to_stream: Vec, + /// Data from the `Sink` interface to be written out to [`Self::io`]'s [`AsyncWrite`] interface. from_sink: VecDeque>, - /// The index in [`Self::buf`] of the last byte that was to the [`Stream`]. + /// The index in [`Self::to_stream`] of the last byte that was to the [`Stream`]. last_out_idx: usize, - /// The index in [`Self::buf`] of the last byte that was read from [`Self::io`] via + /// The index in [`Self::to_stream`] of the last byte that was read from [`Self::io`]'s /// [`AsyncRead`] last_data_idx: usize, + /// Current step of a message being parsed step: Step, } impl Debug for LengthPrefixed { @@ -71,30 +75,36 @@ where step, .. } = self.get_mut(); + debug!( + "Try to AsyncRead up to (buff_size[{}] - last_data_idx[{}]) = [{}]", + to_stream.len(), + *last_data_idx, + to_stream.len() - *last_data_idx + ); let n_bytes_read = match Pin::new(io).poll_read(cx, &mut to_stream[*last_data_idx..]) { Poll::Ready(Ok(n)) => n, Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))), - Poll::Pending => 0, + Poll::Pending => { + cx.waker().wake_by_ref(); + 0 + } }; // TODO handle if to_stream is full - trace!("adding #=[{n_bytes_read}] bytes to end=[{}]", last_data_idx); + debug!("adding #=[{n_bytes_read}] bytes to end=[{}]", last_data_idx); *last_data_idx += n_bytes_read; // grow buffer if it's full if *last_data_idx == to_stream.len() - 1 { + warn!("We filled our buffer!"); to_stream.extend(vec![0; to_stream.len() * 2]); } if let Step::Header = step { trace!(step = ?*step, "enter"); - if *last_data_idx - *last_out_idx < HEADER_LEN { + let cur_data = &to_stream[*last_out_idx..*last_data_idx]; + + let Some((header_len, body_len)) = stat_uint24_le(cur_data) else { trace!("not enough bytes to read header"); return Poll::Pending; - } - let Some((header_len, body_len)) = - stat_uint24_le(&to_stream[*last_out_idx..(*last_out_idx + HEADER_LEN)]) - else { - // we check above the there is room for header so this should never happen - todo!() }; let cur_frame_start = *last_out_idx + header_len; @@ -105,6 +115,7 @@ where }; } + info!(step = ?*step, "enter"); if let Step::Body { start, end } = step { let end = *end as usize; if end <= *last_data_idx { @@ -149,20 +160,22 @@ where match Pin::new(&mut *io).poll_write(cx, &msg) { Poll::Pending => { from_sink.push_front(msg); - debug!("AsyncWrite busy, could not flush"); return Poll::Pending; } Poll::Ready(Ok(n)) => { if n != msg.len() { from_sink.push_front(msg[n..].to_vec()); - warn!("only wrote [{n} / {}]", msg.len()); + warn!("only wrote [{n} / {}] bytes of message", msg.len()); } debug!("flushed whole message of N=[{n}] bytes"); } - Poll::Ready(Err(_e)) => todo!(), + Poll::Ready(Err(e)) => { + error!("Error flushing data"); + return Poll::Ready(Err(e)); + } } } else { - debug!("No messages in self.from_sink. Flush done"); + debug!("No more messages to flush"); return Poll::Ready(Ok(())); } } @@ -170,9 +183,10 @@ where fn poll_close( self: Pin<&mut Self>, - _cx: &mut Context<'_>, + cx: &mut Context<'_>, ) -> Poll> { - todo!() + let Self { io, .. } = self.get_mut(); + Pin::new(&mut *io).poll_close(cx) } } #[cfg(test)] @@ -253,4 +267,48 @@ pub(crate) mod test { assert_eq!(result, expected); Ok(()) } + + #[tokio::test] + async fn left_and_right() -> Result<()> { + let (left, right) = duplex(64); + + let mut leftlp = LengthPrefixed::new(left); + let mut rightlp = LengthPrefixed::new(right); + + let data: &[&[u8]] = &[b"yolo", b"squalor", b"idle", b"hello", b"stuff"]; + for d in data { + rightlp.send(d.to_vec()).await.unwrap(); + } + + let mut result1 = vec![]; + for _ in data { + result1.push(leftlp.next().await.unwrap().unwrap()); + } + assert_eq!(result1, data); + + for d in data { + leftlp.send(d.to_vec()).await.unwrap(); + } + let mut result2 = vec![]; + for _ in data { + result2.push(rightlp.next().await.unwrap().unwrap()); + } + assert_eq!(result2, data); + + let mut r3 = vec![]; + let mut r4 = vec![]; + for d in data { + rightlp.send(d.to_vec()).await.unwrap(); + leftlp.send(d.to_vec()).await.unwrap(); + } + + for _ in data { + r3.push(rightlp.next().await.unwrap().unwrap()); + r4.push(leftlp.next().await.unwrap().unwrap()); + } + assert_eq!(r3, data); + assert_eq!(r4, data); + + Ok(()) + } } From b9c5482f69056a7ce0d2f17ed8f404457375d295 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 17 Mar 2025 00:48:48 -0400 Subject: [PATCH 029/206] Add encryption_established more tests better logs --- src/noise.rs | 184 +++++++++++++++++++++++++++++++++------------------ 1 file changed, 119 insertions(+), 65 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index dd687ec..15525de 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -78,6 +78,11 @@ where flush: false, } } + + /// Wether an encrypted connection has been established. + pub fn encryption_established(&self) -> bool { + matches!(self.step, Step::Established(_)) + } } impl< @@ -214,29 +219,17 @@ fn poll_setup( // Still setting up if let Ok(Some(msg)) = maybe_init(step, is_initiator) { // queue the init message to send first - trace!(initiator = %is_initiator,"queue initial msg\n{msg:?}"); + info!(initiator = %is_initiator,"queue initial msg\n{msg:?}"); encrypted_tx.push_front(msg); } // TODO handle error loop { match encrypted_rx.pop_front() { None => { - debug!( - " - num_encrypted_rx = {} - num_encrypted_tx = {} -no more encrp incoming", - encrypted_rx.len(), - encrypted_tx.len(), - ); break; } Some(Err(e)) => { - error!( - num_encrypted_rx = 0, - num_encrypted_tx = encrypted_tx.len(), - "{e:?}" - ); + error!("Recieved an error during setup encryption setup: {e:?}"); break; } Some(Ok(incoming_msg)) => { @@ -258,6 +251,7 @@ no more encrp incoming", } } +#[instrument(skip_all, fields(initiator = %is_initiator))] fn poll_encrypted_side_io< IO: Stream>> + Sink> + Send + Unpin + 'static, >( @@ -271,33 +265,25 @@ fn poll_encrypted_side_io< // send any pending outgoing messages while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { if let Some(encrypted_out) = encrypted_tx.pop_front() { - trace!(initiator = %is_initiator, "enc tx send msg\n{encrypted_out:?}"); - let _todo = Sink::start_send(Pin::new(io), encrypted_out); + info!(initiator = %is_initiator, msg_len = encrypted_out.len(), "enc tx send msg\n{encrypted_out:?}"); + if let Err(_e) = Sink::start_send(Pin::new(io), encrypted_out) { + error!("Error polling encyrpted side io") + } + *flush = true; } else { break; } } if *flush { - // confusing docs related to start send - // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.start_send - // First part says: - // "you must use **poll_flush** ... inorder to garuntee - // completions of send" - // Then it says: - // " It is only necessary to call poll_flush if you need to guarantee that all - // of the items placed into the Sink have been sent" - // - // So do I need to do it or not? - // must `poll_flush` be called for **anything** to send? match Sink::poll_flush(Pin::new(io), cx) { Poll::Ready(Ok(())) => { *flush = false; - trace!(initiator = %is_initiator, "flushed good"); + info!(initiator = %is_initiator, "flushed good"); + } + Poll::Ready(Err(_e)) => { + error!(initiator = %is_initiator, "Error sending encrypted msg") } - Poll::Ready(Err(_e)) => error!( - initiator = %is_initiator, - "Error sending encrypted msg"), Poll::Pending => { // More confusing docs // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.poll_flush @@ -308,6 +294,7 @@ fn poll_encrypted_side_io< // Does this mean, each time this task wakes up again from this code path that // I must trigger another poll_flush? But how would I know i need more // flushing? + debug!("flush not completed"); *flush = true; } } @@ -327,6 +314,7 @@ fn poll_encrypted_side_io< /// Process messages waiting to be encrypted or decrypted // TODO sholud this return a Result +#[instrument(skip_all)] fn poll_do_encrypt_and_decrypt( encryptor: &mut RawEncryptCipher, decryptor: &mut DecryptCipher, @@ -375,7 +363,7 @@ fn maybe_init(step: &mut Step, is_initiator: bool) -> Result>> { Ok(out) } -#[instrument(skip_all, fields(is_initiator = %is_initiator))] +#[instrument(skip_all, fields(initiator = %is_initiator))] fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Result>> { match &step { Step::NotInitialized => { @@ -432,7 +420,6 @@ fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Resu Ok(out) } Step::SecretStream(_) => { - info!("E're a secret stream now!!!!!"); if let Step::SecretStream((enc_cipher, hs_result)) = replace(step, Step::NotInitialized) { let dec_cipher = DecryptCipher::from_handshake_rx_and_init_msg(&hs_result, msg)?; @@ -447,58 +434,125 @@ fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Resu #[cfg(test)] mod tset { - use crate::test_utils::create_connected; + + use crate::{framing::test::duplex, test_utils::create_result_connected, LengthPrefixed}; use super::*; - use futures::{SinkExt, StreamExt}; + use futures::{future::join, SinkExt, StreamExt}; #[tokio::test] - async fn steps() -> Result<()> { - let mut left_hs = Handshake::new(true)?; - let s1 = left_hs.start_raw()?.unwrap(); + async fn encrypted() -> Result<()> { + let hello = b"hello".to_vec(); + let world = b"world".to_vec(); + let (lc, rc) = create_result_connected(); + let mut left = Encrypted::new(true, lc); + let mut right = Encrypted::new(false, rc); - println!("s1 {s1:?}"); - let mut right_hs = Handshake::new(false)?; + let (_sent, receieved) = join(left.send(hello.clone()), right.next()).await; + assert_eq!(receieved.unwrap()?, hello); - let s2 = right_hs.read_raw(&s1)?.unwrap(); - println!("s2 {s2:?}"); + assert!(left.encryption_established()); + assert!(right.encryption_established()); - let s3 = left_hs.read_raw(&s2)?.unwrap(); - println!("s3 {s3:?}"); + // NB: we cannot totally finish 'left.send' until the other side becomes active + // because the handshake with the other side ('right') must complete + // before the 'hello' message is sent. So we poll both the send and receive concurrently. + let (_sent, receieved) = join(left.send(hello.clone()), right.next()).await; + // right recieves left's message + assert_eq!(receieved.unwrap()?, hello); - let s4 = right_hs.read_raw(&s3)?; + // now that the encrypted channel is established, we don't need to spawn. + right.send(world.clone()).await.unwrap(); - println!("s4 {s4:?}"); - // both sides now ready + // left recieves right's message + assert_eq!(left.next().await.unwrap()?, world); + Ok(()) + } + #[tokio::test] + async fn encrypted_many() -> Result<()> { + let hello = b"hello".to_vec(); + let data = vec![ + b"yolo".to_vec(), + b"squalor".to_vec(), + b"idleness".to_vec(), + b"hello".to_vec(), + b"stuff".to_vec(), + ]; + let (lc, rc) = create_result_connected(); + let mut left = Encrypted::new(true, lc); + let mut right = Encrypted::new(false, rc); + + let (_sent, receieved) = join(left.send(hello.clone()), right.next()).await; + assert_eq!(receieved.unwrap()?, hello); + + for d in &data { + right.send(d.to_vec()).await?; + } + let mut result = vec![]; + for _ in &data { + result.push(left.next().await.unwrap()?); + } + assert_eq!(result, data); Ok(()) } #[tokio::test] - async fn test_encrypted() -> Result<()> { - let hello = b"hello"; - let world = b"world"; - let (left, right) = create_connected(); + async fn with_framing() -> Result<()> { + crate::test_utils::log(); + let hello = b"hello".to_vec(); + + let (left, right) = duplex(1024 * 64); + let left = LengthPrefixed::new(left); + let right = LengthPrefixed::new(right); + let mut left = Encrypted::new(true, left); let mut right = Encrypted::new(false, right); - // NB: we cannot totally finish 'left.send' until the other side becomes active - // this is because the handshake with the other side ('right') must complete - // before the message is sent. So we must spawn here, so we can proceed to run 'right' - let left_handle = tokio::task::spawn(async move { - left.send(hello.into()).await.unwrap(); - left - }); + let (_sent, receieved) = join(left.send(hello.clone()), right.next()).await; + assert_eq!(receieved.unwrap()?, hello); - // right recieves left's message - assert_eq!(right.next().await.unwrap(), hello); + let data = vec![ + b"yolo".to_vec(), + b"squalor".to_vec(), + b"idleness".to_vec(), + b"hello".to_vec(), + b"stuff".to_vec(), + ]; - let mut left = left_handle.await?; + // send right to left + for d in &data { + right.send(d.to_vec()).await?; + } + let mut result = vec![]; + for _ in &data { + result.push(left.next().await.unwrap()?); + } + assert_eq!(result, data); - // now that the encrypted channel is established, we don't need to spawn. - right.send(world.into()).await.unwrap(); + // send left to right + for d in &data { + left.send(d.to_vec()).await?; + } + let mut result = vec![]; + for _ in &data { + result.push(right.next().await.unwrap()?); + } + assert_eq!(result, data); + + // send both ways + for d in &data { + left.send(d.to_vec()).await?; + right.send(d.to_vec()).await?; + } + let mut left_result = vec![]; + let mut right_result = vec![]; + for _ in &data { + right_result.push(right.next().await.unwrap()?); + left_result.push(left.next().await.unwrap()?); + } + assert_eq!(right_result, data); + assert_eq!(left_result, data); - // left recieves right's message - assert_eq!(left.next().await.unwrap(), world); Ok(()) } } From e547a1f7de2437546231fc0a673308d5ff7e6e05 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 17 Mar 2025 12:18:28 -0400 Subject: [PATCH 030/206] logs --- src/noise.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/noise.rs b/src/noise.rs index 15525de..a846b28 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -346,7 +346,7 @@ fn poll_do_encrypt_and_decrypt( Ok(x) => x, Err(_e) => todo!("We failed to encrypt our own message...?"), }; - trace!(initiator = %is_initiator, "enc from plain tx queue\n{enc_out:?}"); + trace!(initiator = %is_initiator, encrypted_msg_length = enc_out.len(), "enqueue new encrypted message from plain tx queue\n{enc_out:?}"); encrypted_tx.push_back(enc_out); *flush = true; } From 83dbe3a77cb0cd2d5b4fe2cee8a7beb74df4222d Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 17 Mar 2025 14:24:05 -0400 Subject: [PATCH 031/206] Add framing buffer rotation --- src/framing.rs | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/framing.rs b/src/framing.rs index 66ce84f..38b0f77 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -84,18 +84,15 @@ where let n_bytes_read = match Pin::new(io).poll_read(cx, &mut to_stream[*last_data_idx..]) { Poll::Ready(Ok(n)) => n, Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))), - Poll::Pending => { - cx.waker().wake_by_ref(); - 0 - } + Poll::Pending => 0, }; // TODO handle if to_stream is full debug!("adding #=[{n_bytes_read}] bytes to end=[{}]", last_data_idx); *last_data_idx += n_bytes_read; // grow buffer if it's full if *last_data_idx == to_stream.len() - 1 { - warn!("We filled our buffer!"); - to_stream.extend(vec![0; to_stream.len() * 2]); + warn!("Buffer full, double it's size"); + to_stream.extend(vec![0; to_stream.len()]); } if let Step::Header = step { @@ -122,14 +119,18 @@ where debug!(frame_size = end - *start, "Frame ready"); let out = to_stream[*start..end].to_vec(); *step = Step::Header; - *last_out_idx = end; + // remove bytes we're done with + to_stream.rotate_left(end); + *last_data_idx -= end; + *last_out_idx = 0; return Poll::Ready(Some(Ok(out))); } } Poll::Pending } } + impl Sink> for LengthPrefixed where IO: AsyncWrite + AsyncRead + Send + Unpin + 'static, From a585aefad5705d79794818fbd94689d432083efd Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 17 Mar 2025 15:34:42 -0400 Subject: [PATCH 032/206] bump futures to non-yanked version --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index ff89935..170c32f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,7 +39,7 @@ futures-lite = "1" sha2 = "0.10" curve25519-dalek = "4" crypto_secretstream = "0.2" -futures = "0.3.13" +futures = "0.3.31" [dependencies.hypercore] version = "0.14.0" From 31e628134b44433c0855a6296b5562c5998c8486 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 17 Mar 2025 17:55:12 -0400 Subject: [PATCH 033/206] handle setup errors and add test --- src/noise.rs | 202 ++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 159 insertions(+), 43 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index a846b28..5126576 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -18,21 +18,6 @@ pub(crate) enum Step { SecretStream((RawEncryptCipher, HandshakeResult)), Established((RawEncryptCipher, DecryptCipher, HandshakeResult)), } -impl std::fmt::Display for Step { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - Step::NotInitialized => "NotInitialized", - Step::Handshake(_) => "Handshake", - Step::SecretStream(_) => "SecretStream", - Step::Established(_) => "Established", - } - ) - } -} - /// Wrap a stream with encryption pub struct Encrypted { io: IO, @@ -45,21 +30,6 @@ pub struct Encrypted { flush: bool, } -impl std::fmt::Debug for Encrypted { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Encrypted") - //.field("io", &self.io) - .field("step", &self.step) - .field("is_initiator", &self.is_initiator) - //.field("encrypted_tx", &self.encrypted_tx) - .field("encrypted_rx", &self.encrypted_rx) - .field("plain_tx", &self.plain_tx) - .field("plain_rx", &self.plain_rx) - .field("flush", &self.flush) - .finish() - } -} - impl Encrypted where IO: Stream>> + Sink> + Send + Unpin + 'static, @@ -78,7 +48,6 @@ where flush: false, } } - /// Wether an encrypted connection has been established. pub fn encryption_established(&self) -> bool { matches!(self.step, Step::Established(_)) @@ -147,7 +116,7 @@ impl< Poll::Ready(Ok(())) } } else { - poll_setup(step, encrypted_tx, encrypted_rx, *is_initiator); + poll_setup(step, encrypted_tx, encrypted_rx, *is_initiator, flush); cx.waker().wake_by_ref(); Poll::Pending } @@ -202,7 +171,7 @@ impl>> + Sink> + Send + Unpin + 'static Poll::Pending } } else { - poll_setup(step, encrypted_tx, encrypted_rx, *is_initiator); + poll_setup(step, encrypted_tx, encrypted_rx, *is_initiator, flush); cx.waker().wake_by_ref(); Poll::Pending } @@ -215,7 +184,12 @@ fn poll_setup( encrypted_tx: &mut VecDeque>, encrypted_rx: &mut VecDeque>>, is_initiator: bool, + flush: &mut bool, ) { + // if we get an error, it could be because the other side reset, and is sending a new + // initialization message. + // If this is the case, we should retry this message after the error. + // But to avoid repeatedly retrying the first message, we should only retry if it is *not* the first msg. // Still setting up if let Ok(Some(msg)) = maybe_init(step, is_initiator) { // queue the init message to send first @@ -234,7 +208,14 @@ fn poll_setup( } Some(Ok(incoming_msg)) => { info!(initiator = %is_initiator, "recieved setup msg"); - if let Ok(msgs) = match handle_setup_message(step, &incoming_msg, is_initiator) { + if let Ok(msgs) = match handle_setup_message( + step, + &incoming_msg, + is_initiator, + encrypted_tx, + encrypted_rx, + flush, + ) { Ok(x) => Ok(x), Err(e) => { error!("handle_setup_message error: {e:?}"); @@ -252,6 +233,7 @@ fn poll_setup( } #[instrument(skip_all, fields(initiator = %is_initiator))] +/// Fills `encrypted_rx` and drains `encrypted_tx`. fn poll_encrypted_side_io< IO: Stream>> + Sink> + Send + Unpin + 'static, >( @@ -363,25 +345,64 @@ fn maybe_init(step: &mut Step, is_initiator: bool) -> Result>> { Ok(out) } +fn reset_encrypted( + step: &mut Step, + maybe_init_message: Option>, + encrypted_tx: &mut VecDeque>, + encrypted_rx: &mut VecDeque>>, + flush: &mut bool, +) { + *step = Step::NotInitialized; + encrypted_tx.clear(); + encrypted_rx.clear(); + if let Some(msg) = maybe_init_message { + encrypted_rx.push_front(Ok(msg)); + } + *flush = false; +} + +/// handle setup messages: if any are incorrect (cause an error) the state is reset #[instrument(skip_all, fields(initiator = %is_initiator))] -fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Result>> { +fn handle_setup_message( + step: &mut Step, + msg: &[u8], + is_initiator: bool, + encrypted_tx: &mut VecDeque>, + encrypted_rx: &mut VecDeque>>, + flush: &mut bool, +) -> Result>> { + // this would only happen after reset with a bad message. + let mut first_message = false; + if let Step::NotInitialized = step { + first_message = true; + assert!(!is_initiator); + warn!(initiator = %is_initiator, "Encrypted state was reset"); + let mut handshake = Handshake::new(is_initiator)?; + let _ = handshake.start_raw()?; + *step = Step::Handshake(Box::new(handshake)); + } match &step { Step::NotInitialized => { - warn!(initiator = %is_initiator, "Encrypted state was reset"); - let mut handshake = Handshake::new(is_initiator)?; - let start_msg = handshake.start_raw()?; - *step = Step::Handshake(Box::new(handshake)); - debug!(initiator = %is_initiator, "Step changed to {step}"); - - Ok(start_msg.map(|x| vec![x]).unwrap_or(vec![])) + unreachable!("should not happen") } Step::Handshake(_) => { + dbg!(); let mut out = vec![]; if let Step::Handshake(mut handshake) = replace(step, Step::NotInitialized) { trace!("Read in handshake msg\n{msg:?}"); if let Some(response) = match handshake.read_raw(msg) { Ok(x) => x, Err(e) => { + let maybe_init_message = + (!first_message && !is_initiator).then_some(msg.to_vec()); + + reset_encrypted( + step, + maybe_init_message, + encrypted_tx, + encrypted_rx, + flush, + ); return Err(e); } } { @@ -432,13 +453,46 @@ fn handle_setup_message(step: &mut Step, msg: &[u8], is_initiator: bool) -> Resu } } +impl std::fmt::Display for Step { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Step::NotInitialized => "NotInitialized", + Step::Handshake(_) => "Handshake", + Step::SecretStream(_) => "SecretStream", + Step::Established(_) => "Established", + } + ) + } +} + +impl std::fmt::Debug for Encrypted { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Encrypted") + //.field("io", &self.io) + //.field("step", &self.step) + .field("is_initiator", &self.is_initiator) + .field("encrypted_tx", &self.encrypted_tx) + .field("encrypted_rx", &self.encrypted_rx) + .field("plain_tx", &self.plain_tx) + .field("plain_rx", &self.plain_rx) + .field("flush", &self.flush) + .finish() + } +} + #[cfg(test)] mod tset { use crate::{framing::test::duplex, test_utils::create_result_connected, LengthPrefixed}; use super::*; - use futures::{future::join, SinkExt, StreamExt}; + use futures::{ + future::{join, select, Either}, + SinkExt, StreamExt, + }; #[tokio::test] async fn encrypted() -> Result<()> { @@ -555,4 +609,66 @@ mod tset { Ok(()) } + + #[tokio::test] + async fn test_setup_error_causes_re_init() -> Result<()> { + let (lc, mut init_side_messages) = create_result_connected(); + let (mut other_side_messages, rc) = create_result_connected(); + let mut left = Encrypted::new(true, lc); + let mut right = Encrypted::new(false, rc); + let hello = b"hello".to_vec(); + + let send_fut = tokio::task::spawn(async move { + left.send(hello).await.unwrap(); + left + }); + + let init_msg = init_side_messages.next().await.unwrap()?; + + other_side_messages.send(init_msg).await?; + // other side encrypted needs to be polled to do work and send a response + let other_send_fut = tokio::task::spawn(async move { + right.send(b"other hello".to_vec()).await.unwrap(); + right + }); + + let _first_response = other_side_messages.next().await.unwrap()?; + // both sides now have a handshake in progress + + // send a bad message to init side. It should reset, and emit new init msg + init_side_messages.send(b"bad msg".to_vec()).await?; + let new_init_msg = init_side_messages.next().await.unwrap()?; + + other_side_messages.send(new_init_msg).await?; + let new_response = other_side_messages.next().await.unwrap()?; + init_side_messages.send(new_response).await?; + let final_setup_message = init_side_messages.next().await.unwrap()?; + other_side_messages.send(final_setup_message).await?; + + // exchange one more message then we're set up + init_side_messages + .send(other_side_messages.next().await.unwrap()?) + .await?; + other_side_messages + .send(init_side_messages.next().await.unwrap()?) + .await?; + // now our spawned sends can complete + let mut left = send_fut.await?; + let mut right = other_send_fut.await?; + + // exchange hellos + init_side_messages + .send(other_side_messages.next().await.unwrap()?) + .await?; + other_side_messages + .send(init_side_messages.next().await.unwrap()?) + .await?; + + assert!(left.encryption_established()); + assert!(right.encryption_established()); + assert_eq!(right.next().await.unwrap()?, b"hello"); + assert_eq!(left.next().await.unwrap()?, b"other hello"); + + Ok(()) + } } From 8a2370bae2709cb35732344117f9ec410b8e2908 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 12:35:18 -0400 Subject: [PATCH 034/206] RMME show logs in example --- examples-nodejs/run.js | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples-nodejs/run.js b/examples-nodejs/run.js index c96541f..ac77bba 100644 --- a/examples-nodejs/run.js +++ b/examples-nodejs/run.js @@ -37,7 +37,8 @@ function startRust (mode, key, color, name) { color: color || 'blue', env: { ...process.env, - RUST_LOG_STYLE: 'always' + RUST_LOG_STYLE: 'always', + RUST_LOG: 'trace' } }) return rust From 918a307c75d6e773bb22c260c990f0866b9e965f Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 12:37:45 -0400 Subject: [PATCH 035/206] RMME extra docs --- src/crypto/cipher.rs | 3 +++ src/noise.rs | 2 +- src/protocol.rs | 9 +++++++++ src/reader.rs | 5 +++++ src/writer.rs | 19 +++++++++++++++++++ 5 files changed, 37 insertions(+), 1 deletion(-) diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index 4dfaf19..f2dc9b9 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -88,6 +88,7 @@ impl DecryptCipher { let decrypted_end = header_len + to_decrypt.len(); buf[header_len..decrypted_end].copy_from_slice(to_decrypt.as_slice()); // Set extra bytes in the buffer to 0 + // Why? let encrypted_end = header_len + body_len; buf[decrypted_end..encrypted_end].fill(0x00); Ok(decrypted_end) @@ -136,6 +137,8 @@ impl EncryptCipher { /// Encrypts message in the given buffer to the same buffer, returns number of bytes /// of total message. + /// NB: we expect the first 3 bytes of the buffer to a size header. + /// The encrypted buffer will also be written prepended with a size header, with it's new size. pub(crate) fn encrypt(&mut self, buf: &mut [u8]) -> io::Result { let stat = stat_uint24_le(buf); if let Some((header_len, body_len)) = stat { diff --git a/src/noise.rs b/src/noise.rs index 5126576..a7bd306 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -295,7 +295,7 @@ fn poll_encrypted_side_io< } /// Process messages waiting to be encrypted or decrypted -// TODO sholud this return a Result +// TODO sholud this return a Result? #[instrument(skip_all)] fn poll_do_encrypt_and_decrypt( encryptor: &mut RawEncryptCipher, diff --git a/src/protocol.rs b/src/protocol.rs index 1f24b1a..673b307 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -301,6 +301,7 @@ where let mut handshake = Handshake::new(self.options.is_initiator)?; // If the handshake start returns a buffer, send it now. if let Some(buf) = handshake.start()? { + // TODO what if this fails? or returns false self.queue_frame_direct(buf.to_vec()).unwrap(); } self.read_state.set_frame_type(FrameType::Raw); @@ -375,6 +376,7 @@ where if let Poll::Ready(Err(e)) = self.write_state.poll_send(cx, &mut self.io) { return Err(e); } + // if no parking or setup in progress if !self.write_state.can_park_frame() || !matches!(self.state, State::Established) { return Ok(()); } @@ -406,11 +408,17 @@ where State::SecretStream(_) => self.on_secret_stream_message(buf)?, State::Established => { if let Some(processed_state) = processed_state.as_ref() { + // last state before established let previous_state = if self.options.encrypted { + // was SecretStream if we're encrypted State::SecretStream(None) } else { + // or wa hasdshake if we're not encrypted State::Handshake(None) }; + + // if htis raw_batch included regular messages (not handshake) + // after handshake stuff if processed_state == &format!("{previous_state:?}") { // This is the unlucky case where the batch had two or more messages where // the first one was correctly identified as Raw but everything @@ -591,6 +599,7 @@ where self.queued_events.push_back(event); } + /// enequeu a buf to be sent fn queue_frame_direct(&mut self, body: Vec) -> Result { let mut frame = Frame::RawBatch(vec![body]); self.write_state.try_encode_frame_for_tx(&mut frame) diff --git a/src/reader.rs b/src/reader.rs index 51b370b..5664d56 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -113,6 +113,9 @@ impl ReadState { if success { if let Some(ref mut cipher) = self.cipher { let mut dec_end = self.start; + // What happens if decrypt fails here? + // next call to this func would have same start, corret? + // so it'd fail repeatedly? for (index, header_len, body_len) in segments { let de = cipher.decrypt( &mut self.buf[self.start + index..end], @@ -137,6 +140,7 @@ impl ReadState { } } + /// Moves start of unprocessed data to the start of the buffer. And resize if necessary. fn cycle_buf_and_resize_if_needed(&mut self, last_segment: (usize, usize, usize)) { let (last_index, last_header_len, last_body_len) = last_segment; let total_incoming_length = last_index + last_header_len + last_body_len; @@ -207,6 +211,7 @@ impl ReadState { } #[allow(clippy::type_complexity)] +// get segments from buff fn create_segments(buf: &[u8]) -> Result<(bool, Vec<(usize, usize, usize)>)> { let mut index: usize = 0; let len = buf.len(); diff --git a/src/writer.rs b/src/writer.rs index 38d6dcf..56bbaf6 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -9,6 +9,10 @@ use std::pin::Pin; use std::task::{Context, Poll}; const BUF_SIZE: usize = 1024 * 64; +// This is the largest size that will fit in u24. +// a message is larger than this we should error. +// also check message is smaller than this when we are encrypting. +const _MAX_MSG_SIZE: usize = 2usize.pow(24) - 1; #[derive(Debug)] pub(crate) enum Step { @@ -64,9 +68,12 @@ impl WriteState { pub(crate) fn try_encode_frame_for_tx(&mut self, frame: &mut T) -> Result { let promised_len = frame.encoded_len()?; let padded_promised_len = self.safe_encrypted_len(promised_len); + // this handles when a message would be longer than the entire buffer if self.buf.len() < padded_promised_len { self.buf.resize(padded_promised_len, 0u8); } + + // check we have enough room if padded_promised_len > self.remaining() { return Ok(false); } @@ -78,6 +85,14 @@ impl WriteState { "encoded_len() did not return that right size, expected={promised_len}, actual={actual_len}" ); } + // Instead of the above, write the buffer to a new vec `foo` of length `promised_length` + // encode frame.to this buff + // slice `foo[(header_len /* 3*/)..actual_len]` this is the fram data + // encrypt this in place + // replace header at start of foo + // write its len to self.buf and then write it to self.buf + // slice from + self.encrypt_frame_contents(padded_promised_len)?; Ok(true) } @@ -95,6 +110,10 @@ impl WriteState { } } + /// The frame should be written to `self.buf` before calling this. And + /// `self.should_write_up_to_idx` should mark the start of the message. + /// `max_message_size` is the maximum size the message could be when it is encrypted + /// We encrypt the message in-place on `self.buf`. fn encrypt_frame_contents(&mut self, max_message_size: usize) -> Result<()> { let end_of_message_index = self.should_write_up_to_idx + max_message_size; From ecc499a14ca979204748d04779f47542567f422a Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 12:38:03 -0400 Subject: [PATCH 036/206] add const header len --- src/framing.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/framing.rs b/src/framing.rs index 38b0f77..2e8cf32 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -14,7 +14,7 @@ use tracing::{debug, error, info, instrument, trace, warn}; use crate::util::{stat_uint24_le, wrap_uint24_le}; const BUF_SIZE: usize = 1024 * 64; -const HEADER_LEN: usize = 3; +const _HEADER_LEN: usize = 3; /// Turn a `AsyncWrite` of length prefixed messages and emit the messages with a Stream pub struct LengthPrefixed { From 8a96462b73174b5578a7b0dc766c809c72f5565e Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 12:38:29 -0400 Subject: [PATCH 037/206] lint --- src/noise.rs | 5 +---- src/protocol.rs | 2 +- src/test_utils.rs | 10 ++-------- 3 files changed, 4 insertions(+), 13 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index a7bd306..cf800f3 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -489,10 +489,7 @@ mod tset { use crate::{framing::test::duplex, test_utils::create_result_connected, LengthPrefixed}; use super::*; - use futures::{ - future::{join, select, Either}, - SinkExt, StreamExt, - }; + use futures::{future::join, SinkExt, StreamExt}; #[tokio::test] async fn encrypted() -> Result<()> { diff --git a/src/protocol.rs b/src/protocol.rs index 673b307..89f3df1 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -10,7 +10,7 @@ use std::io::{self, Error, ErrorKind, Result}; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; -use tracing::{info, trace}; +use tracing::trace; use crate::channels::{Channel, ChannelMap}; use crate::constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}; diff --git a/src/test_utils.rs b/src/test_utils.rs index d35af0e..ff1a3c2 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -7,9 +7,7 @@ use std::{ //use async_channel::{unbounded, Receiver, io::Error, Sender}; use futures::{ - channel::mpsc::{ - unbounded, SendError, UnboundedReceiver as Receiver, UnboundedSender as Sender, - }, + channel::mpsc::{unbounded, UnboundedReceiver as Receiver, UnboundedSender as Sender}, Sink, SinkExt, Stream, StreamExt, }; @@ -76,10 +74,6 @@ impl TwoWay { } } -pub(crate) fn create_connected() -> (Io, Io) { - TwoWay::default().split_sides() -} - pub(crate) fn log() { use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; static START_LOGS: OnceLock<()> = OnceLock::new(); @@ -123,7 +117,7 @@ pub(crate) struct Moo { impl + Unpin, Tx: Unpin> Stream for Moo { type Item = RxItem; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); Pin::new(&mut this.receiver).poll_next(cx) } From a85e42ea35a2a88073dff9dcf87b6b42e0912261 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 12:49:30 -0400 Subject: [PATCH 038/206] helpful names --- src/protocol.rs | 3 ++- src/writer.rs | 11 +++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/protocol.rs b/src/protocol.rs index 89f3df1..930f9bd 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -602,7 +602,8 @@ where /// enequeu a buf to be sent fn queue_frame_direct(&mut self, body: Vec) -> Result { let mut frame = Frame::RawBatch(vec![body]); - self.write_state.try_encode_frame_for_tx(&mut frame) + self.write_state + .try_encode_and_enqueue_frame_for_tx(&mut frame) } fn accept_channel(&mut self, local_id: usize) -> Result<()> { diff --git a/src/writer.rs b/src/writer.rs index 56bbaf6..d91adfb 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -65,7 +65,10 @@ impl WriteState { self.queue.push_back(frame.into()) } - pub(crate) fn try_encode_frame_for_tx(&mut self, frame: &mut T) -> Result { + pub(crate) fn try_encode_and_enqueue_frame_for_tx( + &mut self, + frame: &mut T, + ) -> Result { let promised_len = frame.encoded_len()?; let padded_promised_len = self.safe_encrypted_len(promised_len); // this handles when a message would be longer than the entire buffer @@ -93,7 +96,7 @@ impl WriteState { // write its len to self.buf and then write it to self.buf // slice from - self.encrypt_frame_contents(padded_promised_len)?; + self.encrypt_frame_contents_onto_buf(padded_promised_len)?; Ok(true) } @@ -114,7 +117,7 @@ impl WriteState { /// `self.should_write_up_to_idx` should mark the start of the message. /// `max_message_size` is the maximum size the message could be when it is encrypted /// We encrypt the message in-place on `self.buf`. - fn encrypt_frame_contents(&mut self, max_message_size: usize) -> Result<()> { + fn encrypt_frame_contents_onto_buf(&mut self, max_message_size: usize) -> Result<()> { let end_of_message_index = self.should_write_up_to_idx + max_message_size; let encrypted_end = if let Some(ref mut cipher) = self.cipher { @@ -157,7 +160,7 @@ impl WriteState { } if let Some(mut frame) = self.current_frame.take() { - if !self.try_encode_frame_for_tx(&mut frame)? { + if !self.try_encode_and_enqueue_frame_for_tx(&mut frame)? { self.current_frame = Some(frame); } } From 3e1483a21f40cd70030cebf1743293f9c8af4512 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 12:53:02 -0400 Subject: [PATCH 039/206] rename framing struct --- src/framing.rs | 20 ++++++++++---------- src/lib.rs | 2 +- src/noise.rs | 8 +++++--- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/framing.rs b/src/framing.rs index 2e8cf32..e7c9c5c 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -17,7 +17,7 @@ const BUF_SIZE: usize = 1024 * 64; const _HEADER_LEN: usize = 3; /// Turn a `AsyncWrite` of length prefixed messages and emit the messages with a Stream -pub struct LengthPrefixed { +pub struct Uint24LELengthPrefixedFraming { io: IO, /// Data from [`Self::io`]'s [`AsyncRead`] interface to be sent out via the [`Stream`] interface. to_stream: Vec, @@ -31,12 +31,12 @@ pub struct LengthPrefixed { /// Current step of a message being parsed step: Step, } -impl Debug for LengthPrefixed { +impl Debug for Uint24LELengthPrefixedFraming { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Format()") } } -impl LengthPrefixed +impl Uint24LELengthPrefixedFraming where IO: AsyncWrite + AsyncRead + Send + Unpin + 'static, { @@ -59,7 +59,7 @@ enum Step { Body { start: usize, end: u64 }, } -impl Stream for LengthPrefixed +impl Stream for Uint24LELengthPrefixedFraming where IO: AsyncWrite + AsyncRead + Send + Unpin + 'static, { @@ -131,7 +131,7 @@ where } } -impl Sink> for LengthPrefixed +impl Sink> for Uint24LELengthPrefixedFraming where IO: AsyncWrite + AsyncRead + Send + Unpin + 'static, { @@ -220,7 +220,7 @@ pub(crate) mod test { async fn input() -> Result<()> { log(); let (left, mut right) = duplex(64); - let mut lp = LengthPrefixed::new(left); + let mut lp = Uint24LELengthPrefixedFraming::new(left); let input = b"yelp"; let msg = wrap_uint24_le(input); dbg!(&msg); @@ -235,7 +235,7 @@ pub(crate) mod test { async fn stream_many() -> Result<()> { log(); let (left, mut right) = duplex(64); - let mut lp = LengthPrefixed::new(left); + let mut lp = Uint24LELengthPrefixedFraming::new(left); let data: &[&[u8]] = &[b"yolo", b"squalor", b"idle", b"hello", b"stuff"]; for d in data { let msg = wrap_uint24_le(d); @@ -255,7 +255,7 @@ pub(crate) mod test { async fn sink_many() -> Result<()> { log(); let (left, mut right) = duplex(64); - let mut lp = LengthPrefixed::new(left); + let mut lp = Uint24LELengthPrefixedFraming::new(left); let data: &[&[u8]] = &[b"yolo", b"squalor", b"idle", b"hello", b"stuff"]; for d in data { lp.send(d.to_vec()).await.unwrap(); @@ -273,8 +273,8 @@ pub(crate) mod test { async fn left_and_right() -> Result<()> { let (left, right) = duplex(64); - let mut leftlp = LengthPrefixed::new(left); - let mut rightlp = LengthPrefixed::new(right); + let mut leftlp = Uint24LELengthPrefixedFraming::new(left); + let mut rightlp = Uint24LELengthPrefixedFraming::new(right); let data: &[&[u8]] = &[b"yolo", b"squalor", b"idle", b"hello", b"stuff"]; for d in data { diff --git a/src/lib.rs b/src/lib.rs index b1a043a..07a677b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -137,7 +137,7 @@ pub mod schema; pub use builder::Builder as ProtocolBuilder; pub use channels::Channel; -pub use framing::LengthPrefixed; +pub use framing::Uint24LELengthPrefixedFraming; pub use noise::Encrypted; // Export the needed types for Channel::take_receiver, and Channel::local_sender() pub use async_channel::{ diff --git a/src/noise.rs b/src/noise.rs index cf800f3..cf651ff 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -486,7 +486,9 @@ impl std::fmt::Debug for Encrypted { #[cfg(test)] mod tset { - use crate::{framing::test::duplex, test_utils::create_result_connected, LengthPrefixed}; + use crate::{ + framing::test::duplex, test_utils::create_result_connected, Uint24LELengthPrefixedFraming, + }; use super::*; use futures::{future::join, SinkExt, StreamExt}; @@ -553,8 +555,8 @@ mod tset { let hello = b"hello".to_vec(); let (left, right) = duplex(1024 * 64); - let left = LengthPrefixed::new(left); - let right = LengthPrefixed::new(right); + let left = Uint24LELengthPrefixedFraming::new(left); + let right = Uint24LELengthPrefixedFraming::new(right); let mut left = Encrypted::new(true, left); let mut right = Encrypted::new(false, right); From 2a602911489e65e1a316950328a43359ffab0ac3 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 13:03:13 -0400 Subject: [PATCH 040/206] Add func for building encrypted framed channel --- src/lib.rs | 2 +- src/noise.rs | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 07a677b..e4c0744 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -138,7 +138,7 @@ pub mod schema; pub use builder::Builder as ProtocolBuilder; pub use channels::Channel; pub use framing::Uint24LELengthPrefixedFraming; -pub use noise::Encrypted; +pub use noise::{encyrpted_framed_message_channel, Encrypted}; // Export the needed types for Channel::take_receiver, and Channel::local_sender() pub use async_channel::{ Receiver as ChannelReceiver, SendError as ChannelSendError, Sender as ChannelSender, diff --git a/src/noise.rs b/src/noise.rs index cf651ff..c2269ea 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -1,4 +1,4 @@ -use futures::{Sink, Stream}; +use futures::{AsyncRead, AsyncWrite, Sink, Stream}; use std::{ collections::VecDeque, fmt::Debug, @@ -9,8 +9,19 @@ use std::{ }; use tracing::{debug, error, info, instrument, trace, warn}; -use crate::crypto::{DecryptCipher, Handshake, HandshakeResult, RawEncryptCipher}; +use crate::{ + crypto::{DecryptCipher, Handshake, HandshakeResult, RawEncryptCipher}, + Uint24LELengthPrefixedFraming, +}; +/// Create a framed and encrypted Stream/Sink that reads/writes to an AsyncRead/AsyncWrite. +pub fn encyrpted_framed_message_channel( + is_initiator: bool, + io: IO, +) -> Encrypted> { + let framed = Uint24LELengthPrefixedFraming::new(io); + Encrypted::new(is_initiator, framed) +} #[derive(Debug)] pub(crate) enum Step { NotInitialized, From 41fd39e76d54a53e340821650fa0f0f2b80eb62c Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 13:24:52 -0400 Subject: [PATCH 041/206] Move Protocol in prep for feature flagging --- src/protocol/mod.rs | 4 ++++ src/{protocol.rs => protocol/old.rs} | 0 2 files changed, 4 insertions(+) create mode 100644 src/protocol/mod.rs rename src/{protocol.rs => protocol/old.rs} (100%) diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs new file mode 100644 index 0000000..18be509 --- /dev/null +++ b/src/protocol/mod.rs @@ -0,0 +1,4 @@ +mod old; + +pub(crate) use old::Options; +pub use old::{Command, CommandTx, DiscoveryKey, Event, Key, Protocol}; diff --git a/src/protocol.rs b/src/protocol/old.rs similarity index 100% rename from src/protocol.rs rename to src/protocol/old.rs From 8fca159fd9b5983dd678f3eb77a7a490e6b6ae9f Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 13:47:10 -0400 Subject: [PATCH 042/206] add second protocol impl behind feature flag --- Cargo.toml | 1 + src/protocol/mod.rs | 11 +- src/protocol/modern.rs | 697 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 708 insertions(+), 1 deletion(-) create mode 100644 src/protocol/modern.rs diff --git a/Cargo.toml b/Cargo.toml index 170c32f..82790eb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -65,6 +65,7 @@ tokio-util = { version = "0.7.14", features = ["compat"] } [features] default = ["tokio", "sparse"] +protocol = [] wasm-bindgen = [ "futures-timer/wasm-bindgen" ] diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 18be509..7382df8 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -1,4 +1,13 @@ -mod old; +#[cfg(feature = "protocol")] +mod modern; +#[cfg(feature = "protocol")] +pub(crate) use modern::Options; +#[cfg(feature = "protocol")] +pub use modern::{Command, CommandTx, DiscoveryKey, Event, Key, Protocol}; +#[cfg(not(feature = "protocol"))] +mod old; +#[cfg(not(feature = "protocol"))] pub(crate) use old::Options; +#[cfg(not(feature = "protocol"))] pub use old::{Command, CommandTx, DiscoveryKey, Event, Key, Protocol}; diff --git a/src/protocol/modern.rs b/src/protocol/modern.rs new file mode 100644 index 0000000..930f9bd --- /dev/null +++ b/src/protocol/modern.rs @@ -0,0 +1,697 @@ +use async_channel::{Receiver, Sender}; +use futures_lite::io::{AsyncRead, AsyncWrite}; +use futures_lite::stream::Stream; +use futures_timer::Delay; +use std::collections::VecDeque; +use std::convert::TryInto; +use std::fmt; +use std::future::Future; +use std::io::{self, Error, ErrorKind, Result}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Duration; +use tracing::trace; + +use crate::channels::{Channel, ChannelMap}; +use crate::constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}; +use crate::crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeResult}; +use crate::message::{ChannelMessage, Frame, FrameType, Message}; +use crate::reader::ReadState; +use crate::schema::*; +use crate::util::{map_channel_err, pretty_hash}; +use crate::writer::WriteState; + +macro_rules! return_error { + ($msg:expr) => { + if let Err(e) = $msg { + return Poll::Ready(Err(e)); + } + }; +} + +const CHANNEL_CAP: usize = 1000; +const KEEPALIVE_DURATION: Duration = Duration::from_secs(DEFAULT_KEEPALIVE as u64); + +/// Options for a Protocol instance. +#[derive(Debug)] +pub(crate) struct Options { + /// Whether this peer initiated the IO connection for this protoccol + pub(crate) is_initiator: bool, + /// Enable or disable the handshake. + /// Disabling the handshake will also disable capabilitity verification. + /// Don't disable this if you're not 100% sure you want this. + pub(crate) noise: bool, + /// Enable or disable transport encryption. + pub(crate) encrypted: bool, +} + +impl Options { + /// Create with default options. + pub(crate) fn new(is_initiator: bool) -> Self { + Self { + is_initiator, + noise: true, + encrypted: true, + } + } +} + +/// Remote public key (32 bytes). +pub(crate) type RemotePublicKey = [u8; 32]; +/// Discovery key (32 bytes). +pub type DiscoveryKey = [u8; 32]; +/// Key (32 bytes). +pub type Key = [u8; 32]; + +/// A protocol event. +#[non_exhaustive] +#[derive(PartialEq)] +pub enum Event { + /// Emitted after the handshake with the remote peer is complete. + /// This is the first event (if the handshake is not disabled). + Handshake(RemotePublicKey), + /// Emitted when the remote peer opens a channel that we did not yet open. + DiscoveryKey(DiscoveryKey), + /// Emitted when a channel is established. + Channel(Channel), + /// Emitted when a channel is closed. + Close(DiscoveryKey), + /// Convenience event to make it possible to signal the protocol from a channel. + /// See channel.signal_local() and protocol.commands().signal_local(). + LocalSignal((String, Vec)), +} + +/// A protocol command. +#[derive(Debug)] +pub enum Command { + /// Open a channel + Open(Key), + /// Close a channel by discovery key + Close(DiscoveryKey), + /// Signal locally to protocol + SignalLocal((String, Vec)), +} + +impl fmt::Debug for Event { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Event::Handshake(remote_key) => { + write!(f, "Handshake(remote_key={})", &pretty_hash(remote_key)) + } + Event::DiscoveryKey(discovery_key) => { + write!(f, "DiscoveryKey({})", &pretty_hash(discovery_key)) + } + Event::Channel(channel) => { + write!(f, "Channel({})", &pretty_hash(channel.discovery_key())) + } + Event::Close(discovery_key) => write!(f, "Close({})", &pretty_hash(discovery_key)), + Event::LocalSignal((name, data)) => { + write!(f, "LocalSignal(name={},len={})", name, data.len()) + } + } + } +} + +/// Protocol state +#[allow(clippy::large_enum_variant)] +pub(crate) enum State { + NotInitialized, + // The Handshake struct sits behind an option only so that we can .take() + // it out, it's never actually empty when in State::Handshake. + Handshake(Option), + SecretStream(Option), + Established, +} + +impl fmt::Debug for State { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + State::NotInitialized => write!(f, "NotInitialized"), + State::Handshake(_) => write!(f, "Handshaking"), + State::SecretStream(_) => write!(f, "SecretStream"), + State::Established => write!(f, "Established"), + } + } +} + +/// A Protocol stream. +pub struct Protocol { + write_state: WriteState, + read_state: ReadState, + io: IO, + state: State, + options: Options, + handshake: Option, + channels: ChannelMap, + command_rx: Receiver, + command_tx: CommandTx, + outbound_rx: Receiver>, + outbound_tx: Sender>, + keepalive: Delay, + queued_events: VecDeque, +} + +impl std::fmt::Debug for Protocol { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Protocol") + .field("write_state", &self.write_state) + .field("read_state", &self.read_state) + //.field("io", &self.io) + .field("state", &self.state) + .field("options", &self.options) + .field("handshake", &self.handshake) + .field("channels", &self.channels) + .field("command_rx", &self.command_rx) + .field("command_tx", &self.command_tx) + .field("outbound_rx", &self.outbound_rx) + .field("outbound_tx", &self.outbound_tx) + .field("keepalive", &self.keepalive) + .field("queued_events", &self.queued_events) + .finish() + } +} + +impl Protocol +where + IO: AsyncWrite + AsyncRead + Send + Unpin + 'static, +{ + /// Create a new protocol instance. + pub(crate) fn new(io: IO, options: Options) -> Self { + let (command_tx, command_rx) = async_channel::bounded(CHANNEL_CAP); + let (outbound_tx, outbound_rx): ( + Sender>, + Receiver>, + ) = async_channel::bounded(1); + Protocol { + io, + read_state: ReadState::new(), + write_state: WriteState::new(), + options, + state: State::NotInitialized, + channels: ChannelMap::new(), + handshake: None, + command_rx, + command_tx: CommandTx(command_tx), + outbound_tx, + outbound_rx, + keepalive: Delay::new(Duration::from_secs(DEFAULT_KEEPALIVE as u64)), + queued_events: VecDeque::new(), + } + } + + /// Whether this protocol stream initiated the underlying IO connection. + pub fn is_initiator(&self) -> bool { + self.options.is_initiator + } + + /// Get your own Noise public key. + /// + /// Empty before the handshake completed. + pub fn public_key(&self) -> Option<&[u8]> { + match &self.handshake { + None => None, + Some(handshake) => Some(handshake.local_pubkey.as_slice()), + } + } + + /// Get the remote's Noise public key. + /// + /// Empty before the handshake completed. + pub fn remote_public_key(&self) -> Option<&[u8]> { + match &self.handshake { + None => None, + Some(handshake) => Some(handshake.remote_pubkey.as_slice()), + } + } + + /// Get a sender to send commands. + pub fn commands(&self) -> CommandTx { + self.command_tx.clone() + } + + /// Give a command to the protocol. + pub async fn command(&mut self, command: Command) -> Result<()> { + self.command_tx.send(command).await + } + + /// Open a new protocol channel. + /// + /// Once the other side proofed that it also knows the `key`, the channel is emitted as + /// `Event::Channel` on the protocol event stream. + pub async fn open(&mut self, key: Key) -> Result<()> { + self.command_tx.open(key).await + } + + /// Iterator of all currently opened channels. + pub fn channels(&self) -> impl Iterator { + self.channels.iter().map(|c| c.discovery_key()) + } + + /// Stop the protocol and return the inner reader and writer. + pub fn release(self) -> IO { + self.io + } + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + if let State::NotInitialized = this.state { + return_error!(this.init()); + } + + // Drain queued events first. + if let Some(event) = this.queued_events.pop_front() { + return Poll::Ready(Ok(event)); + } + + // Read and process incoming messages. + return_error!(this.poll_inbound_read(cx)); + + if let State::Established = this.state { + // Check for commands, but only once the connection is established. + return_error!(this.poll_commands(cx)); + } + + // Poll the keepalive timer. + this.poll_keepalive(cx); + + // Write everything we can write. + return_error!(this.poll_outbound_write(cx)); + + // Check if any events are enqueued. + if let Some(event) = this.queued_events.pop_front() { + Poll::Ready(Ok(event)) + } else { + Poll::Pending + } + } + + fn init(&mut self) -> Result<()> { + trace!( + "protocol Init, state {:?}, options {:?}", + self.state, + self.options + ); + match self.state { + State::NotInitialized => {} + _ => return Ok(()), + }; + + self.state = if self.options.noise { + let mut handshake = Handshake::new(self.options.is_initiator)?; + // If the handshake start returns a buffer, send it now. + if let Some(buf) = handshake.start()? { + // TODO what if this fails? or returns false + self.queue_frame_direct(buf.to_vec()).unwrap(); + } + self.read_state.set_frame_type(FrameType::Raw); + State::Handshake(Some(handshake)) + } else { + self.read_state.set_frame_type(FrameType::Message); + State::Established + }; + + Ok(()) + } + + /// Poll commands. + fn poll_commands(&mut self, cx: &mut Context<'_>) -> Result<()> { + while let Poll::Ready(Some(command)) = Pin::new(&mut self.command_rx).poll_next(cx) { + self.on_command(command)?; + } + Ok(()) + } + + /// Poll the keepalive timer and queue a ping message if needed. + fn poll_keepalive(&mut self, cx: &mut Context<'_>) { + if Pin::new(&mut self.keepalive).poll(cx).is_ready() { + if let State::Established = self.state { + // 24 bit header for the empty message, hence the 3 + self.write_state + .queue_frame(Frame::RawBatch(vec![vec![0u8; 3]])); + } + self.keepalive.reset(KEEPALIVE_DURATION); + } + } + + fn on_outbound_message(&mut self, message: &ChannelMessage) -> bool { + // If message is close, close the local channel. + if let ChannelMessage { + channel, + message: Message::Close(_), + .. + } = message + { + self.close_local(*channel); + // If message is a LocalSignal, emit an event and return false to indicate + // this message should be filtered out. + } else if let ChannelMessage { + message: Message::LocalSignal((name, data)), + .. + } = message + { + self.queue_event(Event::LocalSignal((name.to_string(), data.to_vec()))); + return false; + } + true + } + + /// Poll for inbound messages and processs them. + fn poll_inbound_read(&mut self, cx: &mut Context<'_>) -> Result<()> { + loop { + let msg = self.read_state.poll_reader(cx, &mut self.io); + match msg { + Poll::Ready(Ok(message)) => { + self.on_inbound_frame(message)?; + } + Poll::Ready(Err(e)) => return Err(e), + Poll::Pending => return Ok(()), + } + } + } + + /// Poll for outbound messages and write them. + fn poll_outbound_write(&mut self, cx: &mut Context<'_>) -> Result<()> { + loop { + if let Poll::Ready(Err(e)) = self.write_state.poll_send(cx, &mut self.io) { + return Err(e); + } + // if no parking or setup in progress + if !self.write_state.can_park_frame() || !matches!(self.state, State::Established) { + return Ok(()); + } + + match Pin::new(&mut self.outbound_rx).poll_next(cx) { + Poll::Ready(Some(mut messages)) => { + if !messages.is_empty() { + messages.retain(|message| self.on_outbound_message(message)); + if !messages.is_empty() { + let frame = Frame::MessageBatch(messages); + self.write_state.park_frame(frame); + } + } + } + Poll::Ready(None) => unreachable!("Channel closed before end"), + Poll::Pending => return Ok(()), + } + } + } + + fn on_inbound_frame(&mut self, frame: Frame) -> Result<()> { + match frame { + Frame::RawBatch(raw_batch) => { + let mut processed_state: Option = None; + for buf in raw_batch { + let state_name: String = format!("{:?}", self.state); + match self.state { + State::Handshake(_) => self.on_handshake_message(buf)?, + State::SecretStream(_) => self.on_secret_stream_message(buf)?, + State::Established => { + if let Some(processed_state) = processed_state.as_ref() { + // last state before established + let previous_state = if self.options.encrypted { + // was SecretStream if we're encrypted + State::SecretStream(None) + } else { + // or wa hasdshake if we're not encrypted + State::Handshake(None) + }; + + // if htis raw_batch included regular messages (not handshake) + // after handshake stuff + if processed_state == &format!("{previous_state:?}") { + // This is the unlucky case where the batch had two or more messages where + // the first one was correctly identified as Raw but everything + // after that should have been (decrypted and) a MessageBatch. Correct the mistake + // here post-hoc. + let buf = self.read_state.decrypt_buf(&buf)?; + let frame = Frame::decode(&buf, &FrameType::Message)?; + self.on_inbound_frame(frame)?; + continue; + } + } + unreachable!( + "May not receive raw frames in Established state" + ) + } + _ => unreachable!( + "May not receive raw frames outside of handshake or secretstream state, was {:?}", + self.state + ), + }; + if processed_state.is_none() { + processed_state = Some(state_name) + } + } + Ok(()) + } + Frame::MessageBatch(channel_messages) => match self.state { + State::Established => { + for channel_message in channel_messages { + self.on_inbound_message(channel_message)? + } + Ok(()) + } + _ => unreachable!("May not receive message batch frames when not established"), + }, + } + } + + fn on_handshake_message(&mut self, buf: Vec) -> Result<()> { + let mut handshake = match &mut self.state { + State::Handshake(handshake) => handshake.take().unwrap(), + _ => unreachable!("May not call on_handshake_message when not in Handshake state"), + }; + + if let Some(response_buf) = handshake.read(&buf)? { + self.queue_frame_direct(response_buf.to_vec()).unwrap(); + } + + if !handshake.complete() { + self.state = State::Handshake(Some(handshake)); + } else { + let handshake_result = handshake.into_result()?; + + if self.options.encrypted { + // The cipher will be put to use to the writer only after the peer's answer has come + let (cipher, init_msg) = EncryptCipher::from_handshake_tx(handshake_result)?; + self.state = State::SecretStream(Some(cipher)); + + // Send the secret stream init message header to the other side + self.queue_frame_direct(init_msg).unwrap(); + } else { + // Skip secret stream and go straight to Established, then notify about + // handshake + self.read_state.set_frame_type(FrameType::Message); + let remote_public_key = parse_key(&handshake_result.remote_pubkey)?; + self.queue_event(Event::Handshake(remote_public_key)); + self.state = State::Established; + } + // Store handshake result + self.handshake = Some(handshake_result.clone()); + } + Ok(()) + } + + fn on_secret_stream_message(&mut self, buf: Vec) -> Result<()> { + let encrypt_cipher = match &mut self.state { + State::SecretStream(encrypt_cipher) => encrypt_cipher.take().unwrap(), + _ => { + unreachable!("May not call on_secret_stream_message when not in SecretStream state") + } + }; + let handshake_result = &self + .handshake + .as_ref() + .expect("Handshake result must be set before secret stream"); + let decrypt_cipher = DecryptCipher::from_handshake_rx_and_init_msg(handshake_result, &buf)?; + self.read_state.upgrade_with_decrypt_cipher(decrypt_cipher); + self.write_state.upgrade_with_encrypt_cipher(encrypt_cipher); + self.read_state.set_frame_type(FrameType::Message); + + // Lastly notify that handshake is ready and set state to established + let remote_public_key = parse_key(&handshake_result.remote_pubkey)?; + self.queue_event(Event::Handshake(remote_public_key)); + self.state = State::Established; + Ok(()) + } + + fn on_inbound_message(&mut self, channel_message: ChannelMessage) -> Result<()> { + // let channel_message = ChannelMessage::decode(buf)?; + let (remote_id, message) = channel_message.into_split(); + match message { + Message::Open(msg) => self.on_open(remote_id, msg)?, + Message::Close(msg) => self.on_close(remote_id, msg)?, + _ => self + .channels + .forward_inbound_message(remote_id as usize, message)?, + } + Ok(()) + } + + fn on_command(&mut self, command: Command) -> Result<()> { + match command { + Command::Open(key) => self.command_open(key), + Command::Close(discovery_key) => self.command_close(discovery_key), + Command::SignalLocal((name, data)) => self.command_signal_local(name, data), + } + } + + /// Open a Channel with the given key. Adding it to our channel map + fn command_open(&mut self, key: Key) -> Result<()> { + // Create a new channel. + let channel_handle = self.channels.attach_local(key); + // Safe because attach_local always puts Some(local_id) + let local_id = channel_handle.local_id().unwrap(); + let discovery_key = *channel_handle.discovery_key(); + + // If the channel was already opened from the remote end, verify, and if + // verification is ok, push a channel open event. + if channel_handle.is_connected() { + self.accept_channel(local_id)?; + } + + // Tell the remote end about the new channel. + let capability = self.capability(&key); + let channel = local_id as u64; + let message = Message::Open(Open { + channel, + protocol: PROTOCOL_NAME.to_string(), + discovery_key: discovery_key.to_vec(), + capability, + }); + let channel_message = ChannelMessage::new(channel, message); + self.write_state + .queue_frame(Frame::MessageBatch(vec![channel_message])); + Ok(()) + } + + fn command_close(&mut self, discovery_key: DiscoveryKey) -> Result<()> { + if self.channels.has_channel(&discovery_key) { + self.channels.remove(&discovery_key); + self.queue_event(Event::Close(discovery_key)); + } + Ok(()) + } + + fn command_signal_local(&mut self, name: String, data: Vec) -> Result<()> { + self.queue_event(Event::LocalSignal((name, data))); + Ok(()) + } + + fn on_open(&mut self, ch: u64, msg: Open) -> Result<()> { + let discovery_key: DiscoveryKey = parse_key(&msg.discovery_key)?; + let channel_handle = + self.channels + .attach_remote(discovery_key, ch as usize, msg.capability); + + if channel_handle.is_connected() { + let local_id = channel_handle.local_id().unwrap(); + self.accept_channel(local_id)?; + } else { + self.queue_event(Event::DiscoveryKey(discovery_key)); + } + + Ok(()) + } + + fn queue_event(&mut self, event: Event) { + self.queued_events.push_back(event); + } + + /// enequeu a buf to be sent + fn queue_frame_direct(&mut self, body: Vec) -> Result { + let mut frame = Frame::RawBatch(vec![body]); + self.write_state + .try_encode_and_enqueue_frame_for_tx(&mut frame) + } + + fn accept_channel(&mut self, local_id: usize) -> Result<()> { + let (key, remote_capability) = self.channels.prepare_to_verify(local_id)?; + self.verify_remote_capability(remote_capability.cloned(), key)?; + let channel = self.channels.accept(local_id, self.outbound_tx.clone())?; + self.queue_event(Event::Channel(channel)); + Ok(()) + } + + fn close_local(&mut self, local_id: u64) { + if let Some(channel) = self.channels.get_local(local_id as usize) { + let discovery_key = *channel.discovery_key(); + self.channels.remove(&discovery_key); + self.queue_event(Event::Close(discovery_key)); + } + } + + fn on_close(&mut self, remote_id: u64, msg: Close) -> Result<()> { + if let Some(channel_handle) = self.channels.get_remote(remote_id as usize) { + let discovery_key = *channel_handle.discovery_key(); + // There is a possibility both sides will close at the same time, so + // the channel could be closed already, let's tolerate that. + self.channels + .forward_inbound_message_tolerate_closed(remote_id as usize, Message::Close(msg))?; + self.channels.remove(&discovery_key); + self.queue_event(Event::Close(discovery_key)); + } + Ok(()) + } + + fn capability(&self, key: &[u8]) -> Option> { + match self.handshake.as_ref() { + Some(handshake) => handshake.capability(key), + None => None, + } + } + + fn verify_remote_capability(&self, capability: Option>, key: &[u8]) -> Result<()> { + match self.handshake.as_ref() { + Some(handshake) => handshake.verify_remote_capability(capability, key), + None => Err(Error::new( + ErrorKind::PermissionDenied, + "Missing handshake state for capability verification", + )), + } + } +} + +impl Stream for Protocol +where + IO: AsyncRead + AsyncWrite + Send + Unpin + 'static, +{ + type Item = Result; + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Protocol::poll_next(self, cx).map(Some) + } +} + +/// Send [Command](Command)s to the [Protocol](Protocol). +#[derive(Clone, Debug)] +pub struct CommandTx(Sender); + +impl CommandTx { + /// Send a protocol command + pub async fn send(&mut self, command: Command) -> Result<()> { + self.0.send(command).await.map_err(map_channel_err) + } + /// Open a protocol channel. + /// + /// The channel will be emitted on the main protocol. + pub async fn open(&mut self, key: Key) -> Result<()> { + self.send(Command::Open(key)).await + } + + /// Close a protocol channel. + pub async fn close(&mut self, discovery_key: DiscoveryKey) -> Result<()> { + self.send(Command::Close(discovery_key)).await + } + + /// Send a local signal event to the protocol. + pub async fn signal_local(&mut self, name: &str, data: Vec) -> Result<()> { + self.send(Command::SignalLocal((name.to_string(), data))) + .await + } +} + +fn parse_key(key: &[u8]) -> io::Result<[u8; 32]> { + key.try_into() + .map_err(|_e| io::Error::new(io::ErrorKind::InvalidInput, "Key must be 32 bytes long")) +} From e7e7dd7f619afd761f836de64d3ffdd428165074 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 13:53:17 -0400 Subject: [PATCH 043/206] fix spelling --- src/lib.rs | 2 +- src/noise.rs | 2 +- src/protocol/modern.rs | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index e4c0744..88b3d32 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -138,7 +138,7 @@ pub mod schema; pub use builder::Builder as ProtocolBuilder; pub use channels::Channel; pub use framing::Uint24LELengthPrefixedFraming; -pub use noise::{encyrpted_framed_message_channel, Encrypted}; +pub use noise::{encrypted_framed_message_channel, Encrypted}; // Export the needed types for Channel::take_receiver, and Channel::local_sender() pub use async_channel::{ Receiver as ChannelReceiver, SendError as ChannelSendError, Sender as ChannelSender, diff --git a/src/noise.rs b/src/noise.rs index c2269ea..3dbf660 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -15,7 +15,7 @@ use crate::{ }; /// Create a framed and encrypted Stream/Sink that reads/writes to an AsyncRead/AsyncWrite. -pub fn encyrpted_framed_message_channel( +pub fn encrypted_framed_message_channel( is_initiator: bool, io: IO, ) -> Encrypted> { diff --git a/src/protocol/modern.rs b/src/protocol/modern.rs index 930f9bd..6e599da 100644 --- a/src/protocol/modern.rs +++ b/src/protocol/modern.rs @@ -17,9 +17,11 @@ use crate::constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}; use crate::crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeResult}; use crate::message::{ChannelMessage, Frame, FrameType, Message}; use crate::reader::ReadState; -use crate::schema::*; use crate::util::{map_channel_err, pretty_hash}; use crate::writer::WriteState; +use crate::{ + encrypted_framed_message_channel, schema::*, Encrypted, Uint24LELengthPrefixedFraming, +}; macro_rules! return_error { ($msg:expr) => { From bb390b463d5077dccb2b7add5caa5ffb20c94c8d Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 14:03:48 -0400 Subject: [PATCH 044/206] WIP Drop in encrypted channel compiles, fixed some errors, 'unreachable_code' warnings bc of todo!()'s --- benches/throughput.rs | 7 +++---- src/protocol/modern.rs | 19 ++++++++++--------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/benches/throughput.rs b/benches/throughput.rs index 76d6874..7f9890d 100644 --- a/benches/throughput.rs +++ b/benches/throughput.rs @@ -4,7 +4,7 @@ use criterion::{criterion_group, criterion_main, Criterion, Throughput}; use futures::future::Either; use futures::io::{AsyncRead, AsyncWrite}; use futures::stream::{FuturesUnordered, StreamExt}; -use hypercore_protocol::{schema::*, Duplex}; +use hypercore_protocol::schema::*; use hypercore_protocol::{Channel, Event, Message, ProtocolBuilder}; use log::*; use std::time::Instant; @@ -88,7 +88,7 @@ async fn start_server(address: &str) -> futures::channel::oneshot::Sender<()> { kill_tx } -async fn onconnection(reader: R, writer: W, is_initiator: bool) -> Duplex +async fn onconnection(reader: R, writer: W, is_initiator: bool) where R: AsyncRead + Send + Unpin + 'static, W: AsyncWrite + Send + Unpin + 'static, @@ -108,12 +108,11 @@ where task::spawn(onchannel(channel, is_initiator)); } Event::Close(_dkey) => { - return protocol.release(); + return; } _ => {} } } - protocol.release() } async fn onchannel(mut channel: Channel, is_initiator: bool) { diff --git a/src/protocol/modern.rs b/src/protocol/modern.rs index 6e599da..e0e3c3a 100644 --- a/src/protocol/modern.rs +++ b/src/protocol/modern.rs @@ -140,7 +140,7 @@ impl fmt::Debug for State { pub struct Protocol { write_state: WriteState, read_state: ReadState, - io: IO, + io: Encrypted>, state: State, options: Options, handshake: Option, @@ -184,8 +184,9 @@ where Sender>, Receiver>, ) = async_channel::bounded(1); + Protocol { - io, + io: encrypted_framed_message_channel(options.is_initiator, io), read_state: ReadState::new(), write_state: WriteState::new(), options, @@ -250,7 +251,7 @@ where } /// Stop the protocol and return the inner reader and writer. - pub fn release(self) -> IO { + pub fn release(self) -> Encrypted> { self.io } @@ -359,10 +360,10 @@ where } /// Poll for inbound messages and processs them. - fn poll_inbound_read(&mut self, cx: &mut Context<'_>) -> Result<()> { + fn poll_inbound_read(&mut self, _cx: &mut Context<'_>) -> Result<()> { loop { - let msg = self.read_state.poll_reader(cx, &mut self.io); - match msg { + //let msg = self.read_state.poll_reader(cx, &mut self.io); + match todo!() { Poll::Ready(Ok(message)) => { self.on_inbound_frame(message)?; } @@ -375,9 +376,9 @@ where /// Poll for outbound messages and write them. fn poll_outbound_write(&mut self, cx: &mut Context<'_>) -> Result<()> { loop { - if let Poll::Ready(Err(e)) = self.write_state.poll_send(cx, &mut self.io) { - return Err(e); - } + //if let Poll::Ready(Err(e)) = self.write_state.poll_send(cx, &mut self.io) { + // return Err(e); + //} // if no parking or setup in progress if !self.write_state.can_park_frame() || !matches!(self.state, State::Established) { return Ok(()); From df1e63405aaeb6f3ce58e1276c8c7678bfbf6f76 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 15:24:16 -0400 Subject: [PATCH 045/206] make Encoder trait immutable self --- src/message.rs | 82 +++++++++++++++++++++++--------------------------- 1 file changed, 38 insertions(+), 44 deletions(-) diff --git a/src/message.rs b/src/message.rs index 27b74c1..832a5f4 100644 --- a/src/message.rs +++ b/src/message.rs @@ -20,20 +20,20 @@ pub(crate) enum FrameType { /// (channel messages, messages, and individual message types through prost). pub(crate) trait Encoder: Sized + fmt::Debug { /// Calculates the length that the encoded message needs. - fn encoded_len(&mut self) -> Result; + fn encoded_len(&self) -> Result; /// Encodes the message to a buffer. /// /// An error will be returned if the buffer does not have sufficient capacity. - fn encode(&mut self, buf: &mut [u8]) -> Result; + fn encode(&self, buf: &mut [u8]) -> Result; } impl Encoder for &[u8] { - fn encoded_len(&mut self) -> Result { + fn encoded_len(&self) -> Result { Ok(self.len()) } - fn encode(&mut self, buf: &mut [u8]) -> Result { + fn encode(&self, buf: &mut [u8]) -> Result { let len = self.encoded_len()?; if len > buf.len() { return Err(EncodingError::new( @@ -232,7 +232,7 @@ impl Frame { } } - fn preencode(&mut self, state: &mut State) -> Result { + fn preencode(&self, state: &mut State) -> Result { match self { Self::RawBatch(raw_batch) => { for raw in raw_batch { @@ -257,7 +257,7 @@ impl Frame { state.add_end(2)?; let mut current_channel: u64 = messages[0].channel; state.preencode(¤t_channel)?; - for message in messages.iter_mut() { + for message in messages.iter() { if message.channel != current_channel { // Channel changed, need to add a 0x00 in between and then the new // channel @@ -277,7 +277,7 @@ impl Frame { } impl Encoder for Frame { - fn encoded_len(&mut self) -> Result { + fn encoded_len(&self) -> Result { let body_len = self.preencode(&mut State::new())?; match self { Self::RawBatch(_) => Ok(body_len), @@ -285,7 +285,7 @@ impl Encoder for Frame { } } - fn encode(&mut self, buf: &mut [u8]) -> Result { + fn encode(&self, buf: &mut [u8]) -> Result { let mut state = State::new(); let header_len = if let Self::RawBatch(_) = self { 0 } else { 3 }; let body_len = self.preencode(&mut state)?; @@ -303,7 +303,7 @@ impl Encoder for Frame { } } #[allow(clippy::comparison_chain)] - Self::MessageBatch(ref mut messages) => { + Self::MessageBatch(ref messages) => { write_uint24_le(body_len, buf); let buf = buf.get_mut(3..).expect("Buffer should be over 3 bytes"); if messages.len() == 1 { @@ -326,7 +326,7 @@ impl Encoder for Frame { state.set_slice_to_buffer(&[0_u8, 0_u8], buf)?; let mut current_channel: u64 = messages[0].channel; state.encode(¤t_channel, buf)?; - for message in messages.iter_mut() { + for message in messages.iter() { if message.channel != current_channel { // Channel changed, need to add a 0x00 in between and then the new // channel @@ -582,54 +582,48 @@ impl ChannelMessage { /// Performance optimization for letting calling encoded_len() already do /// the preencode phase of compact_encoding. - fn prepare_state(&mut self) -> Result<(), EncodingError> { - if self.state.is_none() { - let state = if let Message::Open(_) = self.message { - // Open message doesn't have a type - // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L41 - let mut state = HypercoreState::new(); - self.message.preencode(&mut state)?; - state - } else if let Message::Close(_) = self.message { - // Close message doesn't have a type - // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L162 - let mut state = HypercoreState::new(); - self.message.preencode(&mut state)?; - state - } else { - // The header is the channel id uint followed by message type uint - // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L179 - let mut state = HypercoreState::new(); - let typ = self.message.typ(); - (*state).preencode(&typ)?; - self.message.preencode(&mut state)?; - state - }; - self.state = Some(state); - } - Ok(()) + fn prepare_state(&self) -> Result { + Ok(if let Message::Open(_) = self.message { + // Open message doesn't have a type + // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L41 + let mut state = HypercoreState::new(); + self.message.preencode(&mut state)?; + state + } else if let Message::Close(_) = self.message { + // Close message doesn't have a type + // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L162 + let mut state = HypercoreState::new(); + self.message.preencode(&mut state)?; + state + } else { + // The header is the channel id uint followed by message type uint + // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L179 + let mut state = HypercoreState::new(); + let typ = self.message.typ(); + (*state).preencode(&typ)?; + self.message.preencode(&mut state)?; + state + }) } } impl Encoder for ChannelMessage { - fn encoded_len(&mut self) -> Result { - self.prepare_state()?; - Ok(self.state.as_ref().unwrap().end()) + fn encoded_len(&self) -> Result { + Ok(self.prepare_state()?.end()) } fn encode(&mut self, buf: &mut [u8]) -> Result { - self.prepare_state()?; - let state = self.state.as_mut().unwrap(); + let mut state = self.prepare_state()?; if let Message::Open(_) = self.message { // Open message is different in that the type byte is missing - self.message.encode(state, buf)?; + self.message.encode(&mut state, buf)?; } else if let Message::Close(_) = self.message { // Close message is different in that the type byte is missing - self.message.encode(state, buf)?; + self.message.encode(&mut state, buf)?; } else { let typ = self.message.typ(); state.0.encode(&typ, buf)?; - self.message.encode(state, buf)?; + self.message.encode(&mut state, buf)?; } Ok(state.start()) } From a860221b79353dedbae00864bcb900f7bf368136 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 15:53:57 -0400 Subject: [PATCH 046/206] rm unused State from ChannelMessage --- src/message.rs | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/src/message.rs b/src/message.rs index 832a5f4..78ccff5 100644 --- a/src/message.rs +++ b/src/message.rs @@ -479,7 +479,6 @@ impl fmt::Display for Message { pub(crate) struct ChannelMessage { pub(crate) channel: u64, pub(crate) message: Message, - state: Option, } impl PartialEq for ChannelMessage { @@ -497,11 +496,7 @@ impl fmt::Debug for ChannelMessage { impl ChannelMessage { /// Create a new message. pub(crate) fn new(channel: u64, message: Message) -> Self { - Self { - channel, - message, - state: None, - } + Self { channel, message } } /// Consume self and return (channel, Message). @@ -527,7 +522,6 @@ impl ChannelMessage { Self { channel: open_msg.channel, message: Message::Open(open_msg), - state: None, }, state.start(), )) @@ -550,7 +544,6 @@ impl ChannelMessage { Self { channel: close_msg.channel, message: Message::Close(close_msg), - state: None, }, state.start(), )) @@ -570,14 +563,7 @@ impl ChannelMessage { let mut state = State::from_buffer(buf); let typ: u64 = state.decode(buf)?; let (message, length) = Message::decode(&buf[state.start()..], typ)?; - Ok(( - Self { - channel, - message, - state: None, - }, - state.start() + length, - )) + Ok((Self { channel, message }, state.start() + length)) } /// Performance optimization for letting calling encoded_len() already do From d6e1e72e4331f65d84373a5127b4025d42030862 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 15:54:12 -0400 Subject: [PATCH 047/206] split out Vec encoding --- src/message.rs | 161 +++++++++++++++++++++++++++++-------------------- 1 file changed, 95 insertions(+), 66 deletions(-) diff --git a/src/message.rs b/src/message.rs index 78ccff5..a807bf9 100644 --- a/src/message.rs +++ b/src/message.rs @@ -239,37 +239,8 @@ impl Frame { state.add_end(raw.as_slice().encoded_len()?)?; } } - #[allow(clippy::comparison_chain)] Self::MessageBatch(messages) => { - if messages.len() == 1 { - if let Message::Open(_) = &messages[0].message { - // This is a special case with 0x00, 0x01 intro bytes - state.add_end(2 + &messages[0].encoded_len()?)?; - } else if let Message::Close(_) = &messages[0].message { - // This is a special case with 0x00, 0x03 intro bytes - state.add_end(2 + &messages[0].encoded_len()?)?; - } else { - (*state).preencode(&messages[0].channel)?; - state.add_end(messages[0].encoded_len()?)?; - } - } else if messages.len() > 1 { - // Two intro bytes 0x00 0x00, then channel id, then lengths - state.add_end(2)?; - let mut current_channel: u64 = messages[0].channel; - state.preencode(¤t_channel)?; - for message in messages.iter() { - if message.channel != current_channel { - // Channel changed, need to add a 0x00 in between and then the new - // channel - state.add_end(1)?; - state.preencode(&message.channel)?; - current_channel = message.channel; - } - let message_length = message.encoded_len()?; - state.preencode(&message_length)?; - state.add_end(message_length)?; - } - } + state.add_end(messages.encoded_len()?)?; } } Ok(state.end()) @@ -302,46 +273,104 @@ impl Encoder for Frame { raw.as_slice().encode(buf)?; } } - #[allow(clippy::comparison_chain)] Self::MessageBatch(ref messages) => { - write_uint24_le(body_len, buf); - let buf = buf.get_mut(3..).expect("Buffer should be over 3 bytes"); - if messages.len() == 1 { - if let Message::Open(_) = &messages[0].message { - // This is a special case with 0x00, 0x01 intro bytes - state.encode(&(0_u8), buf)?; - state.encode(&(1_u8), buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } else if let Message::Close(_) = &messages[0].message { - // This is a special case with 0x00, 0x03 intro bytes + messages.encode(buf)?; + } + }; + Ok(len) + } +} + +fn prencode_channel_messages( + messages: &[ChannelMessage], + state: &mut State, +) -> Result { + match messages.len().cmp(&1) { + std::cmp::Ordering::Less => {} + std::cmp::Ordering::Equal => { + if let Message::Open(_) = &messages[0].message { + // This is a special case with 0x00, 0x01 intro bytes + state.add_end(2 + &messages[0].encoded_len()?)?; + } else if let Message::Close(_) = &messages[0].message { + // This is a special case with 0x00, 0x03 intro bytes + state.add_end(2 + &messages[0].encoded_len()?)?; + } else { + state.preencode(&messages[0].channel)?; + state.add_end(messages[0].encoded_len()?)?; + } + } + std::cmp::Ordering::Greater => { + // Two intro bytes 0x00 0x00, then channel id, then lengths + state.add_end(2)?; + let mut current_channel: u64 = messages[0].channel; + state.preencode(¤t_channel)?; + for message in messages.iter() { + if message.channel != current_channel { + // Channel changed, need to add a 0x00 in between and then the new + // channel + state.add_end(1)?; + state.preencode(&message.channel)?; + current_channel = message.channel; + } + let message_length = message.encoded_len()?; + state.preencode(&message_length)?; + state.add_end(message_length)?; + } + } + }; + Ok(state.end()) +} + +impl Encoder for Vec { + fn encoded_len(&self) -> Result { + let mut state = State::new(); + prencode_channel_messages(self, &mut state) + } + + fn encode(&self, buf: &mut [u8]) -> Result { + const HEADER_LEN: usize = 3; + let mut state = State::new(); + let body_len = prencode_channel_messages(self, &mut state)?; + write_uint24_le(body_len, buf); + let buf = buf.get_mut(3..).expect("Buffer should be over 3 bytes"); + match self.len().cmp(&1) { + std::cmp::Ordering::Less => {} + std::cmp::Ordering::Equal => { + if let Message::Open(_) = &self[0].message { + // This is a special case with 0x00, 0x01 intro bytes + state.encode(&(0_u8), buf)?; + state.encode(&(1_u8), buf)?; + state.add_start(self[0].encode(&mut buf[state.start()..])?)?; + } else if let Message::Close(_) = &self[0].message { + // This is a special case with 0x00, 0x03 intro bytes + state.encode(&(0_u8), buf)?; + state.encode(&(3_u8), buf)?; + state.add_start(self[0].encode(&mut buf[state.start()..])?)?; + } else { + state.encode(&self[0].channel, buf)?; + state.add_start(self[0].encode(&mut buf[state.start()..])?)?; + } + } + std::cmp::Ordering::Greater => { + // Two intro bytes 0x00 0x00, then channel id, then lengths + state.set_slice_to_buffer(&[0_u8, 0_u8], buf)?; + let mut current_channel: u64 = self[0].channel; + state.encode(¤t_channel, buf)?; + for message in self.iter() { + if message.channel != current_channel { + // Channel changed, need to add a 0x00 in between and then the new + // channel state.encode(&(0_u8), buf)?; - state.encode(&(3_u8), buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } else { - state.encode(&messages[0].channel, buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } - } else if messages.len() > 1 { - // Two intro bytes 0x00 0x00, then channel id, then lengths - state.set_slice_to_buffer(&[0_u8, 0_u8], buf)?; - let mut current_channel: u64 = messages[0].channel; - state.encode(¤t_channel, buf)?; - for message in messages.iter() { - if message.channel != current_channel { - // Channel changed, need to add a 0x00 in between and then the new - // channel - state.encode(&(0_u8), buf)?; - state.encode(&message.channel, buf)?; - current_channel = message.channel; - } - let message_length = message.encoded_len()?; - state.encode(&message_length, buf)?; - state.add_start(message.encode(&mut buf[state.start()..])?)?; + state.encode(&message.channel, buf)?; + current_channel = message.channel; } + let message_length = message.encoded_len()?; + state.encode(&message_length, buf)?; + state.add_start(message.encode(&mut buf[state.start()..])?)?; } } - }; - Ok(len) + } + Ok(HEADER_LEN + body_len) } } @@ -598,7 +627,7 @@ impl Encoder for ChannelMessage { Ok(self.prepare_state()?.end()) } - fn encode(&mut self, buf: &mut [u8]) -> Result { + fn encode(&self, buf: &mut [u8]) -> Result { let mut state = self.prepare_state()?; if let Message::Open(_) = self.message { // Open message is different in that the type byte is missing From 7dbb12574ff6566b24ebb079c700fe6b9bd937f1 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 18 Mar 2025 18:15:43 -0400 Subject: [PATCH 048/206] make integration tests use tokio --- tests/_util.rs | 58 +++++++++++++++++++------------------------------- tests/basic.rs | 46 ++++++++++++++++++--------------------- 2 files changed, 43 insertions(+), 61 deletions(-) diff --git a/tests/_util.rs b/tests/_util.rs index 9d0f9bf..3064c08 100644 --- a/tests/_util.rs +++ b/tests/_util.rs @@ -1,10 +1,27 @@ use async_std::net::TcpStream; -use async_std::prelude::*; -use async_std::task::{self, JoinHandle}; use futures_lite::io::{AsyncRead, AsyncWrite}; +use futures_lite::StreamExt; use hypercore_protocol::{Channel, DiscoveryKey, Duplex, Event, Protocol, ProtocolBuilder}; use instant::Duration; +use std::future::Future; use std::io; +use tokio::task::JoinHandle; + +pub(crate) fn log() { + use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; + static START_LOGS: std::sync::OnceLock<()> = std::sync::OnceLock::new(); + START_LOGS.get_or_init(|| { + tracing_subscriber::fmt() + .with_target(true) + .with_line_number(true) + // print when instrumented funtion enters + .with_span_events(FmtSpan::ENTER | FmtSpan::EXIT) + .with_file(true) + .with_env_filter(EnvFilter::from_default_env()) // Reads `RUST_LOG` environment variable + .without_time() + .init(); + }); +} pub type MemoryProtocol = Protocol>; pub async fn create_pair_memory() -> io::Result<(MemoryProtocol, MemoryProtocol)> { @@ -18,21 +35,11 @@ pub async fn create_pair_memory() -> io::Result<(MemoryProtocol, MemoryProtocol) Ok((a, b)) } -pub type TcpProtocol = Protocol; -pub async fn create_pair_tcp() -> io::Result<(TcpProtocol, TcpProtocol)> { - let (stream_a, stream_b) = tcp::pair().await?; - let a = ProtocolBuilder::new(true).connect(stream_a); - let b = ProtocolBuilder::new(false).connect(stream_b); - Ok((a, b)) -} - -pub fn next_event( - mut proto: Protocol, -) -> impl Future, io::Result)> +pub fn next_event(mut proto: Protocol) -> JoinHandle<(Protocol, io::Result)> where IO: AsyncRead + AsyncWrite + Send + Unpin + 'static, { - task::spawn(async move { + tokio::task::spawn(async move { let e1 = proto.next().await; let e1 = e1.unwrap(); (proto, e1) @@ -62,7 +69,7 @@ pub fn drive_until_channel( where IO: AsyncRead + AsyncWrite + Send + Unpin + 'static, { - task::spawn(async move { + tokio::task::spawn(async move { while let Some(event) = proto.next().await { let event = event?; if let Event::Channel(channel) = event { @@ -76,27 +83,6 @@ where }) } -pub mod tcp { - use async_std::net::{TcpListener, TcpStream}; - use async_std::prelude::*; - use async_std::task; - use std::io::{Error, ErrorKind, Result}; - pub async fn pair() -> Result<(TcpStream, TcpStream)> { - let address = "localhost:9999"; - let listener = TcpListener::bind(&address).await?; - let mut incoming = listener.incoming(); - - let connect_task = task::spawn(async move { TcpStream::connect(&address).await }); - - let server_stream = incoming.next().await; - let server_stream = - server_stream.ok_or_else(|| Error::new(ErrorKind::Other, "Stream closed"))?; - let server_stream = server_stream?; - let client_stream = connect_task.await?; - Ok((server_stream, client_stream)) - } -} - const RETRY_TIMEOUT: u64 = 100_u64; const NO_RESPONSE_TIMEOUT: u64 = 1000_u64; pub async fn wait_for_localhost_port(port: u32) { diff --git a/tests/basic.rs b/tests/basic.rs index 8a99c7e..062cf35 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -1,26 +1,22 @@ -#![allow(dead_code, unused_imports)] - -use async_std::net::TcpStream; -use async_std::prelude::*; -use async_std::task; -use futures_lite::io::{AsyncRead, AsyncWrite}; -use hypercore_protocol::{discovery_key, Channel, Event, Message, Protocol, ProtocolBuilder}; +use _util::{ + create_pair_memory, drive_until_channel, event_channel, event_discovery_key, next_event, +}; +use futures_lite::StreamExt; +use hypercore_protocol::{discovery_key, Event, Message}; use hypercore_protocol::{schema::*, DiscoveryKey}; use std::io; -use test_log::test; +use tokio::task; mod _util; -use _util::*; -#[test(async_std::test)] +#[tokio::test] async fn basic_protocol() -> anyhow::Result<()> { - // env_logger::init(); let (proto_a, proto_b) = create_pair_memory().await?; let next_a = next_event(proto_a); let next_b = next_event(proto_b); - let (mut proto_a, event_a) = next_a.await; - let (proto_b, event_b) = next_b.await; + let (mut proto_a, event_a) = next_a.await?; + let (proto_b, event_b) = next_b.await?; assert!(matches!(event_a, Ok(Event::Handshake(_)))); assert!(matches!(event_b, Ok(Event::Handshake(_)))); @@ -35,18 +31,18 @@ async fn basic_protocol() -> anyhow::Result<()> { let next_a = next_event(proto_a); let next_b = next_event(proto_b); - let (mut proto_b, event_b) = next_b.await; + let (mut proto_b, event_b) = next_b.await?; assert!(matches!(event_b, Ok(Event::DiscoveryKey(_)))); assert_eq!(event_discovery_key(event_b.unwrap()), discovery_key(&key)); proto_b.open(key).await?; let next_b = next_event(proto_b); - let (proto_b, event_b) = next_b.await; + let (proto_b, event_b) = next_b.await?; assert!(matches!(event_b, Ok(Event::Channel(_)))); let mut channel_b = event_channel(event_b.unwrap()); - let (proto_a, event_a) = next_a.await; + let (proto_a, event_a) = next_a.await?; assert!(matches!(event_a, Ok(Event::Channel(_)))); let mut channel_a = event_channel(event_a.unwrap()); @@ -68,8 +64,8 @@ async fn basic_protocol() -> anyhow::Result<()> { channel_a.close().await?; - let (_, event_a) = next_a.await; - let (_, event_b) = next_b.await; + let (_, event_a) = next_a.await?; + let (_, event_b) = next_b.await?; assert!(matches!(event_a, Ok(Event::Close(_)))); assert!(matches!(event_b, Ok(Event::Close(_)))); @@ -78,7 +74,7 @@ async fn basic_protocol() -> anyhow::Result<()> { Ok(()) } -#[test(async_std::test)] +#[tokio::test] async fn open_close_channels() -> anyhow::Result<()> { let (mut proto_a, mut proto_b) = create_pair_memory().await?; @@ -91,8 +87,8 @@ async fn open_close_channels() -> anyhow::Result<()> { let next_a = drive_until_channel(proto_a); let next_b = drive_until_channel(proto_b); - let (mut proto_a, mut channel_a1) = next_a.await?; - let (mut proto_b, mut channel_b1) = next_b.await?; + let (mut proto_a, mut channel_a1) = next_a.await??; + let (mut proto_b, mut channel_b1) = next_b.await??; proto_a.open(key2).await?; proto_b.open(key2).await?; @@ -100,8 +96,8 @@ async fn open_close_channels() -> anyhow::Result<()> { let next_a = drive_until_channel(proto_a); let next_b = drive_until_channel(proto_b); - let (proto_a, mut channel_a2) = next_a.await?; - let (proto_b, mut channel_b2) = next_b.await?; + let (proto_a, mut channel_a2) = next_a.await??; + let (proto_b, mut channel_b2) = next_b.await??; eprintln!( "got channels: {:?}", @@ -119,8 +115,8 @@ async fn open_close_channels() -> anyhow::Result<()> { let next_a = next_event(proto_a); let next_b = next_event(proto_b); - let (mut proto_a, ev_a) = next_a.await; - let (mut proto_b, ev_b) = next_b.await; + let (mut proto_a, ev_a) = next_a.await?; + let (mut proto_b, ev_b) = next_b.await?; let ev_a = ev_a?; let ev_b = ev_b?; eprintln!("next a: {ev_a:?}"); From bbad25b265dc08598bdf4aab3c9acbc014ee370e Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 19 Mar 2025 11:19:10 -0400 Subject: [PATCH 049/206] RESETME --- Cargo.toml | 7 +- src/framing.rs | 3 - src/lib.rs | 1 + src/message.rs | 148 ++++++++++++++------------- src/mqueue.rs | 208 +++++++++++++++++++++++++++++++++++++ src/noise.rs | 3 +- src/protocol/modern.rs | 226 +++++++---------------------------------- src/protocol/old.rs | 1 + src/schema.rs | 9 +- src/test_utils.rs | 5 +- tests/_util.rs | 7 +- tests/basic.rs | 11 +- 12 files changed, 349 insertions(+), 280 deletions(-) create mode 100644 src/mqueue.rs diff --git a/Cargo.toml b/Cargo.toml index 82790eb..7c15c8b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,8 +42,9 @@ crypto_secretstream = "0.2" futures = "0.3.31" [dependencies.hypercore] -version = "0.14.0" -default-features = false +path = "../core" +#version = "0.14.0" +#default-features = false [dev-dependencies] @@ -64,7 +65,7 @@ tracing-subscriber = { version = "0.3.16", features = ["env-filter", "fmt"] } tokio-util = { version = "0.7.14", features = ["compat"] } [features] -default = ["tokio", "sparse"] +default = ["tokio", "sparse", "protocol"] protocol = [] wasm-bindgen = [ "futures-timer/wasm-bindgen" diff --git a/src/framing.rs b/src/framing.rs index e7c9c5c..ce4d7ba 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -223,7 +223,6 @@ pub(crate) mod test { let mut lp = Uint24LELengthPrefixedFraming::new(left); let input = b"yelp"; let msg = wrap_uint24_le(input); - dbg!(&msg); right.write_all(&msg).await?; let Some(Ok(rx)) = lp.next().await else { panic!() @@ -242,11 +241,9 @@ pub(crate) mod test { right.write_all(&msg).await?; } for d in data { - dbg!(); let Some(Ok(res)) = lp.next().await else { panic!(); }; - dbg!(&res); assert_eq!(&res, d); } Ok(()) diff --git a/src/lib.rs b/src/lib.rs index 88b3d32..9fca95e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -124,6 +124,7 @@ mod crypto; mod duplex; mod framing; mod message; +mod mqueue; mod noise; mod protocol; mod reader; diff --git a/src/message.rs b/src/message.rs index a807bf9..869baf0 100644 --- a/src/message.rs +++ b/src/message.rs @@ -6,6 +6,7 @@ use hypercore::encoding::{ use pretty_hash::fmt as pretty_fmt; use std::fmt; use std::io; +use tracing::instrument; /// The type of a data frame. #[derive(Debug, Clone, PartialEq)] @@ -76,6 +77,79 @@ impl From> for Frame { } } +pub(crate) fn decode_channel_messages( + buf: &[u8], +) -> Result<(Vec, usize), io::Error> { + if buf.len() >= 3 && buf[0] == 0x00 { + if buf[1] == 0x00 { + // Batch of messages + dbg!(); + let mut messages: Vec = vec![]; + let mut state = State::new_with_start_and_end(2, buf.len()); + + // First, there is the original channel + let mut current_channel: u64 = state.decode(buf)?; + while state.start() < state.end() { + // Length of the message is inbetween here + let channel_message_length: usize = state.decode(buf)?; + if state.start() + channel_message_length > state.end() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "received invalid message length, {} + {} > {}", + state.start(), + channel_message_length, + state.end() + ), + )); + } + // Then the actual message + let (channel_message, _) = ChannelMessage::decode( + &buf[state.start()..state.start() + channel_message_length], + current_channel, + )?; + messages.push(channel_message); + state.add_start(channel_message_length)?; + // After that, if there is an extra 0x00, that means the channel + // changed. This works because of LE encoding, and channels starting + // from the index 1. + if state.start() < state.end() && buf[state.start()] == 0x00 { + state.add_start(1)?; + current_channel = state.decode(buf)?; + } + } + Ok((messages, state.start())) + } else if buf[1] == 0x01 { + dbg!(); + // Open message + let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?; + Ok((vec![channel_message], length + 2)) + } else if buf[1] == 0x03 { + dbg!(); + // Close message + let (channel_message, length) = ChannelMessage::decode_close_message(&buf[2..])?; + Ok((vec![channel_message], length + 2)) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + "received invalid special message", + )) + } + } else if buf.len() >= 2 { + dbg!(); + // Single message + let mut state = State::from_buffer(buf); + let channel: u64 = state.decode(buf)?; + let (channel_message, length) = ChannelMessage::decode(&buf[state.start()..], channel)?; + Ok((vec![channel_message], state.start() + length)) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("received too short message, {buf:02X?}"), + )) + } +} + impl Frame { /// Decodes a frame from a buffer containing multiple concurrent messages. pub(crate) fn decode_multiple(buf: &[u8], frame_type: &FrameType) -> Result { @@ -163,73 +237,8 @@ impl Frame { } fn decode_message(buf: &[u8]) -> Result<(Self, usize), io::Error> { - if buf.len() >= 3 && buf[0] == 0x00 { - if buf[1] == 0x00 { - // Batch of messages - let mut messages: Vec = vec![]; - let mut state = State::new_with_start_and_end(2, buf.len()); - - // First, there is the original channel - let mut current_channel: u64 = state.decode(buf)?; - while state.start() < state.end() { - // Length of the message is inbetween here - let channel_message_length: usize = state.decode(buf)?; - if state.start() + channel_message_length > state.end() { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!( - "received invalid message length, {} + {} > {}", - state.start(), - channel_message_length, - state.end() - ), - )); - } - // Then the actual message - let (channel_message, _) = ChannelMessage::decode( - &buf[state.start()..state.start() + channel_message_length], - current_channel, - )?; - messages.push(channel_message); - state.add_start(channel_message_length)?; - // After that, if there is an extra 0x00, that means the channel - // changed. This works because of LE encoding, and channels starting - // from the index 1. - if state.start() < state.end() && buf[state.start()] == 0x00 { - state.add_start(1)?; - current_channel = state.decode(buf)?; - } - } - Ok((Frame::MessageBatch(messages), state.start())) - } else if buf[1] == 0x01 { - // Open message - let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?; - Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) - } else if buf[1] == 0x03 { - // Close message - let (channel_message, length) = ChannelMessage::decode_close_message(&buf[2..])?; - Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - "received invalid special message", - )) - } - } else if buf.len() >= 2 { - // Single message - let mut state = State::from_buffer(buf); - let channel: u64 = state.decode(buf)?; - let (channel_message, length) = ChannelMessage::decode(&buf[state.start()..], channel)?; - Ok(( - Frame::MessageBatch(vec![channel_message]), - state.start() + length, - )) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("received too short message, {buf:02X?}"), - )) - } + let (channel_messages, bytes_read) = decode_channel_messages(buf)?; + Ok((Frame::MessageBatch(channel_messages), bytes_read)) } fn preencode(&self, state: &mut State) -> Result { @@ -290,7 +299,7 @@ fn prencode_channel_messages( std::cmp::Ordering::Equal => { if let Message::Open(_) = &messages[0].message { // This is a special case with 0x00, 0x01 intro bytes - state.add_end(2 + &messages[0].encoded_len()?)?; + state.add_end(2 + dbg!(&messages[0].encoded_len()?))?; } else if let Message::Close(_) = &messages[0].message { // This is a special case with 0x00, 0x03 intro bytes state.add_end(2 + &messages[0].encoded_len()?)?; @@ -327,6 +336,7 @@ impl Encoder for Vec { prencode_channel_messages(self, &mut state) } + #[instrument] fn encode(&self, buf: &mut [u8]) -> Result { const HEADER_LEN: usize = 3; let mut state = State::new(); @@ -655,7 +665,7 @@ mod tests { ($( $msg:expr ),*) => { $( let channel = rand::random::() as u64; - let mut channel_message = ChannelMessage::new(channel, $msg); + let channel_message = ChannelMessage::new(channel, $msg); let encoded_len = channel_message.encoded_len().expect("Failed to get encoded length"); let mut buf = vec![0u8; encoded_len]; let n = channel_message.encode(&mut buf[..]).expect("Failed to encode message"); diff --git a/src/mqueue.rs b/src/mqueue.rs new file mode 100644 index 0000000..b968937 --- /dev/null +++ b/src/mqueue.rs @@ -0,0 +1,208 @@ +//! Interface for reading and writing message to a Stream/Sink + +use std::{ + collections::VecDeque, + io::Result, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::{AsyncRead, AsyncWrite, Sink, Stream}; +use tracing::{debug, error, info, instrument, trace}; + +use crate::{ + encrypted_framed_message_channel, + message::{decode_channel_messages, ChannelMessage, Encoder as _}, +}; + +pub(crate) struct MessageIo { + io: IO, + write_queue: VecDeque, +} + +use crate::{framing::Uint24LELengthPrefixedFraming, noise::Encrypted}; + +pub(crate) fn encrypted_and_framed( + is_initiator: bool, + io: BytesTxRx, +) -> MessageIo>> { + let io = encrypted_framed_message_channel(is_initiator, io); + MessageIo { + io, + write_queue: Default::default(), + } +} +impl>> + Sink> + Send + Unpin + 'static> MessageIo { + pub(crate) fn new(io: IO) -> Self { + Self { + io, + write_queue: Default::default(), + } + } + + pub(crate) fn enqueue(&mut self, msg: ChannelMessage) { + self.write_queue.push_back(msg) + } + + #[instrument(skip_all)] + pub(crate) fn poll_outbound(&mut self, cx: &mut Context<'_>) -> Poll> { + let mut pending = true; + // TODO handle error? + while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(&mut self.io), cx) { + pending = false; + if self.write_queue.is_empty() { + break; + } + let mut messages = vec![]; + while let Some(msg) = self.write_queue.pop_front() { + messages.push(msg); + } + + let mut buf = vec![0; messages.encoded_len()?]; + dbg!(&buf); + match messages.encode(&mut buf) { + Ok(_) => {} + Err(e) => { + error!(error = ?e, "error encoding messages"); + return Poll::Ready(Err(e.into())); + } + } + if let Err(_e) = Sink::start_send(Pin::new(&mut self.io), buf) { + error!("error in start_send"); + todo!() + } + + match Sink::poll_flush(Pin::new(&mut self.io), cx) { + Poll::Ready(Ok(())) => { + debug!("flushed"); + } + Poll::Ready(Err(_e)) => { + error!("Error flushing"); + return todo!(); + } + Poll::Pending => { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + } + } + + if pending { + cx.waker().wake_by_ref(); + Poll::Pending + } else { + Poll::Ready(Ok(())) + } + } + + pub(crate) fn poll_inbound( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + match Pin::new(&mut self.io).poll_next(cx) { + Poll::Ready(Some(Ok(encoded))) => { + match decode_channel_messages(&encoded) { + Ok((messsages, n_read)) => { + assert_eq!(n_read, encoded.len()); // I think this is always true + Poll::Ready(Ok(messsages)) + } + Err(_) => todo!(), + } + } + Poll::Ready(Some(Err(_e))) => todo!(), + Poll::Ready(None) => todo!(), + Poll::Pending => Poll::Pending, + } + } +} + +impl>> + Sink> + Send + Unpin + 'static> Stream + for MessageIo +{ + type Item = Result>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let out_res = self.poll_outbound(cx); + match out_res { + Poll::Ready(res) => match res { + Ok(okres) => trace!(res = ?okres, "MessageIo poll_outbound"), + Err(e) => error!(error = ?e, "MessageIo error in poll_outbound"), + }, + Poll::Pending => trace!("MessageIo poll_outbound Pending"), + } + + let in_res = self.poll_inbound(cx); + trace!(poll_inbound = ?in_res, "MessageIo"); + + in_res.map(Some) + } +} + +#[cfg(test)] +mod test { + use std::io::Result; + + use futures::future::{join, select}; + use futures_lite::StreamExt; + + use crate::{ + framing::test::duplex, + message::{decode_channel_messages, ChannelMessage, Encoder as _}, + mqueue::encrypted_and_framed, + schema::{NoData, Open}, + test_utils::log, + }; + fn new_msg(channel: u64) -> ChannelMessage { + ChannelMessage { + channel, + message: crate::Message::NoData(NoData { request: channel }), + } + } + + #[tokio::test] + async fn mqueue() -> Result<()> { + log(); + let m = vec![new_msg(0)]; + let mut buf = vec![0; m.encoded_len()?]; + dbg!(&buf.len()); + dbg!(); + m.encode(&mut buf)?; + dbg!(&buf); + + let res = dbg!(decode_channel_messages(&buf))?; + assert_eq!(vec![new_msg(42402)], res.0); + dbg!(&buf); + + Ok(()) + + /* + let (left, right) = duplex(1024 * 64); + let mut left = encrypted_and_framed(true, left); + let mut right = encrypted_and_framed(false, right); + left.enqueue(new_msg(42)); + right.enqueue(new_msg(38)); + + match select(left.next(), right.next()).await { + futures::future::Either::Left(ll) => { + println!( + "left + + ooooooooooooooooooooo + + " + ); + } + futures::future::Either::Right(rr) => { + println!( + "rightllllllllllllllll + + ------------------------- + + " + ); + } + } + Ok(()) + */ + } +} diff --git a/src/noise.rs b/src/noise.rs index 3dbf660..d2a3a1d 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -22,6 +22,7 @@ pub fn encrypted_framed_message_channel { - dbg!(); let mut out = vec![]; if let Step::Handshake(mut handshake) = replace(step, Step::NotInitialized) { trace!("Read in handshake msg\n{msg:?}"); @@ -562,7 +562,6 @@ mod tset { #[tokio::test] async fn with_framing() -> Result<()> { - crate::test_utils::log(); let hello = b"hello".to_vec(); let (left, right) = duplex(1024 * 64); diff --git a/src/protocol/modern.rs b/src/protocol/modern.rs index e0e3c3a..cb71f50 100644 --- a/src/protocol/modern.rs +++ b/src/protocol/modern.rs @@ -5,20 +5,18 @@ use futures_timer::Delay; use std::collections::VecDeque; use std::convert::TryInto; use std::fmt; -use std::future::Future; use std::io::{self, Error, ErrorKind, Result}; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; -use tracing::trace; +use tracing::instrument; use crate::channels::{Channel, ChannelMap}; use crate::constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}; -use crate::crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeResult}; -use crate::message::{ChannelMessage, Frame, FrameType, Message}; -use crate::reader::ReadState; +use crate::crypto::{EncryptCipher, Handshake, HandshakeResult}; +use crate::message::{ChannelMessage, Frame, Message}; +use crate::mqueue::MessageIo; use crate::util::{map_channel_err, pretty_hash}; -use crate::writer::WriteState; use crate::{ encrypted_framed_message_channel, schema::*, Encrypted, Uint24LELengthPrefixedFraming, }; @@ -138,9 +136,7 @@ impl fmt::Debug for State { /// A Protocol stream. pub struct Protocol { - write_state: WriteState, - read_state: ReadState, - io: Encrypted>, + io: MessageIo>>, state: State, options: Options, handshake: Option, @@ -156,8 +152,6 @@ pub struct Protocol { impl std::fmt::Debug for Protocol { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Protocol") - .field("write_state", &self.write_state) - .field("read_state", &self.read_state) //.field("io", &self.io) .field("state", &self.state) .field("options", &self.options) @@ -186,9 +180,7 @@ where ) = async_channel::bounded(1); Protocol { - io: encrypted_framed_message_channel(options.is_initiator, io), - read_state: ReadState::new(), - write_state: WriteState::new(), + io: MessageIo::new(encrypted_framed_message_channel(options.is_initiator, io)), options, state: State::NotInitialized, channels: ChannelMap::new(), @@ -251,17 +243,14 @@ where } /// Stop the protocol and return the inner reader and writer. - pub fn release(self) -> Encrypted> { + pub fn release(self) -> MessageIo>> { self.io } + #[instrument(skip_all)] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); - if let State::NotInitialized = this.state { - return_error!(this.init()); - } - // Drain queued events first. if let Some(event) = this.queued_events.pop_front() { return Poll::Ready(Ok(event)); @@ -270,10 +259,8 @@ where // Read and process incoming messages. return_error!(this.poll_inbound_read(cx)); - if let State::Established = this.state { - // Check for commands, but only once the connection is established. - return_error!(this.poll_commands(cx)); - } + // Check for commands, but only once the connection is established. + return_error!(this.poll_commands(cx)); // Poll the keepalive timer. this.poll_keepalive(cx); @@ -289,34 +276,6 @@ where } } - fn init(&mut self) -> Result<()> { - trace!( - "protocol Init, state {:?}, options {:?}", - self.state, - self.options - ); - match self.state { - State::NotInitialized => {} - _ => return Ok(()), - }; - - self.state = if self.options.noise { - let mut handshake = Handshake::new(self.options.is_initiator)?; - // If the handshake start returns a buffer, send it now. - if let Some(buf) = handshake.start()? { - // TODO what if this fails? or returns false - self.queue_frame_direct(buf.to_vec()).unwrap(); - } - self.read_state.set_frame_type(FrameType::Raw); - State::Handshake(Some(handshake)) - } else { - self.read_state.set_frame_type(FrameType::Message); - State::Established - }; - - Ok(()) - } - /// Poll commands. fn poll_commands(&mut self, cx: &mut Context<'_>) -> Result<()> { while let Poll::Ready(Some(command)) = Pin::new(&mut self.command_rx).poll_next(cx) { @@ -325,8 +284,9 @@ where Ok(()) } - /// Poll the keepalive timer and queue a ping message if needed. - fn poll_keepalive(&mut self, cx: &mut Context<'_>) { + /// TODO Poll the keepalive timer and queue a ping message if needed. + fn poll_keepalive(&mut self, _cx: &mut Context<'_>) { + /* if Pin::new(&mut self.keepalive).poll(cx).is_ready() { if let State::Established = self.state { // 24 bit header for the empty message, hence the 3 @@ -335,8 +295,10 @@ where } self.keepalive.reset(KEEPALIVE_DURATION); } + */ } + // just handles Close and LocalSignal?? fn on_outbound_message(&mut self, message: &ChannelMessage) -> bool { // If message is close, close the local channel. if let ChannelMessage { @@ -360,12 +322,11 @@ where } /// Poll for inbound messages and processs them. - fn poll_inbound_read(&mut self, _cx: &mut Context<'_>) -> Result<()> { + fn poll_inbound_read(&mut self, cx: &mut Context<'_>) -> Result<()> { loop { - //let msg = self.read_state.poll_reader(cx, &mut self.io); - match todo!() { - Poll::Ready(Ok(message)) => { - self.on_inbound_frame(message)?; + match self.io.poll_inbound(cx) { + Poll::Ready(Ok(messages)) => { + self.on_inbound_channel_messages(messages)?; } Poll::Ready(Err(e)) => return Err(e), Poll::Pending => return Ok(()), @@ -374,23 +335,20 @@ where } /// Poll for outbound messages and write them. + /// Reads messages from Self::outbound and sends them over io fn poll_outbound_write(&mut self, cx: &mut Context<'_>) -> Result<()> { loop { - //if let Poll::Ready(Err(e)) = self.write_state.poll_send(cx, &mut self.io) { - // return Err(e); - //} // if no parking or setup in progress - if !self.write_state.can_park_frame() || !matches!(self.state, State::Established) { - return Ok(()); + if let Poll::Ready(Err(e)) = self.io.poll_outbound(cx) { + return Err(e); } - + // send messages outbound_rx match Pin::new(&mut self.outbound_rx).poll_next(cx) { Poll::Ready(Some(mut messages)) => { if !messages.is_empty() { messages.retain(|message| self.on_outbound_message(message)); - if !messages.is_empty() { - let frame = Frame::MessageBatch(messages); - self.write_state.park_frame(frame); + for msg in messages { + self.io.enqueue(msg); } } } @@ -400,125 +358,13 @@ where } } - fn on_inbound_frame(&mut self, frame: Frame) -> Result<()> { - match frame { - Frame::RawBatch(raw_batch) => { - let mut processed_state: Option = None; - for buf in raw_batch { - let state_name: String = format!("{:?}", self.state); - match self.state { - State::Handshake(_) => self.on_handshake_message(buf)?, - State::SecretStream(_) => self.on_secret_stream_message(buf)?, - State::Established => { - if let Some(processed_state) = processed_state.as_ref() { - // last state before established - let previous_state = if self.options.encrypted { - // was SecretStream if we're encrypted - State::SecretStream(None) - } else { - // or wa hasdshake if we're not encrypted - State::Handshake(None) - }; - - // if htis raw_batch included regular messages (not handshake) - // after handshake stuff - if processed_state == &format!("{previous_state:?}") { - // This is the unlucky case where the batch had two or more messages where - // the first one was correctly identified as Raw but everything - // after that should have been (decrypted and) a MessageBatch. Correct the mistake - // here post-hoc. - let buf = self.read_state.decrypt_buf(&buf)?; - let frame = Frame::decode(&buf, &FrameType::Message)?; - self.on_inbound_frame(frame)?; - continue; - } - } - unreachable!( - "May not receive raw frames in Established state" - ) - } - _ => unreachable!( - "May not receive raw frames outside of handshake or secretstream state, was {:?}", - self.state - ), - }; - if processed_state.is_none() { - processed_state = Some(state_name) - } - } - Ok(()) - } - Frame::MessageBatch(channel_messages) => match self.state { - State::Established => { - for channel_message in channel_messages { - self.on_inbound_message(channel_message)? - } - Ok(()) - } - _ => unreachable!("May not receive message batch frames when not established"), - }, - } - } - - fn on_handshake_message(&mut self, buf: Vec) -> Result<()> { - let mut handshake = match &mut self.state { - State::Handshake(handshake) => handshake.take().unwrap(), - _ => unreachable!("May not call on_handshake_message when not in Handshake state"), - }; - - if let Some(response_buf) = handshake.read(&buf)? { - self.queue_frame_direct(response_buf.to_vec()).unwrap(); - } - - if !handshake.complete() { - self.state = State::Handshake(Some(handshake)); - } else { - let handshake_result = handshake.into_result()?; - - if self.options.encrypted { - // The cipher will be put to use to the writer only after the peer's answer has come - let (cipher, init_msg) = EncryptCipher::from_handshake_tx(handshake_result)?; - self.state = State::SecretStream(Some(cipher)); - - // Send the secret stream init message header to the other side - self.queue_frame_direct(init_msg).unwrap(); - } else { - // Skip secret stream and go straight to Established, then notify about - // handshake - self.read_state.set_frame_type(FrameType::Message); - let remote_public_key = parse_key(&handshake_result.remote_pubkey)?; - self.queue_event(Event::Handshake(remote_public_key)); - self.state = State::Established; - } - // Store handshake result - self.handshake = Some(handshake_result.clone()); + fn on_inbound_channel_messages(&mut self, channel_messages: Vec) -> Result<()> { + for channel_message in channel_messages { + self.on_inbound_message(channel_message)? } Ok(()) } - fn on_secret_stream_message(&mut self, buf: Vec) -> Result<()> { - let encrypt_cipher = match &mut self.state { - State::SecretStream(encrypt_cipher) => encrypt_cipher.take().unwrap(), - _ => { - unreachable!("May not call on_secret_stream_message when not in SecretStream state") - } - }; - let handshake_result = &self - .handshake - .as_ref() - .expect("Handshake result must be set before secret stream"); - let decrypt_cipher = DecryptCipher::from_handshake_rx_and_init_msg(handshake_result, &buf)?; - self.read_state.upgrade_with_decrypt_cipher(decrypt_cipher); - self.write_state.upgrade_with_encrypt_cipher(encrypt_cipher); - self.read_state.set_frame_type(FrameType::Message); - - // Lastly notify that handshake is ready and set state to established - let remote_public_key = parse_key(&handshake_result.remote_pubkey)?; - self.queue_event(Event::Handshake(remote_public_key)); - self.state = State::Established; - Ok(()) - } - fn on_inbound_message(&mut self, channel_message: ChannelMessage) -> Result<()> { // let channel_message = ChannelMessage::decode(buf)?; let (remote_id, message) = channel_message.into_split(); @@ -564,8 +410,7 @@ where capability, }); let channel_message = ChannelMessage::new(channel, message); - self.write_state - .queue_frame(Frame::MessageBatch(vec![channel_message])); + self.io.enqueue(channel_message); Ok(()) } @@ -602,13 +447,6 @@ where self.queued_events.push_back(event); } - /// enequeu a buf to be sent - fn queue_frame_direct(&mut self, body: Vec) -> Result { - let mut frame = Frame::RawBatch(vec![body]); - self.write_state - .try_encode_and_enqueue_frame_for_tx(&mut frame) - } - fn accept_channel(&mut self, local_id: usize) -> Result<()> { let (key, remote_capability) = self.channels.prepare_to_verify(local_id)?; self.verify_remote_capability(remote_capability.cloned(), key)?; @@ -662,7 +500,11 @@ where { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Protocol::poll_next(self, cx).map(Some) + match Protocol::poll_next(self, cx) { + Poll::Ready(Ok(e)) => Poll::Ready(Some(Ok(e))), + Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))), + Poll::Pending => Poll::Pending, + } } } diff --git a/src/protocol/old.rs b/src/protocol/old.rs index 930f9bd..01af713 100644 --- a/src/protocol/old.rs +++ b/src/protocol/old.rs @@ -387,6 +387,7 @@ where messages.retain(|message| self.on_outbound_message(message)); if !messages.is_empty() { let frame = Frame::MessageBatch(messages); + // TODO try replacing this with queue_frame self.write_state.park_frame(frame); } } diff --git a/src/schema.rs b/src/schema.rs index ef58e77..bf35416 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -18,9 +18,9 @@ pub struct Open { impl CompactEncoding for State { fn preencode(&mut self, value: &Open) -> Result { - self.preencode(&value.channel)?; - self.preencode(&value.protocol)?; - self.preencode(&value.discovery_key)?; + dbg!(self.preencode(&value.channel)?); + dbg!(self.preencode(&value.protocol)?); + dbg!(self.preencode(&value.discovery_key)?); if value.capability.is_some() { self.add_end(1)?; // flags for future use self.preencode_fixed_32()?; @@ -29,6 +29,7 @@ impl CompactEncoding for State { } fn encode(&mut self, value: &Open, buffer: &mut [u8]) -> Result { + dbg!(); self.encode(&value.channel, buffer)?; self.encode(&value.protocol, buffer)?; self.encode(&value.discovery_key, buffer)?; @@ -369,7 +370,7 @@ pub struct NoData { impl CompactEncoding for State { fn preencode(&mut self, value: &NoData) -> Result { - self.preencode(&value.request) + dbg!(self.preencode(dbg!(&value.request))) } fn encode(&mut self, value: &NoData, buffer: &mut [u8]) -> Result { diff --git a/src/test_utils.rs b/src/test_utils.rs index ff1a3c2..b12f440 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -1,7 +1,6 @@ use std::{ io::{self, ErrorKind}, pin::Pin, - sync::OnceLock, task::{Context, Poll}, }; @@ -74,9 +73,9 @@ impl TwoWay { } } -pub(crate) fn log() { +pub fn log() { use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; - static START_LOGS: OnceLock<()> = OnceLock::new(); + static START_LOGS: std::sync::OnceLock<()> = std::sync::OnceLock::new(); START_LOGS.get_or_init(|| { tracing_subscriber::fmt() .with_target(true) diff --git a/tests/_util.rs b/tests/_util.rs index 3064c08..d1fa197 100644 --- a/tests/_util.rs +++ b/tests/_util.rs @@ -3,10 +3,10 @@ use futures_lite::io::{AsyncRead, AsyncWrite}; use futures_lite::StreamExt; use hypercore_protocol::{Channel, DiscoveryKey, Duplex, Event, Protocol, ProtocolBuilder}; use instant::Duration; -use std::future::Future; use std::io; use tokio::task::JoinHandle; +#[allow(unused)] pub(crate) fn log() { use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; static START_LOGS: std::sync::OnceLock<()> = std::sync::OnceLock::new(); @@ -83,9 +83,10 @@ where }) } -const RETRY_TIMEOUT: u64 = 100_u64; -const NO_RESPONSE_TIMEOUT: u64 = 1000_u64; +#[allow(unused)] pub async fn wait_for_localhost_port(port: u32) { + const RETRY_TIMEOUT: u64 = 100_u64; + const NO_RESPONSE_TIMEOUT: u64 = 1000_u64; loop { let timeout = async_std::future::timeout( Duration::from_millis(NO_RESPONSE_TIMEOUT), diff --git a/tests/basic.rs b/tests/basic.rs index 062cf35..92e4c8b 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -13,10 +13,19 @@ mod _util; async fn basic_protocol() -> anyhow::Result<()> { let (proto_a, proto_b) = create_pair_memory().await?; + dbg!(); let next_a = next_event(proto_a); + dbg!(); let next_b = next_event(proto_b); - let (mut proto_a, event_a) = next_a.await?; + dbg!(); let (proto_b, event_b) = next_b.await?; + dbg!(); + let (mut proto_a, event_a) = next_a.await?; + //let (a, b) = join(next_a, next_b).await; + dbg!(); + //let (mut proto_a, event_a) = a?; + dbg!(); + //let (proto_b, event_b) = b?; assert!(matches!(event_a, Ok(Event::Handshake(_)))); assert!(matches!(event_b, Ok(Event::Handshake(_)))); From e9d6236a2a2b47bd71b376453a190b9a00f61569 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 19 Mar 2025 12:22:57 -0400 Subject: [PATCH 050/206] Add frame encoding test --- src/message.rs | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/message.rs b/src/message.rs index 869baf0..25a2545 100644 --- a/src/message.rs +++ b/src/message.rs @@ -676,6 +676,28 @@ mod tests { } } + #[test] + fn frame_encode_decode() -> std::io::Result<()> { + let msg = Message::Synchronize(Synchronize { + fork: 0, + can_upgrade: true, + downloading: true, + uploading: true, + length: 5, + remote_length: 0, + }); + + let channel = rand::random::() as u64; + let channel_message = ChannelMessage::new(channel, msg); + + let frame = Frame::from(channel_message); + let mut buf = vec![0; frame.encoded_len()?]; + frame.encode(&mut buf)?; + let res_frame = Frame::decode(&buf, &FrameType::Message)?; + assert_eq!(res_frame, frame); + Ok(()) + } + #[test] fn message_encode_decode() { message_enc_dec! { From 15d2895ba5ed704e4302a4bd9dacc97c7779de32 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 20 Mar 2025 15:45:34 -0400 Subject: [PATCH 051/206] fix ChannelMessage Encoding. Restore Frame encoding --- src/message.rs | 322 ++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 292 insertions(+), 30 deletions(-) diff --git a/src/message.rs b/src/message.rs index 25a2545..90028aa 100644 --- a/src/message.rs +++ b/src/message.rs @@ -8,6 +8,8 @@ use std::fmt; use std::io; use tracing::instrument; +const UINT24_HEADER_LEN: usize = 3; + /// The type of a data frame. #[derive(Debug, Clone, PartialEq)] pub(crate) enum FrameType { @@ -71,13 +73,60 @@ impl From for Frame { } } +impl From> for Frame { + fn from(m: Vec) -> Self { + Self::MessageBatch(m) + } +} + impl From> for Frame { fn from(m: Vec) -> Self { Self::RawBatch(vec![m]) } } -pub(crate) fn decode_channel_messages( +pub(crate) fn decode_many_channel_messages( + buf: &[u8], +) -> Result<(Vec, usize), io::Error> { + let mut index = 0; + let mut combined_messages: Vec = vec![]; + while index < buf.len() { + // There might be zero bytes in between, and with LE, the next message will + // start with a non-zero + if buf[index] == 0 { + index += 1; + continue; + } + + let stat = stat_uint24_le(&buf[index..]); + if let Some((header_len, body_len)) = stat { + let (msgs, length) = decode_one_channel_message( + &buf[index + header_len..index + header_len + body_len as usize], + )?; + if length != body_len as usize { + tracing::warn!( + "Did not know what to do with all the bytes, got {} but decoded {}. \ + This may be because the peer implements a newer protocol version \ + that has extra fields.", + body_len, + length + ); + } + for message in msgs { + combined_messages.push(message); + } + index += header_len + body_len as usize; + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "received invalid data in multi-message chunk", + )); + } + } + Ok((combined_messages, index)) +} +// bad name bc it returns many. More like, decode unframed channel messages +pub(crate) fn decode_one_channel_message( buf: &[u8], ) -> Result<(Vec, usize), io::Error> { if buf.len() >= 3 && buf[0] == 0x00 { @@ -237,8 +286,76 @@ impl Frame { } fn decode_message(buf: &[u8]) -> Result<(Self, usize), io::Error> { - let (channel_messages, bytes_read) = decode_channel_messages(buf)?; - Ok((Frame::MessageBatch(channel_messages), bytes_read)) + println!("decode_message {buf:02X?}"); + // buffer length >= 3 or more and starts with 0 is message batch + if buf.len() >= 3 && buf[0] == 0x00 { + if buf[1] == 0x00 { + // Batch of messages + let mut messages: Vec = vec![]; + let mut state = State::new_with_start_and_end(2, buf.len()); + + // First, there is the original channel + let mut current_channel: u64 = state.decode(buf)?; + while state.start() < state.end() { + // Length of the message is inbetween here + let channel_message_length: usize = state.decode(buf)?; + if state.start() + channel_message_length > state.end() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "received invalid message length, {} + {} > {}", + state.start(), + channel_message_length, + state.end() + ), + )); + } + // Then the actual message + let (channel_message, _) = ChannelMessage::decode( + &buf[state.start()..state.start() + channel_message_length], + current_channel, + )?; + messages.push(channel_message); + state.add_start(channel_message_length)?; + // After that, if there is an extra 0x00, that means the channel + // changed. This works because of LE encoding, and channels starting + // from the index 1. + if state.start() < state.end() && buf[state.start()] == 0x00 { + state.add_start(1)?; + current_channel = state.decode(buf)?; + } + } + Ok((Frame::MessageBatch(messages), state.start())) + } else if buf[1] == 0x01 { + // Open message + let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?; + Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) + } else if buf[1] == 0x03 { + // Close message + let (channel_message, length) = ChannelMessage::decode_close_message(&buf[2..])?; + Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + "received invalid special message", + )) + } + } else if buf.len() >= 2 { + // len >= and + // Single message + let mut state = State::from_buffer(buf); + let channel: u64 = state.decode(buf)?; + let (channel_message, length) = ChannelMessage::decode(&buf[state.start()..], channel)?; + Ok(( + Frame::MessageBatch(vec![channel_message]), + state.start() + length, + )) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("received too short message, {buf:02X?}"), + )) + } } fn preencode(&self, state: &mut State) -> Result { @@ -248,8 +365,37 @@ impl Frame { state.add_end(raw.as_slice().encoded_len()?)?; } } + #[allow(clippy::comparison_chain)] Self::MessageBatch(messages) => { - state.add_end(messages.encoded_len()?)?; + if messages.len() == 1 { + if let Message::Open(_) = &messages[0].message { + // This is a special case with 0x00, 0x01 intro bytes + state.add_end(2 + &messages[0].encoded_len()?)?; + } else if let Message::Close(_) = &messages[0].message { + // This is a special case with 0x00, 0x03 intro bytes + state.add_end(2 + &messages[0].encoded_len()?)?; + } else { + (*state).preencode(&messages[0].channel)?; + state.add_end(messages[0].encoded_len()?)?; + } + } else if messages.len() > 1 { + // Two intro bytes 0x00 0x00, then channel id, then lengths + state.add_end(2)?; + let mut current_channel: u64 = messages[0].channel; + state.preencode(¤t_channel)?; + for message in messages.iter() { + if message.channel != current_channel { + // Channel changed, need to add a 0x00 in between and then the new + // channel + state.add_end(1)?; + state.preencode(&message.channel)?; + current_channel = message.channel; + } + let message_length = message.encoded_len()?; + state.preencode(&message_length)?; + state.add_end(message_length)?; + } + } } } Ok(state.end()) @@ -282,8 +428,43 @@ impl Encoder for Frame { raw.as_slice().encode(buf)?; } } + #[allow(clippy::comparison_chain)] Self::MessageBatch(ref messages) => { - messages.encode(buf)?; + write_uint24_le(body_len, buf); + let buf = buf.get_mut(3..).expect("Buffer should be over 3 bytes"); + if messages.len() == 1 { + if let Message::Open(_) = &messages[0].message { + // This is a special case with 0x00, 0x01 intro bytes + state.encode(&(0_u8), buf)?; + state.encode(&(1_u8), buf)?; + state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; + } else if let Message::Close(_) = &messages[0].message { + // This is a special case with 0x00, 0x03 intro bytes + state.encode(&(0_u8), buf)?; + state.encode(&(3_u8), buf)?; + state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; + } else { + state.encode(&messages[0].channel, buf)?; + state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; + } + } else if messages.len() > 1 { + // Two intro bytes 0x00 0x00, then channel id, then lengths + state.set_slice_to_buffer(&[0_u8, 0_u8], buf)?; + let mut current_channel: u64 = messages[0].channel; + state.encode(¤t_channel, buf)?; + for message in messages.iter() { + if message.channel != current_channel { + // Channel changed, need to add a 0x00 in between and then the new + // channel + state.encode(&(0_u8), buf)?; + state.encode(&message.channel, buf)?; + current_channel = message.channel; + } + let message_length = message.encoded_len()?; + state.encode(&message_length, buf)?; + state.add_start(message.encode(&mut buf[state.start()..])?)?; + } + } } }; Ok(len) @@ -333,12 +514,11 @@ fn prencode_channel_messages( impl Encoder for Vec { fn encoded_len(&self) -> Result { let mut state = State::new(); - prencode_channel_messages(self, &mut state) + Ok(prencode_channel_messages(self, &mut state)? + UINT24_HEADER_LEN) } #[instrument] fn encode(&self, buf: &mut [u8]) -> Result { - const HEADER_LEN: usize = 3; let mut state = State::new(); let body_len = prencode_channel_messages(self, &mut state)?; write_uint24_le(body_len, buf); @@ -380,7 +560,7 @@ impl Encoder for Vec { } } } - Ok(HEADER_LEN + body_len) + Ok(UINT24_HEADER_LEN + body_len) } } @@ -656,6 +836,7 @@ impl Encoder for ChannelMessage { #[cfg(test)] mod tests { + use super::*; use hypercore::{ DataBlock, DataHash, DataSeek, DataUpgrade, Node, RequestBlock, RequestSeek, RequestUpgrade, @@ -676,28 +857,6 @@ mod tests { } } - #[test] - fn frame_encode_decode() -> std::io::Result<()> { - let msg = Message::Synchronize(Synchronize { - fork: 0, - can_upgrade: true, - downloading: true, - uploading: true, - length: 5, - remote_length: 0, - }); - - let channel = rand::random::() as u64; - let channel_message = ChannelMessage::new(channel, msg); - - let frame = Frame::from(channel_message); - let mut buf = vec![0; frame.encoded_len()?]; - frame.encode(&mut buf)?; - let res_frame = Frame::decode(&buf, &FrameType::Message)?; - assert_eq!(res_frame, frame); - Ok(()) - } - #[test] fn message_encode_decode() { message_enc_dec! { @@ -781,4 +940,107 @@ mod tests { }) }; } + + fn message_test_data() -> Vec { + vec![ + Message::Synchronize(Synchronize { + fork: 0, + can_upgrade: true, + downloading: true, + uploading: true, + length: 5, + remote_length: 0, + }), + Message::Request(Request { + id: 1, + fork: 1, + block: Some(RequestBlock { + index: 5, + nodes: 10, + }), + hash: Some(RequestBlock { + index: 20, + nodes: 0, + }), + seek: Some(RequestSeek { bytes: 10 }), + upgrade: Some(RequestUpgrade { + start: 0, + length: 10, + }), + }), + Message::Cancel(Cancel { request: 1 }), + Message::Data(Data { + request: 1, + fork: 5, + block: Some(DataBlock { + index: 5, + nodes: vec![Node::new(1, vec![0x01; 32], 100)], + value: vec![0xFF; 10], + }), + hash: Some(DataHash { + index: 20, + nodes: vec![Node::new(2, vec![0x02; 32], 200)], + }), + seek: Some(DataSeek { + bytes: 10, + nodes: vec![Node::new(3, vec![0x03; 32], 300)], + }), + upgrade: Some(DataUpgrade { + start: 0, + length: 10, + nodes: vec![Node::new(4, vec![0x04; 32], 400)], + additional_nodes: vec![Node::new(5, vec![0x05; 32], 500)], + signature: vec![0xAB; 32], + }), + }), + Message::NoData(NoData { request: 2 }), + Message::Want(Want { + start: 0, + length: 100, + }), + Message::Unwant(Unwant { + start: 10, + length: 2, + }), + Message::Bitfield(Bitfield { + start: 20, + bitfield: vec![0x89ABCDEF, 0x00, 0xFFFFFFFF], + }), + Message::Range(Range { + drop: true, + start: 12345, + length: 100000, + }), + Message::Extension(Extension { + name: "custom_extension/v1/open".to_string(), + message: vec![0x44, 20], + }), + ] + } + + #[test] + fn compare_with_frame_encoding_decoding() -> std::io::Result<()> { + let channel = 42; + for msg in message_test_data() { + let channel_message = ChannelMessage::new(channel, msg); + let frame = Frame::from(channel_message.clone()); + let cmvec = vec![channel_message.clone()]; + + let mut fbuf = vec![0; frame.encoded_len()?]; + let mut cbuf = vec![0; cmvec.encoded_len()?]; + + assert_eq!(cbuf, fbuf); + + frame.encode(&mut fbuf)?; + cmvec.encode(&mut cbuf)?; + + assert_eq!(cbuf, fbuf); + + let fres = Frame::decode_multiple(&fbuf, &FrameType::Message)?; + assert_eq!(fres, frame); + let cres_m = decode_many_channel_messages(&cbuf)?.0; + assert_eq!(cres_m, cmvec); + } + Ok(()) + } } From 002ba3681f6a9a957e63851b646b008bf65de272 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 25 Mar 2025 12:29:49 -0400 Subject: [PATCH 052/206] rm/gate old unused stuff. add instrument --- src/framing.rs | 67 +++- src/lib.rs | 4 + src/message.rs | 28 -- src/mqueue.rs | 102 +++--- src/oldmessage.rs | 814 +++++++++++++++++++++++++++++++++++++++++ src/protocol/modern.rs | 38 +- src/protocol/old.rs | 7 +- src/test_utils.rs | 2 +- src/util.rs | 26 ++ tests/_util.rs | 19 + tests/basic.rs | 4 +- 11 files changed, 991 insertions(+), 120 deletions(-) create mode 100644 src/oldmessage.rs diff --git a/src/framing.rs b/src/framing.rs index ce4d7ba..8b8ae8f 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -192,11 +192,12 @@ where } #[cfg(test)] pub(crate) mod test { - use crate::test_utils::log; + use crate::{test_utils::log, Duplex}; use super::*; use futures::{SinkExt, StreamExt}; use futures_lite::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + use tokio::spawn; use tokio_util::compat::TokioAsyncReadCompatExt; pub(crate) fn duplex( @@ -307,6 +308,70 @@ pub(crate) mod test { assert_eq!(r3, data); assert_eq!(r4, data); + Ok(()) + } + #[tokio::test] + async fn left_and_right_sluice() -> Result<()> { + let (ar, bw) = sluice::pipe::pipe(); + let (br, aw) = sluice::pipe::pipe(); + let left = Duplex::new(ar, aw); + let right = Duplex::new(br, bw); + + let mut leftlp = Uint24LELengthPrefixedFraming::new(left); + let mut rightlp = Uint24LELengthPrefixedFraming::new(right); + + // NB sluice has a max "chunk" thing of 4 + // so we limit the data we're sending to 3 things + let data: &[&[u8]] = &[b"yolo", b"squalor", b"idle"]; + // NB this sluice pipe + // + for d in data { + rightlp.feed(d.to_vec()).await.unwrap(); + } + let rflush = spawn(async move { + rightlp.flush().await.unwrap(); + rightlp + }); + + let mut result1 = vec![]; + for _ in data { + result1.push(leftlp.next().await.unwrap().unwrap()); + } + let mut rightlp = rflush.await?; + + assert_eq!(result1, data); + + for d in data { + leftlp.feed(d.to_vec()).await.unwrap(); + } + let lflush = spawn(async move { + leftlp.flush().await.unwrap(); + leftlp + }); + + let mut result2 = vec![]; + for _ in data { + result2.push(rightlp.next().await.unwrap().unwrap()); + } + let mut leftlp = lflush.await?; + assert_eq!(result2, data); + + let mut r3 = vec![]; + let mut r4 = vec![]; + + for d in data { + rightlp.send(d.to_vec()).await.unwrap(); + leftlp.send(d.to_vec()).await.unwrap(); + } + + for _ in data { + r3.push(rightlp.next().await.unwrap().unwrap()); + r4.push(leftlp.next().await.unwrap().unwrap()); + } + + assert_eq!(r3, data); + assert_eq!(r4, data); + Ok(()) } } diff --git a/src/lib.rs b/src/lib.rs index 9fca95e..999aa26 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -126,11 +126,15 @@ mod framing; mod message; mod mqueue; mod noise; +#[cfg(not(feature = "protocol"))] +mod oldmessage; mod protocol; +#[cfg(not(feature = "protocol"))] mod reader; #[cfg(test)] mod test_utils; mod util; +#[cfg(not(feature = "protocol"))] mod writer; /// The wire messages used by the protocol. diff --git a/src/message.rs b/src/message.rs index 90028aa..6d5200c 100644 --- a/src/message.rs +++ b/src/message.rs @@ -13,7 +13,6 @@ const UINT24_HEADER_LEN: usize = 3; /// The type of a data frame. #[derive(Debug, Clone, PartialEq)] pub(crate) enum FrameType { - Raw, Message, } @@ -203,32 +202,6 @@ impl Frame { /// Decodes a frame from a buffer containing multiple concurrent messages. pub(crate) fn decode_multiple(buf: &[u8], frame_type: &FrameType) -> Result { match frame_type { - FrameType::Raw => { - let mut index = 0; - let mut raw_batch: Vec> = vec![]; - while index < buf.len() { - // There might be zero bytes in between, and with LE, the next message will - // start with a non-zero - if buf[index] == 0 { - index += 1; - continue; - } - let stat = stat_uint24_le(&buf[index..]); - if let Some((header_len, body_len)) = stat { - raw_batch.push( - buf[index + header_len..index + header_len + body_len as usize] - .to_vec(), - ); - index += header_len + body_len as usize; - } else { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "received invalid data in raw batch", - )); - } - } - Ok(Frame::RawBatch(raw_batch)) - } FrameType::Message => { let mut index = 0; let mut combined_messages: Vec = vec![]; @@ -277,7 +250,6 @@ impl Frame { /// Decode a frame from a buffer. pub(crate) fn decode(buf: &[u8], frame_type: &FrameType) -> Result { match frame_type { - FrameType::Raw => Ok(Frame::RawBatch(vec![buf.to_vec()])), FrameType::Message => { let (frame, _) = Self::decode_message(buf)?; Ok(frame) diff --git a/src/mqueue.rs b/src/mqueue.rs index b968937..39b4c9b 100644 --- a/src/mqueue.rs +++ b/src/mqueue.rs @@ -12,7 +12,7 @@ use tracing::{debug, error, info, instrument, trace}; use crate::{ encrypted_framed_message_channel, - message::{decode_channel_messages, ChannelMessage, Encoder as _}, + message::{decode_many_channel_messages, ChannelMessage, Encoder as _}, }; pub(crate) struct MessageIo { @@ -22,16 +22,6 @@ pub(crate) struct MessageIo { use crate::{framing::Uint24LELengthPrefixedFraming, noise::Encrypted}; -pub(crate) fn encrypted_and_framed( - is_initiator: bool, - io: BytesTxRx, -) -> MessageIo>> { - let io = encrypted_framed_message_channel(is_initiator, io); - MessageIo { - io, - write_queue: Default::default(), - } -} impl>> + Sink> + Send + Unpin + 'static> MessageIo { pub(crate) fn new(io: IO) -> Self { Self { @@ -101,7 +91,7 @@ impl>> + Sink> + Send + Unpin + 'static ) -> Poll>> { match Pin::new(&mut self.io).poll_next(cx) { Poll::Ready(Some(Ok(encoded))) => { - match decode_channel_messages(&encoded) { + match decode_many_channel_messages(&encoded) { Ok((messsages, n_read)) => { assert_eq!(n_read, encoded.len()); // I think this is always true Poll::Ready(Ok(messsages)) @@ -142,16 +132,27 @@ impl>> + Sink> + Send + Unpin + 'static mod test { use std::io::Result; - use futures::future::{join, select}; + use futures::{future::select, AsyncRead, AsyncWrite}; use futures_lite::StreamExt; use crate::{ - framing::test::duplex, - message::{decode_channel_messages, ChannelMessage, Encoder as _}, - mqueue::encrypted_and_framed, - schema::{NoData, Open}, - test_utils::log, + encrypted_framed_message_channel, framing::test::duplex, message::ChannelMessage, + schema::NoData, test_utils::log, Encrypted, Uint24LELengthPrefixedFraming, }; + + use super::MessageIo; + pub(crate) fn encrypted_and_framed< + BytesTxRx: AsyncRead + AsyncWrite + Send + Unpin + 'static, + >( + is_initiator: bool, + io: BytesTxRx, + ) -> MessageIo>> { + let io = encrypted_framed_message_channel(is_initiator, io); + MessageIo { + io, + write_queue: Default::default(), + } + } fn new_msg(channel: u64) -> ChannelMessage { ChannelMessage { channel, @@ -162,47 +163,32 @@ mod test { #[tokio::test] async fn mqueue() -> Result<()> { log(); - let m = vec![new_msg(0)]; - let mut buf = vec![0; m.encoded_len()?]; - dbg!(&buf.len()); - dbg!(); - m.encode(&mut buf)?; - dbg!(&buf); - - let res = dbg!(decode_channel_messages(&buf))?; - assert_eq!(vec![new_msg(42402)], res.0); - dbg!(&buf); - - Ok(()) - - /* - let (left, right) = duplex(1024 * 64); - let mut left = encrypted_and_framed(true, left); - let mut right = encrypted_and_framed(false, right); - left.enqueue(new_msg(42)); - right.enqueue(new_msg(38)); - match select(left.next(), right.next()).await { - futures::future::Either::Left(ll) => { - println!( - "left - - ooooooooooooooooooooo - - " - ); - } - futures::future::Either::Right(rr) => { - println!( - "rightllllllllllllllll - - ------------------------- - - " - ); - } + let rtolm = new_msg(38); + let ltorm = new_msg(42); + + let (left, right) = duplex(1024 * 64); + let mut left = encrypted_and_framed(true, left); + let mut right = encrypted_and_framed(false, right); + left.enqueue(ltorm.clone()); + right.enqueue(rtolm.clone()); + + match select(left.next(), right.next()).await { + futures::future::Either::Left((m, _)) => { + if let Some(Ok(res)) = m { + assert_eq!(res, vec![rtolm]); + } else { + panic!(); } - Ok(()) - */ + } + futures::future::Either::Right((m, _)) => { + if let Some(Ok(res)) = m { + assert_eq!(res, vec![ltorm]); + } else { + panic!(); + } + } + } + Ok(()) } } diff --git a/src/oldmessage.rs b/src/oldmessage.rs new file mode 100644 index 0000000..8cb2c61 --- /dev/null +++ b/src/oldmessage.rs @@ -0,0 +1,814 @@ +use crate::schema::*; +use crate::util::{stat_uint24_le, write_uint24_le}; +use hypercore::encoding::{ + CompactEncoding, EncodingError, EncodingErrorKind, HypercoreState, State, +}; +use pretty_hash::fmt as pretty_fmt; +use std::fmt; +use std::io; + +/// The type of a data frame. +#[derive(Debug, Clone, PartialEq)] +pub(crate) enum FrameType { + Raw, + Message, +} + +/// Encode data into a buffer. +/// +/// This trait is implemented on data frames and their components +/// (channel messages, messages, and individual message types through prost). +pub(crate) trait Encoder: Sized + fmt::Debug { + /// Calculates the length that the encoded message needs. + fn encoded_len(&mut self) -> Result; + + /// Encodes the message to a buffer. + /// + /// An error will be returned if the buffer does not have sufficient capacity. + fn encode(&mut self, buf: &mut [u8]) -> Result; +} + +impl Encoder for &[u8] { + fn encoded_len(&mut self) -> Result { + Ok(self.len()) + } + + fn encode(&mut self, buf: &mut [u8]) -> Result { + let len = self.encoded_len()?; + if len > buf.len() { + return Err(EncodingError::new( + EncodingErrorKind::Overflow, + &format!("Length does not fit buffer, {} > {}", len, buf.len()), + )); + } + buf[..len].copy_from_slice(&self[..]); + Ok(len) + } +} + +/// A frame of data, either a buffer or a message. +#[derive(Clone, PartialEq)] +pub(crate) enum Frame { + /// A raw batch binary buffer. Used in the handshaking phase. + RawBatch(Vec>), + /// Message batch, containing one or more channel messsages. Used for everything after the handshake. + MessageBatch(Vec), +} + +impl fmt::Debug for Frame { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Frame::RawBatch(batch) => write!(f, "Frame(RawBatch <{}>)", batch.len()), + Frame::MessageBatch(messages) => write!(f, "Frame({messages:?})"), + } + } +} + +impl From for Frame { + fn from(m: ChannelMessage) -> Self { + Self::MessageBatch(vec![m]) + } +} + +impl From> for Frame { + fn from(m: Vec) -> Self { + Self::RawBatch(vec![m]) + } +} + +impl Frame { + /// Decodes a frame from a buffer containing multiple concurrent messages. + pub(crate) fn decode_multiple(buf: &[u8], frame_type: &FrameType) -> Result { + match frame_type { + FrameType::Raw => { + let mut index = 0; + let mut raw_batch: Vec> = vec![]; + while index < buf.len() { + // There might be zero bytes in between, and with LE, the next message will + // start with a non-zero + if buf[index] == 0 { + index += 1; + continue; + } + let stat = stat_uint24_le(&buf[index..]); + if let Some((header_len, body_len)) = stat { + raw_batch.push( + buf[index + header_len..index + header_len + body_len as usize] + .to_vec(), + ); + index += header_len + body_len as usize; + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "received invalid data in raw batch", + )); + } + } + Ok(Frame::RawBatch(raw_batch)) + } + FrameType::Message => { + let mut index = 0; + let mut combined_messages: Vec = vec![]; + while index < buf.len() { + // There might be zero bytes in between, and with LE, the next message will + // start with a non-zero + if buf[index] == 0 { + index += 1; + continue; + } + + let stat = stat_uint24_le(&buf[index..]); + if let Some((header_len, body_len)) = stat { + let (frame, length) = Self::decode_message( + &buf[index + header_len..index + header_len + body_len as usize], + )?; + if length != body_len as usize { + tracing::warn!( + "Did not know what to do with all the bytes, got {} but decoded {}. \ + This may be because the peer implements a newer protocol version \ + that has extra fields.", + body_len, + length + ); + } + if let Frame::MessageBatch(messages) = frame { + for message in messages { + combined_messages.push(message); + } + } else { + unreachable!("Can not get Raw messages"); + } + index += header_len + body_len as usize; + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "received invalid data in multi-message chunk", + )); + } + } + Ok(Frame::MessageBatch(combined_messages)) + } + } + } + + /// Decode a frame from a buffer. + pub(crate) fn decode(buf: &[u8], frame_type: &FrameType) -> Result { + match frame_type { + FrameType::Raw => Ok(Frame::RawBatch(vec![buf.to_vec()])), + FrameType::Message => { + let (frame, _) = Self::decode_message(buf)?; + Ok(frame) + } + } + } + + fn decode_message(buf: &[u8]) -> Result<(Self, usize), io::Error> { + println!("decode_message {buf:02X?}"); + // buffer length >= 3 or more and starts with 0 is message batch + if buf.len() >= 3 && buf[0] == 0x00 { + if buf[1] == 0x00 { + // Batch of messages + let mut messages: Vec = vec![]; + let mut state = State::new_with_start_and_end(2, buf.len()); + + // First, there is the original channel + let mut current_channel: u64 = state.decode(buf)?; + while state.start() < state.end() { + // Length of the message is inbetween here + let channel_message_length: usize = state.decode(buf)?; + if state.start() + channel_message_length > state.end() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "received invalid message length, {} + {} > {}", + state.start(), + channel_message_length, + state.end() + ), + )); + } + // Then the actual message + let (channel_message, _) = ChannelMessage::decode( + &buf[state.start()..state.start() + channel_message_length], + current_channel, + )?; + messages.push(channel_message); + state.add_start(channel_message_length)?; + // After that, if there is an extra 0x00, that means the channel + // changed. This works because of LE encoding, and channels starting + // from the index 1. + if state.start() < state.end() && buf[state.start()] == 0x00 { + state.add_start(1)?; + current_channel = state.decode(buf)?; + } + } + Ok((Frame::MessageBatch(messages), state.start())) + } else if buf[1] == 0x01 { + // Open message + let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?; + Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) + } else if buf[1] == 0x03 { + // Close message + let (channel_message, length) = ChannelMessage::decode_close_message(&buf[2..])?; + Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + "received invalid special message", + )) + } + } else if buf.len() >= 2 { + // len >= and + // Single message + let mut state = State::from_buffer(buf); + let channel: u64 = state.decode(buf)?; + let (channel_message, length) = ChannelMessage::decode(&buf[state.start()..], channel)?; + Ok(( + Frame::MessageBatch(vec![channel_message]), + state.start() + length, + )) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("received too short message, {buf:02X?}"), + )) + } + } + + fn preencode(&mut self, state: &mut State) -> Result { + match self { + Self::RawBatch(raw_batch) => { + for raw in raw_batch { + state.add_end(raw.as_slice().encoded_len()?)?; + } + } + #[allow(clippy::comparison_chain)] + Self::MessageBatch(messages) => { + if messages.len() == 1 { + if let Message::Open(_) = &messages[0].message { + // This is a special case with 0x00, 0x01 intro bytes + state.add_end(2 + &messages[0].encoded_len()?)?; + } else if let Message::Close(_) = &messages[0].message { + // This is a special case with 0x00, 0x03 intro bytes + state.add_end(2 + &messages[0].encoded_len()?)?; + } else { + (*state).preencode(&messages[0].channel)?; + state.add_end(messages[0].encoded_len()?)?; + } + } else if messages.len() > 1 { + // Two intro bytes 0x00 0x00, then channel id, then lengths + state.add_end(2)?; + let mut current_channel: u64 = messages[0].channel; + state.preencode(¤t_channel)?; + for message in messages.iter_mut() { + if message.channel != current_channel { + // Channel changed, need to add a 0x00 in between and then the new + // channel + state.add_end(1)?; + state.preencode(&message.channel)?; + current_channel = message.channel; + } + let message_length = message.encoded_len()?; + state.preencode(&message_length)?; + state.add_end(message_length)?; + } + } + } + } + Ok(state.end()) + } +} + +impl Encoder for Frame { + fn encoded_len(&mut self) -> Result { + let body_len = self.preencode(&mut State::new())?; + match self { + Self::RawBatch(_) => Ok(body_len), + Self::MessageBatch(_) => Ok(3 + body_len), + } + } + + fn encode(&mut self, buf: &mut [u8]) -> Result { + let mut state = State::new(); + let header_len = if let Self::RawBatch(_) = self { 0 } else { 3 }; + let body_len = self.preencode(&mut state)?; + let len = body_len + header_len; + if buf.len() < len { + return Err(EncodingError::new( + EncodingErrorKind::Overflow, + &format!("Length does not fit buffer, {} > {}", len, buf.len()), + )); + } + match self { + Self::RawBatch(ref raw_batch) => { + for raw in raw_batch { + raw.as_slice().encode(buf)?; + } + } + #[allow(clippy::comparison_chain)] + Self::MessageBatch(ref mut messages) => { + write_uint24_le(body_len, buf); + let buf = buf.get_mut(3..).expect("Buffer should be over 3 bytes"); + if messages.len() == 1 { + if let Message::Open(_) = &messages[0].message { + // This is a special case with 0x00, 0x01 intro bytes + state.encode(&(0_u8), buf)?; + state.encode(&(1_u8), buf)?; + state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; + } else if let Message::Close(_) = &messages[0].message { + // This is a special case with 0x00, 0x03 intro bytes + state.encode(&(0_u8), buf)?; + state.encode(&(3_u8), buf)?; + state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; + } else { + state.encode(&messages[0].channel, buf)?; + state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; + } + } else if messages.len() > 1 { + // Two intro bytes 0x00 0x00, then channel id, then lengths + state.set_slice_to_buffer(&[0_u8, 0_u8], buf)?; + let mut current_channel: u64 = messages[0].channel; + state.encode(¤t_channel, buf)?; + for message in messages.iter_mut() { + if message.channel != current_channel { + // Channel changed, need to add a 0x00 in between and then the new + // channel + state.encode(&(0_u8), buf)?; + state.encode(&message.channel, buf)?; + current_channel = message.channel; + } + let message_length = message.encoded_len()?; + state.encode(&message_length, buf)?; + state.add_start(message.encode(&mut buf[state.start()..])?)?; + } + } + } + }; + Ok(len) + } +} + +/// A protocol message. +#[derive(Debug, Clone, PartialEq)] +#[allow(missing_docs)] +pub enum Message { + Open(Open), + Close(Close), + Synchronize(Synchronize), + Request(Request), + Cancel(Cancel), + Data(Data), + NoData(NoData), + Want(Want), + Unwant(Unwant), + Bitfield(Bitfield), + Range(Range), + Extension(Extension), + /// A local signalling message never sent over the wire + LocalSignal((String, Vec)), +} + +impl Message { + /// Wire type of this message. + pub(crate) fn typ(&self) -> u64 { + match self { + Self::Synchronize(_) => 0, + Self::Request(_) => 1, + Self::Cancel(_) => 2, + Self::Data(_) => 3, + Self::NoData(_) => 4, + Self::Want(_) => 5, + Self::Unwant(_) => 6, + Self::Bitfield(_) => 7, + Self::Range(_) => 8, + Self::Extension(_) => 9, + value => unimplemented!("{} does not have a type", value), + } + } + + /// Decode a message from a buffer based on type. + pub(crate) fn decode(buf: &[u8], typ: u64) -> Result<(Self, usize), EncodingError> { + let mut state = HypercoreState::from_buffer(buf); + let message = match typ { + 0 => Ok(Self::Synchronize((*state).decode(buf)?)), + 1 => Ok(Self::Request(state.decode(buf)?)), + 2 => Ok(Self::Cancel((*state).decode(buf)?)), + 3 => Ok(Self::Data(state.decode(buf)?)), + 4 => Ok(Self::NoData((*state).decode(buf)?)), + 5 => Ok(Self::Want((*state).decode(buf)?)), + 6 => Ok(Self::Unwant((*state).decode(buf)?)), + 7 => Ok(Self::Bitfield((*state).decode(buf)?)), + 8 => Ok(Self::Range((*state).decode(buf)?)), + 9 => Ok(Self::Extension((*state).decode(buf)?)), + _ => Err(EncodingError::new( + EncodingErrorKind::InvalidData, + &format!("Invalid message type to decode: {typ}"), + )), + }?; + Ok((message, state.start())) + } + + /// Pre-encodes a message to state, returns length + pub(crate) fn preencode(&self, state: &mut HypercoreState) -> Result { + match self { + Self::Open(ref message) => state.0.preencode(message)?, + Self::Close(ref message) => state.0.preencode(message)?, + Self::Synchronize(ref message) => state.0.preencode(message)?, + Self::Request(ref message) => state.preencode(message)?, + Self::Cancel(ref message) => state.0.preencode(message)?, + Self::Data(ref message) => state.preencode(message)?, + Self::NoData(ref message) => state.0.preencode(message)?, + Self::Want(ref message) => state.0.preencode(message)?, + Self::Unwant(ref message) => state.0.preencode(message)?, + Self::Bitfield(ref message) => state.0.preencode(message)?, + Self::Range(ref message) => state.0.preencode(message)?, + Self::Extension(ref message) => state.0.preencode(message)?, + Self::LocalSignal(_) => 0, + }; + Ok(state.end()) + } + + /// Encodes a message to a given buffer, using preencoded state, results size + pub(crate) fn encode( + &self, + state: &mut HypercoreState, + buf: &mut [u8], + ) -> Result { + match self { + Self::Open(ref message) => state.0.encode(message, buf)?, + Self::Close(ref message) => state.0.encode(message, buf)?, + Self::Synchronize(ref message) => state.0.encode(message, buf)?, + Self::Request(ref message) => state.encode(message, buf)?, + Self::Cancel(ref message) => state.0.encode(message, buf)?, + Self::Data(ref message) => state.encode(message, buf)?, + Self::NoData(ref message) => state.0.encode(message, buf)?, + Self::Want(ref message) => state.0.encode(message, buf)?, + Self::Unwant(ref message) => state.0.encode(message, buf)?, + Self::Bitfield(ref message) => state.0.encode(message, buf)?, + Self::Range(ref message) => state.0.encode(message, buf)?, + Self::Extension(ref message) => state.0.encode(message, buf)?, + Self::LocalSignal(_) => 0, + }; + Ok(state.start()) + } +} + +impl fmt::Display for Message { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Open(msg) => write!( + f, + "Open(discovery_key: {}, capability <{}>)", + pretty_fmt(&msg.discovery_key).unwrap(), + msg.capability.as_ref().map_or(0, |c| c.len()) + ), + Self::Data(msg) => write!( + f, + "Data(request: {}, fork: {}, block: {}, hash: {}, seek: {}, upgrade: {})", + msg.request, + msg.fork, + msg.block.is_some(), + msg.hash.is_some(), + msg.seek.is_some(), + msg.upgrade.is_some(), + ), + _ => write!(f, "{:?}", &self), + } + } +} + +/// A message on a channel. +#[derive(Clone)] +pub(crate) struct ChannelMessage { + pub(crate) channel: u64, + pub(crate) message: Message, + state: Option, +} + +impl PartialEq for ChannelMessage { + fn eq(&self, other: &Self) -> bool { + self.channel == other.channel && self.message == other.message + } +} + +impl fmt::Debug for ChannelMessage { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "ChannelMessage({}, {})", self.channel, self.message) + } +} + +impl ChannelMessage { + /// Create a new message. + pub(crate) fn new(channel: u64, message: Message) -> Self { + Self { + channel, + message, + state: None, + } + } + + /// Consume self and return (channel, Message). + pub(crate) fn into_split(self) -> (u64, Message) { + (self.channel, self.message) + } + + /// Decodes an open message for a channel message from a buffer. + /// + /// Note: `buf` has to have a valid length, and without the 3 LE + /// bytes in it + pub(crate) fn decode_open_message(buf: &[u8]) -> io::Result<(Self, usize)> { + if buf.len() <= 5 { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "received too short Open message", + )); + } + + let mut state = State::new_with_start_and_end(0, buf.len()); + let open_msg: Open = state.decode(buf)?; + Ok(( + Self { + channel: open_msg.channel, + message: Message::Open(open_msg), + state: None, + }, + state.start(), + )) + } + + /// Decodes a close message for a channel message from a buffer. + /// + /// Note: `buf` has to have a valid length, and without the 3 LE + /// bytes in it + pub(crate) fn decode_close_message(buf: &[u8]) -> io::Result<(Self, usize)> { + if buf.is_empty() { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "received too short Close message", + )); + } + let mut state = State::new_with_start_and_end(0, buf.len()); + let close_msg: Close = state.decode(buf)?; + Ok(( + Self { + channel: close_msg.channel, + message: Message::Close(close_msg), + state: None, + }, + state.start(), + )) + } + + /// Decode a normal channel message from a buffer. + /// + /// Note: `buf` has to have a valid length, and without the 3 LE + /// bytes in it + pub(crate) fn decode(buf: &[u8], channel: u64) -> io::Result<(Self, usize)> { + if buf.len() <= 1 { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "received empty message", + )); + } + let mut state = State::from_buffer(buf); + let typ: u64 = state.decode(buf)?; + let (message, length) = Message::decode(&buf[state.start()..], typ)?; + Ok(( + Self { + channel, + message, + state: None, + }, + state.start() + length, + )) + } + + /// Performance optimization for letting calling encoded_len() already do + /// the preencode phase of compact_encoding. + fn prepare_state(&mut self) -> Result<(), EncodingError> { + if self.state.is_none() { + let state = if let Message::Open(_) = self.message { + // Open message doesn't have a type + // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L41 + let mut state = HypercoreState::new(); + self.message.preencode(&mut state)?; + state + } else if let Message::Close(_) = self.message { + // Close message doesn't have a type + // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L162 + let mut state = HypercoreState::new(); + self.message.preencode(&mut state)?; + state + } else { + // The header is the channel id uint followed by message type uint + // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L179 + let mut state = HypercoreState::new(); + let typ = self.message.typ(); + (*state).preencode(&typ)?; + self.message.preencode(&mut state)?; + state + }; + self.state = Some(state); + } + Ok(()) + } +} + +impl Encoder for ChannelMessage { + fn encoded_len(&mut self) -> Result { + self.prepare_state()?; + Ok(self.state.as_ref().unwrap().end()) + } + + fn encode(&mut self, buf: &mut [u8]) -> Result { + self.prepare_state()?; + let state = self.state.as_mut().unwrap(); + if let Message::Open(_) = self.message { + // Open message is different in that the type byte is missing + self.message.encode(state, buf)?; + } else if let Message::Close(_) = self.message { + // Close message is different in that the type byte is missing + self.message.encode(state, buf)?; + } else { + let typ = self.message.typ(); + state.0.encode(&typ, buf)?; + self.message.encode(state, buf)?; + } + Ok(state.start()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use hypercore::{ + DataBlock, DataHash, DataSeek, DataUpgrade, Node, RequestBlock, RequestSeek, RequestUpgrade, + }; + + macro_rules! message_enc_dec { + ($( $msg:expr ),*) => { + $( + let channel = rand::random::() as u64; + let mut channel_message = ChannelMessage::new(channel, $msg); + let encoded_len = channel_message.encoded_len().expect("Failed to get encoded length"); + let mut buf = vec![0u8; encoded_len]; + let n = channel_message.encode(&mut buf[..]).expect("Failed to encode message"); + let decoded = ChannelMessage::decode(&buf[..n], channel).expect("Failed to decode message").0.into_split(); + assert_eq!(channel, decoded.0); + assert_eq!($msg, decoded.1); + )* + } + } + #[test] + fn frame_encode_decode() -> std::io::Result<()> { + let msg = Message::Synchronize(Synchronize { + fork: 0, + can_upgrade: true, + downloading: true, + uploading: true, + length: 5, + remote_length: 0, + }); + + let channel = rand::random::() as u64; + let channel_message = ChannelMessage::new(channel, msg); + + let mut frame = Frame::from(channel_message); + let mut buf = vec![0; frame.encoded_len()?]; + frame.encode(&mut buf)?; + let res_frame = Frame::decode_multiple(&buf, &FrameType::Message)?; + assert_eq!(res_frame, frame); + Ok(()) + } + #[test] + fn frame_encode_decode_bar() -> std::io::Result<()> { + let msg = Message::Synchronize(Synchronize { + fork: 0, + can_upgrade: true, + downloading: true, + uploading: true, + length: 5, + remote_length: 0, + }); + + //let channel = rand::random::() as u64; + let channel = 42; + let channel_message = ChannelMessage::new(channel, msg); + + let mut frame = Frame::from(channel_message.clone()); + + let mut fbuf = vec![0; frame.encoded_len()?]; + + frame.encode(&mut fbuf)?; + + let fres = Frame::decode_multiple(&fbuf, &FrameType::Message)?; + assert_eq!(fres, frame); + ///assert_eq!(cres, cmvec); + //println!("REG frame buf\t{frame_buf:02X?}"); + //let res_frame = Frame::decode(&frame_buf, &FrameType::Message)?; + //dbg!(res_frame); + //let res_frame = Frame::decode_multiple(&frame_buf, &FrameType::Message)?; + //dbg!(res_frame); + + //let mut vec_frame_buf = vec![0; vec_frame.encoded_len()?]; + //vec_frame.encode(&mut vec_frame_buf)?; + + //assert_eq!(vec_frame_buf, frame_buf); + //println!("VEC frame buf\t{vec_frame_buf:02X?}"); + + //let res_frame = Frame::decode(&vec_frame_buf, &FrameType::Message)?; + //dbg!(res_frame); + //let res_frame = Frame::decode_multiple(&vec_frame_buf, &FrameType::Message)?; + //dbg!(&res_frame); + + //let (msg, _len) = decode_channel_messages(&vec_frame_buf)?; + //assert_eq!(msg, vec![channel_message]); + + //assert_eq!(res_frame, frame); + Ok(()) + } + + #[test] + fn message_encode_decode() { + message_enc_dec! { + Message::Synchronize(Synchronize{ + fork: 0, + can_upgrade: true, + downloading: true, + uploading: true, + length: 5, + remote_length: 0, + }), + Message::Request(Request { + id: 1, + fork: 1, + block: Some(RequestBlock { + index: 5, + nodes: 10, + }), + hash: Some(RequestBlock { + index: 20, + nodes: 0 + }), + seek: Some(RequestSeek { + bytes: 10 + }), + upgrade: Some(RequestUpgrade { + start: 0, + length: 10 + }) + }), + Message::Cancel(Cancel { + request: 1, + }), + Message::Data(Data{ + request: 1, + fork: 5, + block: Some(DataBlock { + index: 5, + nodes: vec![Node::new(1, vec![0x01; 32], 100)], + value: vec![0xFF; 10] + }), + hash: Some(DataHash { + index: 20, + nodes: vec![Node::new(2, vec![0x02; 32], 200)], + }), + seek: Some(DataSeek { + bytes: 10, + nodes: vec![Node::new(3, vec![0x03; 32], 300)], + }), + upgrade: Some(DataUpgrade { + start: 0, + length: 10, + nodes: vec![Node::new(4, vec![0x04; 32], 400)], + additional_nodes: vec![Node::new(5, vec![0x05; 32], 500)], + signature: vec![0xAB; 32] + }) + }), + Message::NoData(NoData { + request: 2, + }), + Message::Want(Want { + start: 0, + length: 100, + }), + Message::Unwant(Unwant { + start: 10, + length: 2, + }), + Message::Bitfield(Bitfield { + start: 20, + bitfield: vec![0x89ABCDEF, 0x00, 0xFFFFFFFF], + }), + Message::Range(Range { + drop: true, + start: 12345, + length: 100000 + }), + Message::Extension(Extension { + name: "custom_extension/v1/open".to_string(), + message: vec![0x44, 20] + }) + }; + } +} diff --git a/src/protocol/modern.rs b/src/protocol/modern.rs index cb71f50..acbdc05 100644 --- a/src/protocol/modern.rs +++ b/src/protocol/modern.rs @@ -13,8 +13,8 @@ use tracing::instrument; use crate::channels::{Channel, ChannelMap}; use crate::constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}; -use crate::crypto::{EncryptCipher, Handshake, HandshakeResult}; -use crate::message::{ChannelMessage, Frame, Message}; +use crate::crypto::HandshakeResult; +use crate::message::{ChannelMessage, Message}; use crate::mqueue::MessageIo; use crate::util::{map_channel_err, pretty_hash}; use crate::{ @@ -30,7 +30,6 @@ macro_rules! return_error { } const CHANNEL_CAP: usize = 1000; -const KEEPALIVE_DURATION: Duration = Duration::from_secs(DEFAULT_KEEPALIVE as u64); /// Options for a Protocol instance. #[derive(Debug)] @@ -112,32 +111,9 @@ impl fmt::Debug for Event { } } -/// Protocol state -#[allow(clippy::large_enum_variant)] -pub(crate) enum State { - NotInitialized, - // The Handshake struct sits behind an option only so that we can .take() - // it out, it's never actually empty when in State::Handshake. - Handshake(Option), - SecretStream(Option), - Established, -} - -impl fmt::Debug for State { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - State::NotInitialized => write!(f, "NotInitialized"), - State::Handshake(_) => write!(f, "Handshaking"), - State::SecretStream(_) => write!(f, "SecretStream"), - State::Established => write!(f, "Established"), - } - } -} - /// A Protocol stream. pub struct Protocol { io: MessageIo>>, - state: State, options: Options, handshake: Option, channels: ChannelMap, @@ -153,7 +129,6 @@ impl std::fmt::Debug for Protocol { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Protocol") //.field("io", &self.io) - .field("state", &self.state) .field("options", &self.options) .field("handshake", &self.handshake) .field("channels", &self.channels) @@ -182,7 +157,6 @@ where Protocol { io: MessageIo::new(encrypted_framed_message_channel(options.is_initiator, io)), options, - state: State::NotInitialized, channels: ChannelMap::new(), handshake: None, command_rx, @@ -272,6 +246,7 @@ where if let Some(event) = this.queued_events.pop_front() { Poll::Ready(Ok(event)) } else { + cx.waker().wake_by_ref(); Poll::Pending } } @@ -287,6 +262,7 @@ where /// TODO Poll the keepalive timer and queue a ping message if needed. fn poll_keepalive(&mut self, _cx: &mut Context<'_>) { /* + const KEEPALIVE_DURATION: Duration = Duration::from_secs(DEFAULT_KEEPALIVE as u64); if Pin::new(&mut self.keepalive).poll(cx).is_ready() { if let State::Established = self.state { // 24 bit header for the empty message, hence the 3 @@ -295,7 +271,7 @@ where } self.keepalive.reset(KEEPALIVE_DURATION); } - */ + */ } // just handles Close and LocalSignal?? @@ -322,6 +298,7 @@ where } /// Poll for inbound messages and processs them. + #[instrument(skip_all)] fn poll_inbound_read(&mut self, cx: &mut Context<'_>) -> Result<()> { loop { match self.io.poll_inbound(cx) { @@ -336,6 +313,7 @@ where /// Poll for outbound messages and write them. /// Reads messages from Self::outbound and sends them over io + #[instrument(skip_all)] fn poll_outbound_write(&mut self, cx: &mut Context<'_>) -> Result<()> { loop { // if no parking or setup in progress @@ -365,6 +343,7 @@ where Ok(()) } + #[instrument(skip_all)] fn on_inbound_message(&mut self, channel_message: ChannelMessage) -> Result<()> { // let channel_message = ChannelMessage::decode(buf)?; let (remote_id, message) = channel_message.into_split(); @@ -387,6 +366,7 @@ where } /// Open a Channel with the given key. Adding it to our channel map + #[instrument(skip_all)] fn command_open(&mut self, key: Key) -> Result<()> { // Create a new channel. let channel_handle = self.channels.attach_local(key); diff --git a/src/protocol/old.rs b/src/protocol/old.rs index 01af713..2c7d4c5 100644 --- a/src/protocol/old.rs +++ b/src/protocol/old.rs @@ -10,7 +10,7 @@ use std::io::{self, Error, ErrorKind, Result}; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; -use tracing::trace; +use tracing::{instrument, trace}; use crate::channels::{Channel, ChannelMap}; use crate::constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}; @@ -252,6 +252,7 @@ where self.io } + #[instrument(skip_all)] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); @@ -357,6 +358,7 @@ where } /// Poll for inbound messages and processs them. + #[instrument(skip_all)] fn poll_inbound_read(&mut self, cx: &mut Context<'_>) -> Result<()> { loop { let msg = self.read_state.poll_reader(cx, &mut self.io); @@ -371,6 +373,7 @@ where } /// Poll for outbound messages and write them. + #[instrument(skip_all)] fn poll_outbound_write(&mut self, cx: &mut Context<'_>) -> Result<()> { loop { if let Poll::Ready(Err(e)) = self.write_state.poll_send(cx, &mut self.io) { @@ -398,6 +401,7 @@ where } } + #[instrument(skip_all)] fn on_inbound_frame(&mut self, frame: Frame) -> Result<()> { match frame { Frame::RawBatch(raw_batch) => { @@ -539,6 +543,7 @@ where } /// Open a Channel with the given key. Adding it to our channel map + #[instrument(skip_all)] fn command_open(&mut self, key: Key) -> Result<()> { // Create a new channel. let channel_handle = self.channels.attach_local(key); diff --git a/src/test_utils.rs b/src/test_utils.rs index b12f440..e67d756 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -73,7 +73,7 @@ impl TwoWay { } } -pub fn log() { +pub(crate) fn log() { use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; static START_LOGS: std::sync::OnceLock<()> = std::sync::OnceLock::new(); START_LOGS.get_or_init(|| { diff --git a/src/util.rs b/src/util.rs index 1350728..21e4c75 100644 --- a/src/util.rs +++ b/src/util.rs @@ -29,7 +29,33 @@ pub(crate) fn map_channel_err(err: async_channel::SendError) -> Error { } pub(crate) const UINT_24_LENGTH: usize = 3; +#[cfg(feature = "uint24")] +mod uint24 { + use super::UINT_24_LENGTH; + pub struct Uint24LE([u8; UINT_24_LENGTH]); + impl Uint24LE { + pub const MAX_USIZE: usize = 16777215; + pub const SIZE: usize = UINT_24_LENGTH; + } + + impl AsRef<[u8; 3]> for Uint24LE { + fn as_ref(&self) -> &[u8; 3] { + &self.0 + } + } + // TODO we are using std::io::Error everywhere so I won't add a new one but this isn't ideal + impl TryFrom for Uint24LE { + type Error = Error; + + fn try_from(n: usize) -> Result { + if n > Self::MAX_USIZE { + todo!() + } + Ok(Self([(n & 255) as u8, (n >> 8) as u8, (n >> 16) as u8])) + } + } +} #[inline] pub(crate) fn wrap_uint24_le(data: &[u8]) -> Vec { let mut buf: Vec = vec![0; 3]; diff --git a/tests/_util.rs b/tests/_util.rs index d1fa197..aec496d 100644 --- a/tests/_util.rs +++ b/tests/_util.rs @@ -4,6 +4,7 @@ use futures_lite::StreamExt; use hypercore_protocol::{Channel, DiscoveryKey, Duplex, Event, Protocol, ProtocolBuilder}; use instant::Duration; use std::io; +use tokio::io::DuplexStream; use tokio::task::JoinHandle; #[allow(unused)] @@ -23,7 +24,16 @@ pub(crate) fn log() { }); } +type TokioDuplex = tokio_util::compat::Compat; + +pub(crate) fn duplex(channel_size: usize) -> (TokioDuplex, TokioDuplex) { + use tokio_util::compat::TokioAsyncReadCompatExt as _; + let (left, right) = tokio::io::duplex(channel_size); + (left.compat(), right.compat()) +} + pub type MemoryProtocol = Protocol>; + pub async fn create_pair_memory() -> io::Result<(MemoryProtocol, MemoryProtocol)> { let (ar, bw) = sluice::pipe::pipe(); let (br, aw) = sluice::pipe::pipe(); @@ -35,6 +45,15 @@ pub async fn create_pair_memory() -> io::Result<(MemoryProtocol, MemoryProtocol) Ok((a, b)) } +pub async fn create_pair_memory2() -> io::Result<(Protocol, Protocol)> { + let (left, right) = duplex(1024 * 1024); + let a = ProtocolBuilder::new(true); + let b = ProtocolBuilder::new(false); + let a = a.connect(left); + let b = b.connect(right); + Ok((a, b)) +} + pub fn next_event(mut proto: Protocol) -> JoinHandle<(Protocol, io::Result)> where IO: AsyncRead + AsyncWrite + Send + Unpin + 'static, diff --git a/tests/basic.rs b/tests/basic.rs index 92e4c8b..a102bc0 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -1,5 +1,6 @@ use _util::{ - create_pair_memory, drive_until_channel, event_channel, event_discovery_key, next_event, + create_pair_memory, create_pair_memory2, drive_until_channel, event_channel, + event_discovery_key, next_event, }; use futures_lite::StreamExt; use hypercore_protocol::{discovery_key, Event, Message}; @@ -170,7 +171,6 @@ async fn open_close_channels() -> anyhow::Result<()> { assert_eq!(msg_b, Some(want(0, 10))); eprintln!("all good!"); - Ok(()) } From fae480439eabab3d849946b8d24eeb57fac668de Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 25 Mar 2025 12:31:37 -0400 Subject: [PATCH 053/206] doc comments --- src/reader.rs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/reader.rs b/src/reader.rs index 5664d56..cc80c5c 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -116,6 +116,7 @@ impl ReadState { // What happens if decrypt fails here? // next call to this func would have same start, corret? // so it'd fail repeatedly? + // Why not just decrypt to the end? for (index, header_len, body_len) in segments { let de = cipher.decrypt( &mut self.buf[self.start + index..end], @@ -144,10 +145,13 @@ impl ReadState { fn cycle_buf_and_resize_if_needed(&mut self, last_segment: (usize, usize, usize)) { let (last_index, last_header_len, last_body_len) = last_segment; let total_incoming_length = last_index + last_header_len + last_body_len; + if self.buf.len() < total_incoming_length { // The incoming segments will not fit into the buffer, need to resize it self.buf.resize(total_incoming_length, 0u8); } + + // to-read length let temp = self.buf[self.start..].to_vec(); let len = temp.len(); self.buf[..len].copy_from_slice(&temp[..]); @@ -187,17 +191,21 @@ impl ReadState { } } + // one message within an encrypted frame + // encrypted frame [ u24 header + encoded_frame [ ]] Step::Body { header_len, body_len, } => { let message_len = header_len + body_len; let range = self.start + header_len..self.start + message_len; + // this includes a a frame header let frame = Frame::decode(&self.buf[range], &self.frame_type); self.start += message_len; self.step = Step::Header; return Some(frame); } + // multiple message within an encrypted frame Step::Batch => { let frame = Frame::decode_multiple(&self.buf[self.start..self.end], &self.frame_type); @@ -211,7 +219,9 @@ impl ReadState { } #[allow(clippy::type_complexity)] -// get segments from buff +/// Given a buff get all the segments (starting_index_in_buffer, header_len, buffer_len) +/// returns returns `(true, segments)` if we read all segments, but (false, ..) if there +/// are remaining segments fn create_segments(buf: &[u8]) -> Result<(bool, Vec<(usize, usize, usize)>)> { let mut index: usize = 0; let len = buf.len(); From b019e9f143e8c084dc3cb33743b3687d99964cc2 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 25 Mar 2025 12:43:51 -0400 Subject: [PATCH 054/206] feature gate oldmessage --- Cargo.toml | 2 ++ src/lib.rs | 3 +-- src/message/mod.rs | 11 +++++++++++ src/{message.rs => message/modern.rs} | 0 src/{oldmessage.rs => message/old.rs} | 2 +- 5 files changed, 15 insertions(+), 3 deletions(-) create mode 100644 src/message/mod.rs rename src/{message.rs => message/modern.rs} (100%) rename src/{oldmessage.rs => message/old.rs} (99%) diff --git a/Cargo.toml b/Cargo.toml index 7c15c8b..7aeb9e3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,6 +66,8 @@ tokio-util = { version = "0.7.14", features = ["compat"] } [features] default = ["tokio", "sparse", "protocol"] +#default = ["tokio", "sparse"] +uint24 = [] protocol = [] wasm-bindgen = [ "futures-timer/wasm-bindgen" diff --git a/src/lib.rs b/src/lib.rs index 999aa26..1990b97 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -124,10 +124,9 @@ mod crypto; mod duplex; mod framing; mod message; +#[cfg(feature = "protocol")] mod mqueue; mod noise; -#[cfg(not(feature = "protocol"))] -mod oldmessage; mod protocol; #[cfg(not(feature = "protocol"))] mod reader; diff --git a/src/message/mod.rs b/src/message/mod.rs new file mode 100644 index 0000000..1526f3a --- /dev/null +++ b/src/message/mod.rs @@ -0,0 +1,11 @@ +#[cfg(feature = "protocol")] +mod modern; + +#[cfg(feature = "protocol")] +pub use modern::*; + +#[cfg(not(feature = "protocol"))] +mod old; + +#[cfg(not(feature = "protocol"))] +pub use old::*; diff --git a/src/message.rs b/src/message/modern.rs similarity index 100% rename from src/message.rs rename to src/message/modern.rs diff --git a/src/oldmessage.rs b/src/message/old.rs similarity index 99% rename from src/oldmessage.rs rename to src/message/old.rs index 8cb2c61..373eea2 100644 --- a/src/oldmessage.rs +++ b/src/message/old.rs @@ -703,7 +703,7 @@ mod tests { let fres = Frame::decode_multiple(&fbuf, &FrameType::Message)?; assert_eq!(fres, frame); - ///assert_eq!(cres, cmvec); + //assert_eq!(cres, cmvec); //println!("REG frame buf\t{frame_buf:02X?}"); //let res_frame = Frame::decode(&frame_buf, &FrameType::Message)?; //dbg!(res_frame); From 24fe4a1f2fbf2e9373f413c29760a6e3f8471490 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 25 Mar 2025 13:19:07 -0400 Subject: [PATCH 055/206] fix all warnings --- src/constants.rs | 12 +- src/crypto/cipher.rs | 168 +++++++++++++------------- src/crypto/handshake.rs | 5 +- src/crypto/mod.rs | 4 + src/message/modern.rs | 257 +++++++++++++++++++--------------------- src/mqueue.rs | 13 +- src/protocol/modern.rs | 5 - src/protocol/old.rs | 5 - src/writer.rs | 4 - tests/basic.rs | 14 +-- 10 files changed, 232 insertions(+), 255 deletions(-) diff --git a/src/constants.rs b/src/constants.rs index 77285ee..73d0748 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -1,15 +1,17 @@ /// Seed for the discovery key hash pub(crate) const DISCOVERY_NS_BUF: &[u8] = b"hypercore"; -/// Default timeout (in seconds) -pub(crate) const DEFAULT_TIMEOUT: u32 = 20; - /// Default keepalive interval (in seconds) pub(crate) const DEFAULT_KEEPALIVE: u32 = 10; +/// v10: Protocol name +pub(crate) const PROTOCOL_NAME: &str = "hypercore/alpha"; + // 16,78MB is the max encrypted wire message size (will be much smaller usually). // This limitation stems from the 24bit header. +#[cfg(not(feature = "protocol"))] pub(crate) const MAX_MESSAGE_SIZE: u64 = 0xFFFFFF; -/// v10: Protocol name -pub(crate) const PROTOCOL_NAME: &str = "hypercore/alpha"; +/// Default timeout (in seconds) +#[cfg(not(feature = "protocol"))] +pub(crate) const DEFAULT_TIMEOUT: u32 = 20; diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index f2dc9b9..8ef6c9e 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -1,5 +1,4 @@ use super::HandshakeResult; -use crate::util::{stat_uint24_le, write_uint24_le, UINT_24_LENGTH}; use blake2::{ digest::{typenum::U32, FixedOutput, Update}, Blake2bMac, @@ -11,28 +10,17 @@ use std::io; const STREAM_ID_LENGTH: usize = 32; const KEY_LENGTH: usize = 32; -const HEADER_MSG_LEN: usize = UINT_24_LENGTH + STREAM_ID_LENGTH + Header::BYTES; pub(crate) struct DecryptCipher { pull_stream: PullStream, } -pub(crate) struct EncryptCipher { - push_stream: PushStream, -} - impl std::fmt::Debug for DecryptCipher { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "DecryptCipher(crypto_secretstream)") } } -impl std::fmt::Debug for EncryptCipher { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "EncryptCipher(crypto_secretstream)") - } -} - impl DecryptCipher { pub(crate) fn from_handshake_rx_and_init_msg( handshake_result: &HandshakeResult, @@ -75,25 +63,6 @@ impl DecryptCipher { let pull_stream = PullStream::init(Header::from(header), &key); Ok(Self { pull_stream }) } - - pub(crate) fn decrypt( - &mut self, - buf: &mut [u8], - header_len: usize, - body_len: usize, - ) -> io::Result { - let (to_decrypt, _tag) = self.decrypt_buf(&buf[header_len..header_len + body_len])?; - let decrypted_len = to_decrypt.len(); - write_uint24_le(decrypted_len, buf); - let decrypted_end = header_len + to_decrypt.len(); - buf[header_len..decrypted_end].copy_from_slice(to_decrypt.as_slice()); - // Set extra bytes in the buffer to 0 - // Why? - let encrypted_end = header_len + body_len; - buf[decrypted_end..encrypted_end].fill(0x00); - Ok(decrypted_end) - } - pub(crate) fn decrypt_buf(&mut self, buf: &[u8]) -> io::Result<(Vec, Tag)> { let mut to_decrypt = buf.to_vec(); let tag = &self.pull_stream.pull(&mut to_decrypt, &[]).map_err(|err| { @@ -103,63 +72,102 @@ impl DecryptCipher { } } -impl EncryptCipher { - pub(crate) fn from_handshake_tx( - handshake_result: &HandshakeResult, - ) -> std::io::Result<(Self, Vec)> { - let key: [u8; KEY_LENGTH] = handshake_result.split_tx[..KEY_LENGTH] - .try_into() - .expect("split_tx with incorrect length"); - let key = Key::from(key); +#[cfg(not(feature = "protocol"))] +mod encrypt_cipher { + use super::*; + use crate::util::{stat_uint24_le, write_uint24_le, UINT_24_LENGTH}; + const HEADER_MSG_LEN: usize = UINT_24_LENGTH + STREAM_ID_LENGTH + Header::BYTES; - let mut header_message: [u8; HEADER_MSG_LEN] = [0; HEADER_MSG_LEN]; - write_uint24_le(STREAM_ID_LENGTH + Header::BYTES, &mut header_message); - write_stream_id( - &handshake_result.handshake_hash, - handshake_result.is_initiator, - &mut header_message[UINT_24_LENGTH..UINT_24_LENGTH + STREAM_ID_LENGTH], - ); + pub(crate) struct EncryptCipher { + push_stream: PushStream, + } - let (header, push_stream) = PushStream::init(OsRng, &key); - let header = header.as_ref(); - header_message[UINT_24_LENGTH + STREAM_ID_LENGTH..].copy_from_slice(header); - let msg = header_message.to_vec(); - Ok((Self { push_stream }, msg)) + impl std::fmt::Debug for EncryptCipher { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "EncryptCipher(crypto_secretstream)") + } } - /// Get the length needed for encryption, that includes padding. - pub(crate) fn safe_encrypted_len(&self, plaintext_len: usize) -> usize { - // ChaCha20-Poly1305 uses padding in two places, use two 15 bytes as a safe - // extra room. - // https://mailarchive.ietf.org/arch/msg/cfrg/u734TEOSDDWyQgE0pmhxjdncwvw/ - plaintext_len + 2 * 15 + impl EncryptCipher { + pub(crate) fn from_handshake_tx( + handshake_result: &HandshakeResult, + ) -> std::io::Result<(Self, Vec)> { + let key: [u8; KEY_LENGTH] = handshake_result.split_tx[..KEY_LENGTH] + .try_into() + .expect("split_tx with incorrect length"); + let key = Key::from(key); + + let mut header_message: [u8; HEADER_MSG_LEN] = [0; HEADER_MSG_LEN]; + write_uint24_le(STREAM_ID_LENGTH + Header::BYTES, &mut header_message); + write_stream_id( + &handshake_result.handshake_hash, + handshake_result.is_initiator, + &mut header_message[UINT_24_LENGTH..UINT_24_LENGTH + STREAM_ID_LENGTH], + ); + + let (header, push_stream) = PushStream::init(OsRng, &key); + let header = header.as_ref(); + header_message[UINT_24_LENGTH + STREAM_ID_LENGTH..].copy_from_slice(header); + let msg = header_message.to_vec(); + Ok((Self { push_stream }, msg)) + } + + /// Get the length needed for encryption, that includes padding. + pub(crate) fn safe_encrypted_len(&self, plaintext_len: usize) -> usize { + // ChaCha20-Poly1305 uses padding in two places, use two 15 bytes as a safe + // extra room. + // https://mailarchive.ietf.org/arch/msg/cfrg/u734TEOSDDWyQgE0pmhxjdncwvw/ + plaintext_len + 2 * 15 + } + + /// Encrypts message in the given buffer to the same buffer, returns number of bytes + /// of total message. + /// NB: we expect the first 3 bytes of the buffer to a size header. + /// The encrypted buffer will also be written prepended with a size header, with it's new size. + pub(crate) fn encrypt(&mut self, buf: &mut [u8]) -> io::Result { + let stat = stat_uint24_le(buf); + if let Some((header_len, body_len)) = stat { + let mut to_encrypt = buf[header_len..header_len + body_len as usize].to_vec(); + self.push_stream + .push(&mut to_encrypt, &[], Tag::Message) + .map_err(|err| { + io::Error::new(io::ErrorKind::Other, format!("Encrypt failed: {err}")) + })?; + let encrypted_len = to_encrypt.len(); + write_uint24_le(encrypted_len, buf); + buf[header_len..header_len + encrypted_len].copy_from_slice(to_encrypt.as_slice()); + Ok(header_len + encrypted_len) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("Could not encrypt invalid data, len: {}", buf.len()), + )) + } + } } - /// Encrypts message in the given buffer to the same buffer, returns number of bytes - /// of total message. - /// NB: we expect the first 3 bytes of the buffer to a size header. - /// The encrypted buffer will also be written prepended with a size header, with it's new size. - pub(crate) fn encrypt(&mut self, buf: &mut [u8]) -> io::Result { - let stat = stat_uint24_le(buf); - if let Some((header_len, body_len)) = stat { - let mut to_encrypt = buf[header_len..header_len + body_len as usize].to_vec(); - self.push_stream - .push(&mut to_encrypt, &[], Tag::Message) - .map_err(|err| { - io::Error::new(io::ErrorKind::Other, format!("Encrypt failed: {err}")) - })?; - let encrypted_len = to_encrypt.len(); - write_uint24_le(encrypted_len, buf); - buf[header_len..header_len + encrypted_len].copy_from_slice(to_encrypt.as_slice()); - Ok(header_len + encrypted_len) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("Could not encrypt invalid data, len: {}", buf.len()), - )) + impl DecryptCipher { + pub(crate) fn decrypt( + &mut self, + buf: &mut [u8], + header_len: usize, + body_len: usize, + ) -> io::Result { + let (to_decrypt, _tag) = self.decrypt_buf(&buf[header_len..header_len + body_len])?; + let decrypted_len = to_decrypt.len(); + write_uint24_le(decrypted_len, buf); + let decrypted_end = header_len + to_decrypt.len(); + buf[header_len..decrypted_end].copy_from_slice(to_decrypt.as_slice()); + // Set extra bytes in the buffer to 0 + // Why? + let encrypted_end = header_len + body_len; + buf[decrypted_end..encrypted_end].fill(0x00); + Ok(decrypted_end) } } } +#[cfg(not(feature = "protocol"))] +pub use encrypt_cipher::*; // NB: These values come from Javascript-side // @@ -197,7 +205,7 @@ pub(crate) struct RawEncryptCipher { impl std::fmt::Debug for RawEncryptCipher { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "EncryptCipher(crypto_secretstream)") + write!(f, "RawEncryptCipher(crypto_secretstream)") } } diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index 74a1ada..fc5ad02 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -1,5 +1,4 @@ use super::curve::CurveResolver; -use crate::util::wrap_uint24_le; use blake2::{ digest::{typenum::U32, FixedOutput, Update}, Blake2bMac, @@ -111,8 +110,9 @@ impl Handshake { Ok(None) } } + #[cfg(not(feature = "protocol"))] pub(crate) fn start(&mut self) -> Result>> { - Ok(self.start_raw()?.map(|x| wrap_uint24_le(&x))) + Ok(self.start_raw()?.map(|x| crate::util::wrap_uint24_le(&x))) } pub(crate) fn complete(&self) -> bool { @@ -177,6 +177,7 @@ impl Handshake { Ok(tx_buf) } // reads in `msg` without framing bytes, but emits msg WITH framing bytes + #[cfg(not(feature = "protocol"))] pub(crate) fn read(&mut self, msg: &[u8]) -> Result>> { Ok(self.read_raw(msg)?.map(|x| wrap_uint24_le(&x))) } diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 27f12b4..3de592a 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -1,5 +1,9 @@ mod cipher; mod curve; mod handshake; +#[cfg(not(feature = "protocol"))] pub(crate) use cipher::{DecryptCipher, EncryptCipher, RawEncryptCipher}; + +#[cfg(feature = "protocol")] +pub(crate) use cipher::{DecryptCipher, RawEncryptCipher}; pub(crate) use handshake::{Handshake, HandshakeResult}; diff --git a/src/message/modern.rs b/src/message/modern.rs index 6d5200c..8b16988 100644 --- a/src/message/modern.rs +++ b/src/message/modern.rs @@ -10,12 +10,6 @@ use tracing::instrument; const UINT24_HEADER_LEN: usize = 3; -/// The type of a data frame. -#[derive(Debug, Clone, PartialEq)] -pub(crate) enum FrameType { - Message, -} - /// Encode data into a buffer. /// /// This trait is implemented on data frames and their components @@ -200,136 +194,6 @@ pub(crate) fn decode_one_channel_message( impl Frame { /// Decodes a frame from a buffer containing multiple concurrent messages. - pub(crate) fn decode_multiple(buf: &[u8], frame_type: &FrameType) -> Result { - match frame_type { - FrameType::Message => { - let mut index = 0; - let mut combined_messages: Vec = vec![]; - while index < buf.len() { - // There might be zero bytes in between, and with LE, the next message will - // start with a non-zero - if buf[index] == 0 { - index += 1; - continue; - } - - let stat = stat_uint24_le(&buf[index..]); - if let Some((header_len, body_len)) = stat { - let (frame, length) = Self::decode_message( - &buf[index + header_len..index + header_len + body_len as usize], - )?; - if length != body_len as usize { - tracing::warn!( - "Did not know what to do with all the bytes, got {} but decoded {}. \ - This may be because the peer implements a newer protocol version \ - that has extra fields.", - body_len, - length - ); - } - if let Frame::MessageBatch(messages) = frame { - for message in messages { - combined_messages.push(message); - } - } else { - unreachable!("Can not get Raw messages"); - } - index += header_len + body_len as usize; - } else { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "received invalid data in multi-message chunk", - )); - } - } - Ok(Frame::MessageBatch(combined_messages)) - } - } - } - - /// Decode a frame from a buffer. - pub(crate) fn decode(buf: &[u8], frame_type: &FrameType) -> Result { - match frame_type { - FrameType::Message => { - let (frame, _) = Self::decode_message(buf)?; - Ok(frame) - } - } - } - - fn decode_message(buf: &[u8]) -> Result<(Self, usize), io::Error> { - println!("decode_message {buf:02X?}"); - // buffer length >= 3 or more and starts with 0 is message batch - if buf.len() >= 3 && buf[0] == 0x00 { - if buf[1] == 0x00 { - // Batch of messages - let mut messages: Vec = vec![]; - let mut state = State::new_with_start_and_end(2, buf.len()); - - // First, there is the original channel - let mut current_channel: u64 = state.decode(buf)?; - while state.start() < state.end() { - // Length of the message is inbetween here - let channel_message_length: usize = state.decode(buf)?; - if state.start() + channel_message_length > state.end() { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!( - "received invalid message length, {} + {} > {}", - state.start(), - channel_message_length, - state.end() - ), - )); - } - // Then the actual message - let (channel_message, _) = ChannelMessage::decode( - &buf[state.start()..state.start() + channel_message_length], - current_channel, - )?; - messages.push(channel_message); - state.add_start(channel_message_length)?; - // After that, if there is an extra 0x00, that means the channel - // changed. This works because of LE encoding, and channels starting - // from the index 1. - if state.start() < state.end() && buf[state.start()] == 0x00 { - state.add_start(1)?; - current_channel = state.decode(buf)?; - } - } - Ok((Frame::MessageBatch(messages), state.start())) - } else if buf[1] == 0x01 { - // Open message - let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?; - Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) - } else if buf[1] == 0x03 { - // Close message - let (channel_message, length) = ChannelMessage::decode_close_message(&buf[2..])?; - Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - "received invalid special message", - )) - } - } else if buf.len() >= 2 { - // len >= and - // Single message - let mut state = State::from_buffer(buf); - let channel: u64 = state.decode(buf)?; - let (channel_message, length) = ChannelMessage::decode(&buf[state.start()..], channel)?; - Ok(( - Frame::MessageBatch(vec![channel_message]), - state.start() + length, - )) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("received too short message, {buf:02X?}"), - )) - } - } - fn preencode(&self, state: &mut State) -> Result { match self { Self::RawBatch(raw_batch) => { @@ -989,6 +853,125 @@ mod tests { }), ] } + impl Frame { + pub(crate) fn decode_multiple(buf: &[u8]) -> Result { + let mut index = 0; + let mut combined_messages: Vec = vec![]; + while index < buf.len() { + // There might be zero bytes in between, and with LE, the next message will + // start with a non-zero + if buf[index] == 0 { + index += 1; + continue; + } + + let stat = stat_uint24_le(&buf[index..]); + if let Some((header_len, body_len)) = stat { + let (frame, length) = Self::decode_message( + &buf[index + header_len..index + header_len + body_len as usize], + )?; + if length != body_len as usize { + tracing::warn!( + "Did not know what to do with all the bytes, got {} but decoded {}. \ + This may be because the peer implements a newer protocol version \ + that has extra fields.", + body_len, + length + ); + } + if let Frame::MessageBatch(messages) = frame { + for message in messages { + combined_messages.push(message); + } + } else { + unreachable!("Can not get Raw messages"); + } + index += header_len + body_len as usize; + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "received invalid data in multi-message chunk", + )); + } + } + Ok(Frame::MessageBatch(combined_messages)) + } + + fn decode_message(buf: &[u8]) -> Result<(Self, usize), io::Error> { + println!("decode_message {buf:02X?}"); + // buffer length >= 3 or more and starts with 0 is message batch + if buf.len() >= 3 && buf[0] == 0x00 { + if buf[1] == 0x00 { + // Batch of messages + let mut messages: Vec = vec![]; + let mut state = State::new_with_start_and_end(2, buf.len()); + + // First, there is the original channel + let mut current_channel: u64 = state.decode(buf)?; + while state.start() < state.end() { + // Length of the message is inbetween here + let channel_message_length: usize = state.decode(buf)?; + if state.start() + channel_message_length > state.end() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "received invalid message length, {} + {} > {}", + state.start(), + channel_message_length, + state.end() + ), + )); + } + // Then the actual message + let (channel_message, _) = ChannelMessage::decode( + &buf[state.start()..state.start() + channel_message_length], + current_channel, + )?; + messages.push(channel_message); + state.add_start(channel_message_length)?; + // After that, if there is an extra 0x00, that means the channel + // changed. This works because of LE encoding, and channels starting + // from the index 1. + if state.start() < state.end() && buf[state.start()] == 0x00 { + state.add_start(1)?; + current_channel = state.decode(buf)?; + } + } + Ok((Frame::MessageBatch(messages), state.start())) + } else if buf[1] == 0x01 { + // Open message + let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?; + Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) + } else if buf[1] == 0x03 { + // Close message + let (channel_message, length) = + ChannelMessage::decode_close_message(&buf[2..])?; + Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + "received invalid special message", + )) + } + } else if buf.len() >= 2 { + // len >= and + // Single message + let mut state = State::from_buffer(buf); + let channel: u64 = state.decode(buf)?; + let (channel_message, length) = + ChannelMessage::decode(&buf[state.start()..], channel)?; + Ok(( + Frame::MessageBatch(vec![channel_message]), + state.start() + length, + )) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("received too short message, {buf:02X?}"), + )) + } + } + } #[test] fn compare_with_frame_encoding_decoding() -> std::io::Result<()> { @@ -1008,7 +991,7 @@ mod tests { assert_eq!(cbuf, fbuf); - let fres = Frame::decode_multiple(&fbuf, &FrameType::Message)?; + let fres = Frame::decode_multiple(&fbuf)?; assert_eq!(fres, frame); let cres_m = decode_many_channel_messages(&cbuf)?.0; assert_eq!(cres_m, cmvec); diff --git a/src/mqueue.rs b/src/mqueue.rs index 39b4c9b..c802d34 100644 --- a/src/mqueue.rs +++ b/src/mqueue.rs @@ -7,21 +7,16 @@ use std::{ task::{Context, Poll}, }; -use futures::{AsyncRead, AsyncWrite, Sink, Stream}; -use tracing::{debug, error, info, instrument, trace}; +use futures::{Sink, Stream}; +use tracing::{debug, error, instrument, trace}; -use crate::{ - encrypted_framed_message_channel, - message::{decode_many_channel_messages, ChannelMessage, Encoder as _}, -}; +use crate::message::{decode_many_channel_messages, ChannelMessage, Encoder as _}; pub(crate) struct MessageIo { io: IO, write_queue: VecDeque, } -use crate::{framing::Uint24LELengthPrefixedFraming, noise::Encrypted}; - impl>> + Sink> + Send + Unpin + 'static> MessageIo { pub(crate) fn new(io: IO) -> Self { Self { @@ -68,7 +63,7 @@ impl>> + Sink> + Send + Unpin + 'static } Poll::Ready(Err(_e)) => { error!("Error flushing"); - return todo!(); + todo!() } Poll::Pending => { cx.waker().wake_by_ref(); diff --git a/src/protocol/modern.rs b/src/protocol/modern.rs index acbdc05..f9bfa80 100644 --- a/src/protocol/modern.rs +++ b/src/protocol/modern.rs @@ -216,11 +216,6 @@ where self.channels.iter().map(|c| c.discovery_key()) } - /// Stop the protocol and return the inner reader and writer. - pub fn release(self) -> MessageIo>> { - self.io - } - #[instrument(skip_all)] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); diff --git a/src/protocol/old.rs b/src/protocol/old.rs index 2c7d4c5..b5f44ec 100644 --- a/src/protocol/old.rs +++ b/src/protocol/old.rs @@ -247,11 +247,6 @@ where self.channels.iter().map(|c| c.discovery_key()) } - /// Stop the protocol and return the inner reader and writer. - pub fn release(self) -> IO { - self.io - } - #[instrument(skip_all)] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); diff --git a/src/writer.rs b/src/writer.rs index d91adfb..df89949 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -9,10 +9,6 @@ use std::pin::Pin; use std::task::{Context, Poll}; const BUF_SIZE: usize = 1024 * 64; -// This is the largest size that will fit in u24. -// a message is larger than this we should error. -// also check message is smaller than this when we are encrypting. -const _MAX_MSG_SIZE: usize = 2usize.pow(24) - 1; #[derive(Debug)] pub(crate) enum Step { diff --git a/tests/basic.rs b/tests/basic.rs index a102bc0..5730dbc 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -12,22 +12,20 @@ mod _util; #[tokio::test] async fn basic_protocol() -> anyhow::Result<()> { - let (proto_a, proto_b) = create_pair_memory().await?; + _util::log(); + let (proto_a, proto_b) = create_pair_memory2().await?; - dbg!(); let next_a = next_event(proto_a); - dbg!(); let next_b = next_event(proto_b); - dbg!(); - let (proto_b, event_b) = next_b.await?; - dbg!(); let (mut proto_a, event_a) = next_a.await?; + let (proto_b, event_b) = next_b.await?; + //let (a, b) = join(next_a, next_b).await; - dbg!(); //let (mut proto_a, event_a) = a?; - dbg!(); //let (proto_b, event_b) = b?; + dbg!(&event_a); + dbg!(&event_b); assert!(matches!(event_a, Ok(Event::Handshake(_)))); assert!(matches!(event_b, Ok(Event::Handshake(_)))); From b23bead8bdb5aaa974724b5513bce68afcfd2805 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 1 Apr 2025 16:24:05 -0400 Subject: [PATCH 056/206] Add tracing::instrument --- src/channels.rs | 13 +++++++------ src/crypto/handshake.rs | 4 +++- src/message/modern.rs | 2 +- src/noise.rs | 3 ++- src/protocol/modern.rs | 8 +++++++- src/protocol/old.rs | 12 ++++++++++-- src/writer.rs | 2 ++ 7 files changed, 32 insertions(+), 12 deletions(-) diff --git a/src/channels.rs b/src/channels.rs index c2e22f8..8e82116 100644 --- a/src/channels.rs +++ b/src/channels.rs @@ -13,7 +13,7 @@ use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::task::Poll; -use tracing::debug; +use tracing::instrument; /// A protocol channel. /// @@ -93,7 +93,6 @@ impl Channel { "Channel is closed", )); } - debug!("TX:\n{message:?}\n"); let message = ChannelMessage::new(self.local_id as u64, message); self.outbound_tx .send(vec![message]) @@ -122,10 +121,7 @@ impl Channel { let messages = messages .iter() - .map(|message| { - debug!("TX:\n{message:?}\n"); - ChannelMessage::new(self.local_id as u64, message.clone()) - }) + .map(|message| ChannelMessage::new(self.local_id as u64, message.clone())) .collect(); self.outbound_tx .send(messages) @@ -249,6 +245,7 @@ impl ChannelHandle { self.remote_state.as_ref().map(|s| s.remote_id) } + #[instrument(skip_all, fields(local_id = local_id))] pub(crate) fn attach_local(&mut self, local_id: usize, key: Key) { let local_state = LocalState { local_id, key }; self.local_state = Some(local_state); @@ -271,11 +268,13 @@ impl ChannelHandle { return Err(error("Channel is not opened from both local and remote")); } // Safe because of the is_connected() check above. + dbg!(&self.local_state, &self.remote_state); let local_state = self.local_state.as_ref().unwrap(); let remote_state = self.remote_state.as_ref().unwrap(); Ok((&local_state.key, remote_state.remote_capability.as_ref())) } + #[instrument(skip_all)] pub(crate) fn open(&mut self, outbound_tx: Sender>) -> Channel { let local_state = self .local_state @@ -433,6 +432,7 @@ impl ChannelMap { self.channels.remove(&hdkey); } + #[instrument(skip(self))] pub(crate) fn prepare_to_verify(&self, local_id: usize) -> Result<(&Key, Option<&Vec>)> { let channel_handle = self .get_local(local_id) @@ -477,6 +477,7 @@ impl ChannelMap { Ok(()) } + #[instrument(skip_all)] fn alloc_local(&mut self) -> usize { let empty_id = self .local_id diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index fc5ad02..8659f09 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -33,6 +33,7 @@ pub(crate) struct HandshakeResult { } impl HandshakeResult { + #[instrument(skip_all)] pub(crate) fn capability(&self, key: &[u8]) -> Option> { Some(replicate_capability( self.is_initiator, @@ -49,6 +50,7 @@ impl HandshakeResult { )) } + #[instrument(skip_all)] pub(crate) fn verify_remote_capability( &self, capability: Option>, @@ -179,7 +181,7 @@ impl Handshake { // reads in `msg` without framing bytes, but emits msg WITH framing bytes #[cfg(not(feature = "protocol"))] pub(crate) fn read(&mut self, msg: &[u8]) -> Result>> { - Ok(self.read_raw(msg)?.map(|x| wrap_uint24_le(&x))) + Ok(self.read_raw(msg)?.map(|x| crate::util::wrap_uint24_le(&x))) } pub(crate) fn into_result(&self) -> Result<&HandshakeResult> { diff --git a/src/message/modern.rs b/src/message/modern.rs index 8b16988..9d7346d 100644 --- a/src/message/modern.rs +++ b/src/message/modern.rs @@ -353,7 +353,7 @@ impl Encoder for Vec { Ok(prencode_channel_messages(self, &mut state)? + UINT24_HEADER_LEN) } - #[instrument] + #[instrument(skip_all)] fn encode(&self, buf: &mut [u8]) -> Result { let mut state = State::new(); let body_len = prencode_channel_messages(self, &mut state)?; diff --git a/src/noise.rs b/src/noise.rs index d2a3a1d..c1fa5d6 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -148,7 +148,7 @@ impl>> + Sink> + Send + Unpin + 'static { type Item = Result>; - #[instrument(skip_all, fields(initiator = %self.is_initiator))] + #[instrument(skip(cx), fields(initiator = %self.is_initiator))] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let Encrypted { io, @@ -357,6 +357,7 @@ fn maybe_init(step: &mut Step, is_initiator: bool) -> Result>> { Ok(out) } +#[instrument(skip_all)] fn reset_encrypted( step: &mut Step, maybe_init_message: Option>, diff --git a/src/protocol/modern.rs b/src/protocol/modern.rs index f9bfa80..6098ff7 100644 --- a/src/protocol/modern.rs +++ b/src/protocol/modern.rs @@ -216,7 +216,7 @@ where self.channels.iter().map(|c| c.discovery_key()) } - #[instrument(skip_all)] + #[instrument(skip_all, fields(initiator = ?self.is_initiator()))] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); @@ -352,6 +352,7 @@ where Ok(()) } + #[instrument(skip(self))] fn on_command(&mut self, command: Command) -> Result<()> { match command { Command::Open(key) => self.command_open(key), @@ -402,6 +403,7 @@ where Ok(()) } + #[instrument(skip(self))] fn on_open(&mut self, ch: u64, msg: Open) -> Result<()> { let discovery_key: DiscoveryKey = parse_key(&msg.discovery_key)?; let channel_handle = @@ -418,10 +420,12 @@ where Ok(()) } + #[instrument(skip(self))] fn queue_event(&mut self, event: Event) { self.queued_events.push_back(event); } + #[instrument(skip(self))] fn accept_channel(&mut self, local_id: usize) -> Result<()> { let (key, remote_capability) = self.channels.prepare_to_verify(local_id)?; self.verify_remote_capability(remote_capability.cloned(), key)?; @@ -451,6 +455,7 @@ where Ok(()) } + #[instrument(skip_all)] fn capability(&self, key: &[u8]) -> Option> { match self.handshake.as_ref() { Some(handshake) => handshake.capability(key), @@ -458,6 +463,7 @@ where } } + #[instrument(skip_all)] fn verify_remote_capability(&self, capability: Option>, key: &[u8]) -> Result<()> { match self.handshake.as_ref() { Some(handshake) => handshake.verify_remote_capability(capability, key), diff --git a/src/protocol/old.rs b/src/protocol/old.rs index b5f44ec..20c9064 100644 --- a/src/protocol/old.rs +++ b/src/protocol/old.rs @@ -230,6 +230,7 @@ where } /// Give a command to the protocol. + #[instrument(skip(self))] pub async fn command(&mut self, command: Command) -> Result<()> { self.command_tx.send(command).await } @@ -238,6 +239,7 @@ where /// /// Once the other side proofed that it also knows the `key`, the channel is emitted as /// `Event::Channel` on the protocol event stream. + #[instrument(skip(self))] pub async fn open(&mut self, key: Key) -> Result<()> { self.command_tx.open(key).await } @@ -247,7 +249,7 @@ where self.channels.iter().map(|c| c.discovery_key()) } - #[instrument(skip_all)] + #[instrument(skip_all, fields(initiator = ?self.is_initiator()))] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); @@ -311,6 +313,7 @@ where } /// Poll commands. + #[instrument(skip_all)] fn poll_commands(&mut self, cx: &mut Context<'_>) -> Result<()> { while let Poll::Ready(Some(command)) = Pin::new(&mut self.command_rx).poll_next(cx) { self.on_command(command)?; @@ -515,7 +518,7 @@ where self.state = State::Established; Ok(()) } - + #[instrument(skip_all)] fn on_inbound_message(&mut self, channel_message: ChannelMessage) -> Result<()> { // let channel_message = ChannelMessage::decode(buf)?; let (remote_id, message) = channel_message.into_split(); @@ -529,6 +532,7 @@ where Ok(()) } + #[instrument(skip(self))] fn on_command(&mut self, command: Command) -> Result<()> { match command { Command::Open(key) => self.command_open(key), @@ -580,6 +584,7 @@ where Ok(()) } + #[instrument(skip(self))] fn on_open(&mut self, ch: u64, msg: Open) -> Result<()> { let discovery_key: DiscoveryKey = parse_key(&msg.discovery_key)?; let channel_handle = @@ -596,6 +601,7 @@ where Ok(()) } + #[instrument(skip(self))] fn queue_event(&mut self, event: Event) { self.queued_events.push_back(event); } @@ -607,6 +613,7 @@ where .try_encode_and_enqueue_frame_for_tx(&mut frame) } + #[instrument(skip(self))] fn accept_channel(&mut self, local_id: usize) -> Result<()> { let (key, remote_capability) = self.channels.prepare_to_verify(local_id)?; self.verify_remote_capability(remote_capability.cloned(), key)?; @@ -636,6 +643,7 @@ where Ok(()) } + #[instrument(skip_all)] fn capability(&self, key: &[u8]) -> Option> { match self.handshake.as_ref() { Some(handshake) => handshake.capability(key), diff --git a/src/writer.rs b/src/writer.rs index df89949..9a1465b 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -1,5 +1,6 @@ use crate::crypto::EncryptCipher; use crate::message::{Encoder, Frame}; +use tracing::instrument; use futures_lite::{ready, AsyncWrite}; use std::collections::VecDeque; @@ -61,6 +62,7 @@ impl WriteState { self.queue.push_back(frame.into()) } + #[instrument(skip(self))] pub(crate) fn try_encode_and_enqueue_frame_for_tx( &mut self, frame: &mut T, From 16029ee471e2c31315a170d568ce76f1b12fdfcd Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 1 Apr 2025 16:34:30 -0400 Subject: [PATCH 057/206] rm dbg --- src/schema.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/schema.rs b/src/schema.rs index bf35416..08e5221 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -18,9 +18,10 @@ pub struct Open { impl CompactEncoding for State { fn preencode(&mut self, value: &Open) -> Result { - dbg!(self.preencode(&value.channel)?); - dbg!(self.preencode(&value.protocol)?); - dbg!(self.preencode(&value.discovery_key)?); + let start = self.end(); + self.preencode(&value.channel)?; + self.preencode(&value.protocol)?; + self.preencode(&value.discovery_key)?; if value.capability.is_some() { self.add_end(1)?; // flags for future use self.preencode_fixed_32()?; @@ -29,7 +30,6 @@ impl CompactEncoding for State { } fn encode(&mut self, value: &Open, buffer: &mut [u8]) -> Result { - dbg!(); self.encode(&value.channel, buffer)?; self.encode(&value.protocol, buffer)?; self.encode(&value.discovery_key, buffer)?; @@ -370,7 +370,7 @@ pub struct NoData { impl CompactEncoding for State { fn preencode(&mut self, value: &NoData) -> Result { - dbg!(self.preencode(dbg!(&value.request))) + self.preencode(&value.request) } fn encode(&mut self, value: &NoData, buffer: &mut [u8]) -> Result { From 5bf98ed4e524c95be68b1cdf87071eda52d958bb Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 2 Apr 2025 11:44:24 -0400 Subject: [PATCH 058/206] rm unused --- src/channels.rs | 1 - src/crypto/handshake.rs | 2 +- src/framing.rs | 2 +- src/message/modern.rs | 6 +----- src/mqueue.rs | 1 - src/schema.rs | 1 - 6 files changed, 3 insertions(+), 10 deletions(-) diff --git a/src/channels.rs b/src/channels.rs index 8e82116..1b94ece 100644 --- a/src/channels.rs +++ b/src/channels.rs @@ -268,7 +268,6 @@ impl ChannelHandle { return Err(error("Channel is not opened from both local and remote")); } // Safe because of the is_connected() check above. - dbg!(&self.local_state, &self.remote_state); let local_state = self.local_state.as_ref().unwrap(); let remote_state = self.remote_state.as_ref().unwrap(); Ok((&local_state.key, remote_state.remote_capability.as_ref())) diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index 8659f09..0094edb 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -237,7 +237,7 @@ fn map_err(e: SnowError) -> Error { } /// Create a hash used to indicate replication capability. -/// See https://github.com/hypercore-protocol/hypercore/blob/70b271643c4e4b1e5ecae5bb579966dfe6361ff3/lib/caps.js#L11 +/// See JavaScript [here](https://github.com/hypercore-protocol/hypercore/blob/70b271643c4e4b1e5ecae5bb579966dfe6361ff3/lib/caps.js#L11). fn replicate_capability(is_initiator: bool, key: &[u8], handshake_hash: &[u8]) -> Vec { let seed = if is_initiator { REPLICATE_INITIATOR diff --git a/src/framing.rs b/src/framing.rs index 8b8ae8f..a51cea9 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -40,7 +40,7 @@ impl Uint24LELengthPrefixedFraming where IO: AsyncWrite + AsyncRead + Send + Unpin + 'static, { - /// Build [`LengthPrefixed`] around an [`AsyncWrite`]/[`AsyncRead`] thing. + /// Build [`Uint24LELengthPrefixedFraming`] around an [`AsyncWrite`]/[`AsyncRead`] thing. pub fn new(io: IO) -> Self { Self { io, diff --git a/src/message/modern.rs b/src/message/modern.rs index 9d7346d..11c1491 100644 --- a/src/message/modern.rs +++ b/src/message/modern.rs @@ -125,7 +125,6 @@ pub(crate) fn decode_one_channel_message( if buf.len() >= 3 && buf[0] == 0x00 { if buf[1] == 0x00 { // Batch of messages - dbg!(); let mut messages: Vec = vec![]; let mut state = State::new_with_start_and_end(2, buf.len()); @@ -162,12 +161,10 @@ pub(crate) fn decode_one_channel_message( } Ok((messages, state.start())) } else if buf[1] == 0x01 { - dbg!(); // Open message let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?; Ok((vec![channel_message], length + 2)) } else if buf[1] == 0x03 { - dbg!(); // Close message let (channel_message, length) = ChannelMessage::decode_close_message(&buf[2..])?; Ok((vec![channel_message], length + 2)) @@ -178,7 +175,6 @@ pub(crate) fn decode_one_channel_message( )) } } else if buf.len() >= 2 { - dbg!(); // Single message let mut state = State::from_buffer(buf); let channel: u64 = state.decode(buf)?; @@ -316,7 +312,7 @@ fn prencode_channel_messages( std::cmp::Ordering::Equal => { if let Message::Open(_) = &messages[0].message { // This is a special case with 0x00, 0x01 intro bytes - state.add_end(2 + dbg!(&messages[0].encoded_len()?))?; + state.add_end(2 + &messages[0].encoded_len()?)?; } else if let Message::Close(_) = &messages[0].message { // This is a special case with 0x00, 0x03 intro bytes state.add_end(2 + &messages[0].encoded_len()?)?; diff --git a/src/mqueue.rs b/src/mqueue.rs index c802d34..ea99f42 100644 --- a/src/mqueue.rs +++ b/src/mqueue.rs @@ -44,7 +44,6 @@ impl>> + Sink> + Send + Unpin + 'static } let mut buf = vec![0; messages.encoded_len()?]; - dbg!(&buf); match messages.encode(&mut buf) { Ok(_) => {} Err(e) => { diff --git a/src/schema.rs b/src/schema.rs index 08e5221..ef58e77 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -18,7 +18,6 @@ pub struct Open { impl CompactEncoding for State { fn preencode(&mut self, value: &Open) -> Result { - let start = self.end(); self.preencode(&value.channel)?; self.preencode(&value.protocol)?; self.preencode(&value.discovery_key)?; From d3bfd07b89d28c56940a11f4d2fc165f39fe8e85 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 2 Apr 2025 11:50:24 -0400 Subject: [PATCH 059/206] rm unused --- src/crypto/cipher.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index 8ef6c9e..53c291f 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -167,7 +167,7 @@ mod encrypt_cipher { } } #[cfg(not(feature = "protocol"))] -pub use encrypt_cipher::*; +pub(crate) use encrypt_cipher::*; // NB: These values come from Javascript-side // From 094caa83c6b5ae26dbe2116e6d36d43b0a6bf07d Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 2 Apr 2025 11:51:09 -0400 Subject: [PATCH 060/206] pub HandshakeResult --- src/crypto/handshake.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index 0094edb..72c9da3 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -23,7 +23,7 @@ const REPLICATE_RESPONDER: [u8; 32] = [ ]; #[derive(Debug, Clone, Default)] -pub(crate) struct HandshakeResult { +pub struct HandshakeResult { pub(crate) is_initiator: bool, pub(crate) local_pubkey: Vec, pub(crate) remote_pubkey: Vec, From 7a8b2dadb99529e64fbe5fcf9ea8d7bf4252e5c5 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 2 Apr 2025 11:52:19 -0400 Subject: [PATCH 061/206] custom debug for Framig --- src/framing.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/framing.rs b/src/framing.rs index a51cea9..7daef38 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -33,7 +33,14 @@ pub struct Uint24LELengthPrefixedFraming { } impl Debug for Uint24LELengthPrefixedFraming { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Format()") + f.debug_struct("Framer") + //.field("io", &self.io) + .field("to_stream.len()", &self.to_stream.len()) + .field("from_sink", &self.from_sink.len()) + .field("last_out_idx", &self.last_out_idx) + .field("last_data_idx", &self.last_data_idx) + .field("step", &self.step) + .finish() } } impl Uint24LELengthPrefixedFraming From a7fac6deab1fab5378e195dc48bb9d72ad2888f7 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 2 Apr 2025 11:52:47 -0400 Subject: [PATCH 062/206] feature gates --- src/lib.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 1990b97..c13ccae 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -142,8 +142,9 @@ pub mod schema; pub use builder::Builder as ProtocolBuilder; pub use channels::Channel; pub use framing::Uint24LELengthPrefixedFraming; -pub use noise::{encrypted_framed_message_channel, Encrypted}; +pub use noise::{encrypted_framed_message_channel, Encrypted, Event as NoiseEvent}; // Export the needed types for Channel::take_receiver, and Channel::local_sender() +#[cfg(feature = "protocol")] pub use async_channel::{ Receiver as ChannelReceiver, SendError as ChannelSendError, Sender as ChannelSender, }; From a189b047f6d7bb6a0b43e9becab7153cf1fe7990 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 2 Apr 2025 11:54:35 -0400 Subject: [PATCH 063/206] rm println --- src/message/old.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/message/old.rs b/src/message/old.rs index 373eea2..d4afd64 100644 --- a/src/message/old.rs +++ b/src/message/old.rs @@ -163,7 +163,6 @@ impl Frame { } fn decode_message(buf: &[u8]) -> Result<(Self, usize), io::Error> { - println!("decode_message {buf:02X?}"); // buffer length >= 3 or more and starts with 0 is message batch if buf.len() >= 3 && buf[0] == 0x00 { if buf[1] == 0x00 { From 9db631bf5a25171908e3869a1f03aad1621ea582 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 2 Apr 2025 11:56:24 -0400 Subject: [PATCH 064/206] rm print --- src/message/modern.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/message/modern.rs b/src/message/modern.rs index 11c1491..2d8e732 100644 --- a/src/message/modern.rs +++ b/src/message/modern.rs @@ -894,7 +894,6 @@ mod tests { } fn decode_message(buf: &[u8]) -> Result<(Self, usize), io::Error> { - println!("decode_message {buf:02X?}"); // buffer length >= 3 or more and starts with 0 is message batch if buf.len() >= 3 && buf[0] == 0x00 { if buf[1] == 0x00 { From a7325077854fb1319000c981f844bccc9908c580 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 2 Apr 2025 12:18:18 -0400 Subject: [PATCH 065/206] refactor mqueue to pass through non byte messages --- src/mqueue.rs | 126 ++++++++++++++++++++++++++++---------------------- 1 file changed, 72 insertions(+), 54 deletions(-) diff --git a/src/mqueue.rs b/src/mqueue.rs index ea99f42..7e1fffd 100644 --- a/src/mqueue.rs +++ b/src/mqueue.rs @@ -2,22 +2,60 @@ use std::{ collections::VecDeque, + fmt::Debug, io::Result, pin::Pin, task::{Context, Poll}, }; use futures::{Sink, Stream}; -use tracing::{debug, error, instrument, trace}; +use tracing::{error, instrument}; -use crate::message::{decode_many_channel_messages, ChannelMessage, Encoder as _}; +use crate::{ + message::{decode_many_channel_messages, ChannelMessage, Encoder as _}, + noise::EncryptionInfo, + NoiseEvent, +}; + +#[derive(Debug)] +pub(crate) enum MqueueEvent { + Meta(EncryptionInfo), + Message(Result>), +} + +impl From for MqueueEvent { + fn from(e: NoiseEvent) -> Self { + match e { + NoiseEvent::Meta(einf) => Self::Meta(einf), + NoiseEvent::Decrypted(dec_res) => { + match dec_res { + Ok(encoded) => match decode_many_channel_messages(&encoded) { + //assert_eq!(_n_read, encoded.len()); } + Ok((messsages, _n_read)) => Self::Message(Ok(messsages)), + Err(e) => Self::Message(Err(e)), + }, + Err(e) => Self::Message(Err(e)), + } + } + } + } +} pub(crate) struct MessageIo { io: IO, write_queue: VecDeque, } -impl>> + Sink> + Send + Unpin + 'static> MessageIo { +impl Debug for MessageIo { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MessageIo") + //.field("io", &self.io) + .field("write_queue", &self.write_queue) + .finish() + } +} + +impl + Sink> + Send + Unpin + 'static> MessageIo { pub(crate) fn new(io: IO) -> Self { Self { io, @@ -48,26 +86,24 @@ impl>> + Sink> + Send + Unpin + 'static Ok(_) => {} Err(e) => { error!(error = ?e, "error encoding messages"); + // TODO this would probably be a programming error. + // if so, this sholud just be an unwrap/expect return Poll::Ready(Err(e.into())); } } if let Err(_e) = Sink::start_send(Pin::new(&mut self.io), buf) { - error!("error in start_send"); todo!() } match Sink::poll_flush(Pin::new(&mut self.io), cx) { - Poll::Ready(Ok(())) => { - debug!("flushed"); - } Poll::Ready(Err(_e)) => { - error!("Error flushing"); todo!() } Poll::Pending => { cx.waker().wake_by_ref(); return Poll::Pending; } + _ => {} } } @@ -79,46 +115,21 @@ impl>> + Sink> + Send + Unpin + 'static } } - pub(crate) fn poll_inbound( - &mut self, - cx: &mut Context<'_>, - ) -> Poll>> { - match Pin::new(&mut self.io).poll_next(cx) { - Poll::Ready(Some(Ok(encoded))) => { - match decode_many_channel_messages(&encoded) { - Ok((messsages, n_read)) => { - assert_eq!(n_read, encoded.len()); // I think this is always true - Poll::Ready(Ok(messsages)) - } - Err(_) => todo!(), - } - } - Poll::Ready(Some(Err(_e))) => todo!(), - Poll::Ready(None) => todo!(), - Poll::Pending => Poll::Pending, - } + pub(crate) fn poll_inbound(&mut self, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.io) + .poll_next(cx) + .map(|opt| opt.map(MqueueEvent::from)) } } -impl>> + Sink> + Send + Unpin + 'static> Stream +impl + Sink> + Send + Unpin + 'static> Stream for MessageIo { - type Item = Result>; + type Item = MqueueEvent; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let out_res = self.poll_outbound(cx); - match out_res { - Poll::Ready(res) => match res { - Ok(okres) => trace!(res = ?okres, "MessageIo poll_outbound"), - Err(e) => error!(error = ?e, "MessageIo error in poll_outbound"), - }, - Poll::Pending => trace!("MessageIo poll_outbound Pending"), - } - - let in_res = self.poll_inbound(cx); - trace!(poll_inbound = ?in_res, "MessageIo"); - - in_res.map(Some) + let _ = self.poll_outbound(cx); + self.poll_inbound(cx) } } @@ -134,7 +145,7 @@ mod test { schema::NoData, test_utils::log, Encrypted, Uint24LELengthPrefixedFraming, }; - use super::MessageIo; + use super::{MessageIo, MqueueEvent}; pub(crate) fn encrypted_and_framed< BytesTxRx: AsyncRead + AsyncWrite + Send + Unpin + 'static, >( @@ -154,6 +165,13 @@ mod test { } } + fn take_messages(e: Option) -> Option> { + match e { + Some(MqueueEvent::Message(Result::Ok(out))) => Some(out), + _ => None, + } + } + #[tokio::test] async fn mqueue() -> Result<()> { log(); @@ -167,19 +185,19 @@ mod test { left.enqueue(ltorm.clone()); right.enqueue(rtolm.clone()); - match select(left.next(), right.next()).await { - futures::future::Either::Left((m, _)) => { - if let Some(Ok(res)) = m { - assert_eq!(res, vec![rtolm]); - } else { - panic!(); + loop { + match select(left.next(), right.next()).await { + futures::future::Either::Left((m, _)) => { + if let Some(m) = take_messages(m) { + assert_eq!(m, vec![rtolm]); + break; + } } - } - futures::future::Either::Right((m, _)) => { - if let Some(Ok(res)) = m { - assert_eq!(res, vec![ltorm]); - } else { - panic!(); + futures::future::Either::Right((m, _)) => { + if let Some(m) = take_messages(m) { + assert_eq!(m, vec![rtolm]); + break; + } } } } From c5888f569dff7d7b08385abd39d7d921c940abb1 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 2 Apr 2025 12:29:14 -0400 Subject: [PATCH 066/206] use tracing-tree for viewing logs in tests --- Cargo.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 7aeb9e3..4ceb1f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,7 +61,8 @@ sluice = "0.5.4" futures = "0.3.13" log = "0.4" test-log = { version = "0.2.11", default-features = false, features = ["trace"] } -tracing-subscriber = { version = "0.3.16", features = ["env-filter", "fmt"] } +tracing-subscriber = { version = "0.3.19", features = ["env-filter", "fmt"] } +tracing-tree = "0.4.0" tokio-util = { version = "0.7.14", features = ["compat"] } [features] From d21d4a9972b80feb3ead499396d6d501a0d4ef22 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 2 Apr 2025 12:31:05 -0400 Subject: [PATCH 067/206] use tracing-tree --- src/test_utils.rs | 26 +++++++++++++++++--------- tests/_util.rs | 26 +++++++++++++++++--------- 2 files changed, 34 insertions(+), 18 deletions(-) diff --git a/src/test_utils.rs b/src/test_utils.rs index e67d756..3f687ea 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -74,17 +74,25 @@ impl TwoWay { } pub(crate) fn log() { - use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; static START_LOGS: std::sync::OnceLock<()> = std::sync::OnceLock::new(); START_LOGS.get_or_init(|| { - tracing_subscriber::fmt() - .with_target(true) - .with_line_number(true) - // print when instrumented funtion enters - .with_span_events(FmtSpan::ENTER | FmtSpan::EXIT) - .with_file(true) - .with_env_filter(EnvFilter::from_default_env()) // Reads `RUST_LOG` environment variable - .without_time() + use tracing_subscriber::{ + layer::SubscriberExt as _, util::SubscriberInitExt as _, EnvFilter, + }; + let env_filter = EnvFilter::from_default_env(); // Reads `RUST_LOG` environment variable + + // Create the hierarchical layer from tracing_tree + let tree_layer = tracing_tree::HierarchicalLayer::new(2) // 2 spaces per indent level + .with_targets(true) + .with_bracketed_fields(true) + .with_indent_lines(true) + .with_span_modes(true) + .with_thread_ids(false) + .with_thread_names(false); + + tracing_subscriber::registry() + .with(env_filter) + .with(tree_layer) .init(); }); } diff --git a/tests/_util.rs b/tests/_util.rs index aec496d..e2cc679 100644 --- a/tests/_util.rs +++ b/tests/_util.rs @@ -9,17 +9,25 @@ use tokio::task::JoinHandle; #[allow(unused)] pub(crate) fn log() { - use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; static START_LOGS: std::sync::OnceLock<()> = std::sync::OnceLock::new(); START_LOGS.get_or_init(|| { - tracing_subscriber::fmt() - .with_target(true) - .with_line_number(true) - // print when instrumented funtion enters - .with_span_events(FmtSpan::ENTER | FmtSpan::EXIT) - .with_file(true) - .with_env_filter(EnvFilter::from_default_env()) // Reads `RUST_LOG` environment variable - .without_time() + use tracing_subscriber::{ + layer::SubscriberExt as _, util::SubscriberInitExt as _, EnvFilter, + }; + let env_filter = EnvFilter::from_default_env(); // Reads `RUST_LOG` environment variable + + // Create the hierarchical layer from tracing_tree + let tree_layer = tracing_tree::HierarchicalLayer::new(2) // 2 spaces per indent level + .with_targets(true) + .with_bracketed_fields(true) + .with_indent_lines(true) + .with_span_modes(true) + .with_thread_ids(false) + .with_thread_names(false); + + tracing_subscriber::registry() + .with(env_filter) + .with(tree_layer) .init(); }); } From 3062f0781dbc54870f7096026ae1d62a67fea04d Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 2 Apr 2025 12:33:59 -0400 Subject: [PATCH 068/206] rm unused async --- tests/_util.rs | 4 ++-- tests/basic.rs | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/_util.rs b/tests/_util.rs index e2cc679..d15be38 100644 --- a/tests/_util.rs +++ b/tests/_util.rs @@ -42,7 +42,7 @@ pub(crate) fn duplex(channel_size: usize) -> (TokioDuplex, TokioDuplex) { pub type MemoryProtocol = Protocol>; -pub async fn create_pair_memory() -> io::Result<(MemoryProtocol, MemoryProtocol)> { +pub fn create_pair_memory() -> (MemoryProtocol, MemoryProtocol) { let (ar, bw) = sluice::pipe::pipe(); let (br, aw) = sluice::pipe::pipe(); @@ -50,7 +50,7 @@ pub async fn create_pair_memory() -> io::Result<(MemoryProtocol, MemoryProtocol) let b = ProtocolBuilder::new(false); let a = a.connect_rw(ar, aw); let b = b.connect_rw(br, bw); - Ok((a, b)) + (a, b) } pub async fn create_pair_memory2() -> io::Result<(Protocol, Protocol)> { diff --git a/tests/basic.rs b/tests/basic.rs index 5730dbc..280e5be 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -24,8 +24,6 @@ async fn basic_protocol() -> anyhow::Result<()> { //let (mut proto_a, event_a) = a?; //let (proto_b, event_b) = b?; - dbg!(&event_a); - dbg!(&event_b); assert!(matches!(event_a, Ok(Event::Handshake(_)))); assert!(matches!(event_b, Ok(Event::Handshake(_)))); @@ -84,7 +82,7 @@ async fn basic_protocol() -> anyhow::Result<()> { #[tokio::test] async fn open_close_channels() -> anyhow::Result<()> { - let (mut proto_a, mut proto_b) = create_pair_memory().await?; + let (mut proto_a, mut proto_b) = create_pair_memory(); let key1 = [0u8; 32]; let key2 = [1u8; 32]; From 294184057d90ccb626b054c8082797f8d4cb5041 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 2 Apr 2025 12:52:11 -0400 Subject: [PATCH 069/206] expose handshake result, refactor, fix deadlock --- src/noise.rs | 401 ++++++++++++++++++++++++++++++++++----------------- 1 file changed, 268 insertions(+), 133 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index c1fa5d6..40ce6ac 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -7,7 +7,7 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use tracing::{debug, error, info, instrument, trace, warn}; +use tracing::{debug, error, instrument, trace, warn}; use crate::{ crypto::{DecryptCipher, Handshake, HandshakeResult, RawEncryptCipher}, @@ -30,6 +30,53 @@ pub(crate) enum Step { SecretStream((RawEncryptCipher, HandshakeResult)), Established((RawEncryptCipher, DecryptCipher, HandshakeResult)), } + +impl Step { + fn established(&self) -> bool { + matches!(self, Step::Established(_)) + } +} + +impl std::fmt::Display for Step { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Step::NotInitialized => "NotInitialized", + Step::Handshake(_) => "Handshake", + Step::SecretStream(_) => "SecretStream", + Step::Established(_) => "Established", + } + ) + } +} + +#[derive(Debug)] +/// Encryption related info +pub enum EncryptionInfo { + Handshake(HandshakeResult), +} +#[derive(Debug)] +/// Decrypted messages and encryption related events +pub enum Event { + /// Events related to the encryption stream + Meta(EncryptionInfo), + /// A decrypted message + Decrypted(Result>), +} + +impl From>> for Event { + fn from(value: Result>) -> Self { + Self::Decrypted(value) + } +} +impl From for Event { + fn from(value: HandshakeResult) -> Self { + Self::Meta(EncryptionInfo::Handshake(value)) + } +} + /// Wrap a stream with encryption pub struct Encrypted { io: IO, @@ -38,7 +85,7 @@ pub struct Encrypted { encrypted_tx: VecDeque>, encrypted_rx: VecDeque>>, plain_tx: VecDeque>, - plain_rx: VecDeque>>, + plain_rx: VecDeque, flush: bool, } @@ -62,7 +109,7 @@ where } /// Wether an encrypted connection has been established. pub fn encryption_established(&self) -> bool { - matches!(self.step, Step::Established(_)) + self.step.established() } } @@ -85,7 +132,7 @@ impl< #[instrument(skip_all, fields(initiator = %self.is_initiator))] fn start_send(mut self: Pin<&mut Self>, item: Vec) -> std::result::Result<(), Self::Error> { - info!(initiator = %self.is_initiator, "enqueue plain_tx\n{item:?}"); + trace!(initiator = %self.is_initiator, "enqueue plain_tx"); self.plain_tx.push_back(item); Ok(()) } @@ -95,6 +142,10 @@ impl< self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { + // The flow here can be understood as reading from the encrypted side moving those messages + // through to the plaintext side, then reading new plaintext messages and moving them to + // the encrypted side. + // We do this repeatedly until there's nothing else to do let Encrypted { io, step, @@ -107,30 +158,37 @@ impl< .. } = self.get_mut(); - poll_encrypted_side_io(io, cx, encrypted_tx, encrypted_rx, *is_initiator, flush); - - if let Step::Established((encryptor, decryptor, ..)) = step { - poll_do_encrypt_and_decrypt( - encryptor, - decryptor, + loop { + poll_message_throughput( + io, + cx, + step, encrypted_tx, encrypted_rx, - plain_tx, plain_rx, + plain_tx, *is_initiator, flush, ); + poll_outgoing_encrypted_messages(io, cx, encrypted_tx, *is_initiator, flush); - if *flush { - cx.waker().wake_by_ref(); - Poll::Pending - } else { - Poll::Ready(Ok(())) + // check if we've done all possible work + if did_as_much_as_possible( + io, + cx, + step, + encrypted_tx, + encrypted_rx, + plain_tx, + *is_initiator, + ) { + if !step.established() || !encrypted_tx.is_empty() || *flush { + trace!(not_established = !step.established(), tx_msgs_waiting = !encrypted_tx.is_empty(), flush = ?flush, "not done flushing"); + cx.waker().wake_by_ref(); + return Poll::Pending; + } + return Poll::Ready(Ok(())); } - } else { - poll_setup(step, encrypted_tx, encrypted_rx, *is_initiator, flush); - cx.waker().wake_by_ref(); - Poll::Pending } } @@ -143,10 +201,32 @@ impl< } } +/// Check that we've done as much work as possible. Sending, receiving, encrypting and decrypting. +fn did_as_much_as_possible< + IO: Stream>> + Sink> + Send + Unpin + 'static, +>( + io: &mut IO, + cx: &mut Context<'_>, + step: &mut Step, + encrypted_tx: &mut VecDeque>, + encrypted_rx: &mut VecDeque>>, + plain_tx: &mut VecDeque>, + is_initiator: bool, +) -> bool { + // No incoming encrypted messages available. + poll_incomming_encrypted_messages(io, cx, encrypted_rx, is_initiator).is_pending() + // We're unable to send any anymore encrypted/setup messages either because we have none or the `Sink` is unavailable. + && (encrypted_tx.is_empty() || Sink::poll_ready(Pin::new(io), cx).is_pending()) + // No encrypted messages waiting to be decrypted. + && encrypted_rx.is_empty() + // No plaint text messages waiting to be enccrypted or we're still setting up + && (plain_tx.is_empty() || !step.established()) +} + impl>> + Sink> + Send + Unpin + 'static> Stream for Encrypted { - type Item = Result>; + type Item = Event; #[instrument(skip(cx), fields(initiator = %self.is_initiator))] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -162,39 +242,71 @@ impl>> + Sink> + Send + Unpin + 'static .. } = self.get_mut(); - poll_encrypted_side_io(io, cx, encrypted_tx, encrypted_rx, *is_initiator, flush); - - if let Step::Established((encryptor, decryptor, ..)) = step { - poll_do_encrypt_and_decrypt( - encryptor, - decryptor, - encrypted_tx, - encrypted_rx, - plain_tx, - plain_rx, - *is_initiator, - flush, - ); - // emit any messages that are ready + if poll_message_throughput( + io, + cx, + step, + encrypted_tx, + encrypted_rx, + plain_rx, + plain_tx, + *is_initiator, + flush, + ) { if let Some(msg) = plain_rx.pop_front() { - trace!(initiator = %is_initiator, "plain rx emit"); Poll::Ready(Some(msg)) } else { Poll::Pending } } else { - poll_setup(step, encrypted_tx, encrypted_rx, *is_initiator, flush); cx.waker().wake_by_ref(); Poll::Pending } } } +/// Handle all message throughput. Sends, encrypts and decrypts messages +/// Returns `true` `step` is already [`Step::Established`]. +fn poll_message_throughput< + IO: Stream>> + Sink> + Send + Unpin + 'static, +>( + io: &mut IO, + cx: &mut Context<'_>, + step: &mut Step, + encrypted_tx: &mut VecDeque>, + encrypted_rx: &mut VecDeque>>, + plain_rx: &mut VecDeque, + plain_tx: &mut VecDeque>, + is_initiator: bool, + flush: &mut bool, +) -> bool { + poll_outgoing_encrypted_messages(io, cx, encrypted_tx, is_initiator, flush); + let _ = poll_incomming_encrypted_messages(io, cx, encrypted_rx, is_initiator); + if let Step::Established((encryptor, decryptor, ..)) = step { + // decrypt incomming msgs + poll_decrypt(decryptor, encrypted_rx, plain_rx, is_initiator); + // encrypt any pending plaintext outgoinng messages + poll_encrypt(encryptor, encrypted_tx, plain_tx, is_initiator, flush); + true + } else { + poll_setup( + step, + encrypted_tx, + encrypted_rx, + plain_rx, + is_initiator, + flush, + ); + false + } +} + #[instrument(skip_all, fields(initiator = %is_initiator))] fn poll_setup( step: &mut Step, encrypted_tx: &mut VecDeque>, encrypted_rx: &mut VecDeque>>, + plain_rx: &mut VecDeque, is_initiator: bool, flush: &mut bool, ) { @@ -205,27 +317,25 @@ fn poll_setup( // Still setting up if let Ok(Some(msg)) = maybe_init(step, is_initiator) { // queue the init message to send first - info!(initiator = %is_initiator,"queue initial msg\n{msg:?}"); + trace!(initiator = %is_initiator,"queue initial msg"); encrypted_tx.push_front(msg); } // TODO handle error - loop { - match encrypted_rx.pop_front() { - None => { - break; - } - Some(Err(e)) => { + while let Some(enc_res) = encrypted_rx.pop_front() { + match enc_res { + Err(e) => { error!("Recieved an error during setup encryption setup: {e:?}"); break; } - Some(Ok(incoming_msg)) => { - info!(initiator = %is_initiator, "recieved setup msg"); + Ok(incoming_msg) => { + trace!(initiator = %is_initiator, "encrypted_rx dequeue recieved setup msg"); if let Ok(msgs) = match handle_setup_message( step, &incoming_msg, is_initiator, encrypted_tx, encrypted_rx, + plain_rx, flush, ) { Ok(x) => Ok(x), @@ -235,31 +345,34 @@ fn poll_setup( } } { for msg in msgs.into_iter().rev() { - info!(initiator = %is_initiator,"queue more setup msg\n{msg:?}"); + trace!(initiator = %is_initiator,"queue more setup msg"); encrypted_tx.push_front(msg); } } } } + + if step.established() { + return; + } } } #[instrument(skip_all, fields(initiator = %is_initiator))] /// Fills `encrypted_rx` and drains `encrypted_tx`. -fn poll_encrypted_side_io< +fn poll_outgoing_encrypted_messages< IO: Stream>> + Sink> + Send + Unpin + 'static, >( io: &mut IO, cx: &mut Context<'_>, encrypted_tx: &mut VecDeque>, - encrypted_rx: &mut VecDeque>>, is_initiator: bool, flush: &mut bool, ) { // send any pending outgoing messages while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { if let Some(encrypted_out) = encrypted_tx.pop_front() { - info!(initiator = %is_initiator, msg_len = encrypted_out.len(), "enc tx send msg\n{encrypted_out:?}"); + trace!(initiator = %is_initiator, msg_len = encrypted_out.len(), "TX message"); if let Err(_e) = Sink::start_send(Pin::new(io), encrypted_out) { error!("Error polling encyrpted side io") } @@ -273,74 +386,86 @@ fn poll_encrypted_side_io< match Sink::poll_flush(Pin::new(io), cx) { Poll::Ready(Ok(())) => { *flush = false; - info!(initiator = %is_initiator, "flushed good"); + trace!(initiator = %is_initiator, "all flushed"); } Poll::Ready(Err(_e)) => { error!(initiator = %is_initiator, "Error sending encrypted msg") } Poll::Pending => { - // More confusing docs - // https://docs.rs/futures-sink/0.3.30/futures_sink/trait.Sink.html#tymethod.poll_flush - // It says: - // "Returns Poll::Pending if there is more work left to do, in which case the - // current task is scheduled (via cx.waker().wake_by_ref()) to wake up when - // poll_flush should be called again." - // Does this mean, each time this task wakes up again from this code path that - // I must trigger another poll_flush? But how would I know i need more - // flushing? - debug!("flush not completed"); + // flush not complete try again later *flush = true; } } } +} + +fn poll_incomming_encrypted_messages< + IO: Stream>> + Sink> + Send + Unpin + 'static, +>( + io: &mut IO, + cx: &mut Context<'_>, + encrypted_rx: &mut VecDeque>>, + is_initiator: bool, +) -> Poll<()> { // pull in any incomming encrypted messages - loop { - match Stream::poll_next(Pin::new(io), cx) { - Poll::Pending => break, - Poll::Ready(None) => break, - Poll::Ready(Some(encrypted_msg)) => { - trace!(initiator = %is_initiator, "enc rx queue\n{encrypted_msg:?}"); - encrypted_rx.push_back(encrypted_msg); - } - } + let mut got_some = false; + while let Poll::Ready(Some(encrypted_msg)) = Stream::poll_next(Pin::new(io), cx) { + trace!(initiator = %is_initiator, "RX message"); + encrypted_rx.push_back(encrypted_msg); + got_some = true; + } + if got_some { + Poll::Ready(()) + } else { + Poll::Pending } } -/// Process messages waiting to be encrypted or decrypted -// TODO sholud this return a Result? #[instrument(skip_all)] -fn poll_do_encrypt_and_decrypt( - encryptor: &mut RawEncryptCipher, +fn poll_decrypt( decryptor: &mut DecryptCipher, - encrypted_tx: &mut VecDeque>, encrypted_rx: &mut VecDeque>>, - plain_tx: &mut VecDeque>, - plain_rx: &mut VecDeque>>, + plain_rx: &mut VecDeque, is_initiator: bool, - flush: &mut bool, ) { // decrypt any incromming encrypted messages // TODO handle error - while let Some(Ok(incoming_msg)) = encrypted_rx.pop_front() { - info!(initiator = %is_initiator, "enc rx decrypting\n{incoming_msg:?}"); - match decryptor.decrypt_buf(&incoming_msg) { - Ok((plain_msg, _tag)) => { - info!(initiator = %is_initiator, "plain rx queue"); - plain_rx.push_back(Ok(plain_msg)); + while let Some(incoming_msg_res) = encrypted_rx.pop_front() { + match incoming_msg_res { + Ok(incoming_msg) => { + trace!(initiator = %is_initiator, "encrypted_rx dequeue decrypt"); + match decryptor.decrypt_buf(&incoming_msg) { + Ok((plain_msg, _tag)) => { + trace!(initiator = %is_initiator, "plain rx queue"); + plain_rx.push_back(Event::from(Ok(plain_msg))); + } + Err(e) => { + error!(initiator = %is_initiator,"RX message failed to decrypt: {e:?}") + } + } } Err(e) => { error!(initiator = %is_initiator,"RX message failed to decrypt: {e:?}") } } } +} +#[instrument(skip_all)] +fn poll_encrypt( + encryptor: &mut RawEncryptCipher, + encrypted_tx: &mut VecDeque>, + plain_tx: &mut VecDeque>, + is_initiator: bool, + flush: &mut bool, +) { // encrypt any pending plaintext outgoinng messages while let Some(plain_out) = plain_tx.pop_front() { let enc_out = match encryptor.encrypt(&plain_out) { Ok(x) => x, Err(_e) => todo!("We failed to encrypt our own message...?"), }; - trace!(initiator = %is_initiator, encrypted_msg_length = enc_out.len(), "enqueue new encrypted message from plain tx queue\n{enc_out:?}"); + trace!(initiator = %is_initiator, encrypted_msg_length = enc_out.len(), "enqueue new encrypted message from plain tx queue"); encrypted_tx.push_back(enc_out); *flush = true; } @@ -365,6 +490,7 @@ fn reset_encrypted( encrypted_rx: &mut VecDeque>>, flush: &mut bool, ) { + error!("Encrypted RESET"); *step = Step::NotInitialized; encrypted_tx.clear(); encrypted_rx.clear(); @@ -382,6 +508,7 @@ fn handle_setup_message( is_initiator: bool, encrypted_tx: &mut VecDeque>, encrypted_rx: &mut VecDeque>>, + plain_rx: &mut VecDeque, flush: &mut bool, ) -> Result>> { // this would only happen after reset with a bad message. @@ -401,7 +528,7 @@ fn handle_setup_message( Step::Handshake(_) => { let mut out = vec![]; if let Step::Handshake(mut handshake) = replace(step, Step::NotInitialized) { - trace!("Read in handshake msg\n{msg:?}"); + trace!("RX handshake msg"); if let Some(response) = match handshake.read_raw(msg) { Ok(x) => x, Err(e) => { @@ -418,15 +545,15 @@ fn handle_setup_message( return Err(e); } } { - info!( + trace!( initiator = %is_initiator, - "read message and emitting response {response:?}", + "read message and emitting response", ); out.push(response); } if handshake.complete() { - debug!(initiator = %is_initiator, "HS complete. Making result"); + debug!(initiator = %is_initiator, "Handshake completed"); let handshake_result = match handshake.into_result() { Ok(x) => x, Err(e) => { @@ -456,6 +583,7 @@ fn handle_setup_message( if let Step::SecretStream((enc_cipher, hs_result)) = replace(step, Step::NotInitialized) { let dec_cipher = DecryptCipher::from_handshake_rx_and_init_msg(&hs_result, msg)?; + plain_rx.push_back(Event::from(hs_result.clone())); *step = Step::Established((enc_cipher, dec_cipher, hs_result)); debug!(initiator = %is_initiator, "Step changed to {step}"); } @@ -465,38 +593,23 @@ fn handle_setup_message( } } -impl std::fmt::Display for Step { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - Step::NotInitialized => "NotInitialized", - Step::Handshake(_) => "Handshake", - Step::SecretStream(_) => "SecretStream", - Step::Established(_) => "Established", - } - ) - } -} - impl std::fmt::Debug for Encrypted { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Encrypted") //.field("io", &self.io) //.field("step", &self.step) - .field("is_initiator", &self.is_initiator) - .field("encrypted_tx", &self.encrypted_tx) - .field("encrypted_rx", &self.encrypted_rx) - .field("plain_tx", &self.plain_tx) - .field("plain_rx", &self.plain_rx) - .field("flush", &self.flush) + .field("initiator", &self.is_initiator) + .field("encrypted_tx.len()", &self.encrypted_tx.len()) + .field("encrypted_rx", &self.encrypted_rx.len()) + .field("plain_tx", &self.plain_tx.len()) + .field("plain_rx", &self.plain_rx.len()) + //.field("flush", &self.flush) .finish() } } #[cfg(test)] -mod tset { +mod test { use crate::{ framing::test::duplex, test_utils::create_result_connected, Uint24LELengthPrefixedFraming, @@ -505,6 +618,12 @@ mod tset { use super::*; use futures::{future::join, SinkExt, StreamExt}; + fn inner(e: Option) -> Vec { + if let Some(Event::Decrypted(Ok(x))) = e { + return x; + } + panic!() + } #[tokio::test] async fn encrypted() -> Result<()> { let hello = b"hello".to_vec(); @@ -513,26 +632,31 @@ mod tset { let mut left = Encrypted::new(true, lc); let mut right = Encrypted::new(false, rc); - let (_sent, receieved) = join(left.send(hello.clone()), right.next()).await; - assert_eq!(receieved.unwrap()?, hello); + let (_sent, _received) = join(left.send(hello.clone()), right.next()).await; + let (_sent, received) = join(left.send(hello.clone()), right.next()).await; + assert_eq!(inner(received), hello); assert!(left.encryption_established()); + assert!(right.encryption_established()); // NB: we cannot totally finish 'left.send' until the other side becomes active // because the handshake with the other side ('right') must complete // before the 'hello' message is sent. So we poll both the send and receive concurrently. - let (_sent, receieved) = join(left.send(hello.clone()), right.next()).await; + let (_sent, received) = join(left.send(hello.clone()), right.next()).await; + // right recieves left's message - assert_eq!(receieved.unwrap()?, hello); + assert_eq!(inner(received), hello); // now that the encrypted channel is established, we don't need to spawn. right.send(world.clone()).await.unwrap(); // left recieves right's message - assert_eq!(left.next().await.unwrap()?, world); + left.next().await; + assert_eq!(inner(left.next().await), world); Ok(()) } + #[tokio::test] async fn encrypted_many() -> Result<()> { let hello = b"hello".to_vec(); @@ -547,15 +671,17 @@ mod tset { let mut left = Encrypted::new(true, lc); let mut right = Encrypted::new(false, rc); - let (_sent, receieved) = join(left.send(hello.clone()), right.next()).await; - assert_eq!(receieved.unwrap()?, hello); + let (_sent, _received) = join(left.send(hello.clone()), right.next()).await; + let (_sent, received) = join(left.send(hello.clone()), right.next()).await; + assert_eq!(inner(received), hello); for d in &data { right.send(d.to_vec()).await?; } let mut result = vec![]; + let _ = left.next().await; for _ in &data { - result.push(left.next().await.unwrap()?); + result.push(inner(left.next().await)); } assert_eq!(result, data); Ok(()) @@ -572,8 +698,8 @@ mod tset { let mut left = Encrypted::new(true, left); let mut right = Encrypted::new(false, right); - let (_sent, receieved) = join(left.send(hello.clone()), right.next()).await; - assert_eq!(receieved.unwrap()?, hello); + let (_sent, _received) = join(left.send(hello.clone()), right.next()).await; + assert_eq!(inner(right.next().await), hello); let data = vec![ b"yolo".to_vec(), @@ -587,9 +713,10 @@ mod tset { for d in &data { right.send(d.to_vec()).await?; } + let _ = left.next().await; let mut result = vec![]; for _ in &data { - result.push(left.next().await.unwrap()?); + result.push(inner(left.next().await)); } assert_eq!(result, data); @@ -599,20 +726,22 @@ mod tset { } let mut result = vec![]; for _ in &data { - result.push(right.next().await.unwrap()?); + result.push(inner(right.next().await)); } assert_eq!(result, data); // send both ways + let mut res = vec![]; for d in &data { left.send(d.to_vec()).await?; right.send(d.to_vec()).await?; + res.push(d.to_vec()); } let mut left_result = vec![]; let mut right_result = vec![]; for _ in &data { - right_result.push(right.next().await.unwrap()?); - left_result.push(left.next().await.unwrap()?); + right_result.push(inner(right.next().await)); + left_result.push(inner(left.next().await)); } assert_eq!(right_result, data); assert_eq!(left_result, data); @@ -647,13 +776,16 @@ mod tset { // send a bad message to init side. It should reset, and emit new init msg init_side_messages.send(b"bad msg".to_vec()).await?; - let new_init_msg = init_side_messages.next().await.unwrap()?; - other_side_messages.send(new_init_msg).await?; - let new_response = other_side_messages.next().await.unwrap()?; - init_side_messages.send(new_response).await?; - let final_setup_message = init_side_messages.next().await.unwrap()?; - other_side_messages.send(final_setup_message).await?; + other_side_messages + .send(init_side_messages.next().await.unwrap()?) + .await?; + init_side_messages + .send(other_side_messages.next().await.unwrap()?) + .await?; + other_side_messages + .send(init_side_messages.next().await.unwrap()?) + .await?; // exchange one more message then we're set up init_side_messages @@ -676,8 +808,11 @@ mod tset { assert!(left.encryption_established()); assert!(right.encryption_established()); - assert_eq!(right.next().await.unwrap()?, b"hello"); - assert_eq!(left.next().await.unwrap()?, b"other hello"); + let _ = right.next().await; + let _ = left.next().await; + + assert_eq!(inner(right.next().await), b"hello"); + assert_eq!(inner(left.next().await), b"other hello"); Ok(()) } From 4a35d1ffa54d20c041d3c1cd19ed9a1295592d01 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 2 Apr 2025 13:13:57 -0400 Subject: [PATCH 070/206] wait for setup to handle commands --- src/protocol/modern.rs | 37 +++++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/src/protocol/modern.rs b/src/protocol/modern.rs index 6098ff7..42bd5d6 100644 --- a/src/protocol/modern.rs +++ b/src/protocol/modern.rs @@ -9,13 +9,14 @@ use std::io::{self, Error, ErrorKind, Result}; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; -use tracing::instrument; +use tracing::{debug, error, instrument, warn}; use crate::channels::{Channel, ChannelMap}; use crate::constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}; use crate::crypto::HandshakeResult; use crate::message::{ChannelMessage, Message}; -use crate::mqueue::MessageIo; +use crate::mqueue::{MessageIo, MqueueEvent}; +use crate::noise::EncryptionInfo; use crate::util::{map_channel_err, pretty_hash}; use crate::{ encrypted_framed_message_channel, schema::*, Encrypted, Uint24LELengthPrefixedFraming, @@ -229,7 +230,9 @@ where return_error!(this.poll_inbound_read(cx)); // Check for commands, but only once the connection is established. - return_error!(this.poll_commands(cx)); + if this.options.noise && this.handshake.is_some() { + return_error!(this.poll_commands(cx)); + } // Poll the keepalive timer. this.poll_keepalive(cx); @@ -241,7 +244,6 @@ where if let Some(event) = this.queued_events.pop_front() { Poll::Ready(Ok(event)) } else { - cx.waker().wake_by_ref(); Poll::Pending } } @@ -249,7 +251,10 @@ where /// Poll commands. fn poll_commands(&mut self, cx: &mut Context<'_>) -> Result<()> { while let Poll::Ready(Some(command)) = Pin::new(&mut self.command_rx).poll_next(cx) { - self.on_command(command)?; + if let Err(e) = self.on_command(command) { + error!(error = ?e, "Error handling command"); + return Err(e); + } } Ok(()) } @@ -297,10 +302,21 @@ where fn poll_inbound_read(&mut self, cx: &mut Context<'_>) -> Result<()> { loop { match self.io.poll_inbound(cx) { - Poll::Ready(Ok(messages)) => { - self.on_inbound_channel_messages(messages)?; - } - Poll::Ready(Err(e)) => return Err(e), + Poll::Ready(opt) => match opt { + Some(e) => match e { + MqueueEvent::Meta(einf) => match einf { + EncryptionInfo::Handshake(hs_res) => { + let remote_pubkey = parse_key(&hs_res.remote_pubkey)?; + self.handshake = Some(hs_res); + debug!(handshake = ?self.handshake, "set Protocol::handshake"); + self.queue_event(Event::Handshake(remote_pubkey)) + } + }, + MqueueEvent::Message(msgs) => self.on_inbound_channel_messages(msgs?)?, + }, + + None => return Ok(()), + }, Poll::Pending => return Ok(()), } } @@ -313,6 +329,7 @@ where loop { // if no parking or setup in progress if let Poll::Ready(Err(e)) = self.io.poll_outbound(cx) { + error!(err = ?e, "error from poll_outbound"); return Err(e); } // send messages outbound_rx @@ -489,7 +506,7 @@ where } } -/// Send [Command](Command)s to the [Protocol](Protocol). +/// Send [`Command`]s to the [`Protocol`]. #[derive(Clone, Debug)] pub struct CommandTx(Sender); From 35a6bac4742756d31e17832246016052af6fcc51 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 9 Apr 2025 12:34:18 -0500 Subject: [PATCH 071/206] move unused framed stuff into tests and do rename --- src/message/modern.rs | 310 +++++++++++++++++++++--------------------- src/mqueue.rs | 4 +- 2 files changed, 157 insertions(+), 157 deletions(-) diff --git a/src/message/modern.rs b/src/message/modern.rs index 2d8e732..e70bed2 100644 --- a/src/message/modern.rs +++ b/src/message/modern.rs @@ -42,43 +42,7 @@ impl Encoder for &[u8] { } } -/// A frame of data, either a buffer or a message. -#[derive(Clone, PartialEq)] -pub(crate) enum Frame { - /// A raw batch binary buffer. Used in the handshaking phase. - RawBatch(Vec>), - /// Message batch, containing one or more channel messsages. Used for everything after the handshake. - MessageBatch(Vec), -} - -impl fmt::Debug for Frame { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Frame::RawBatch(batch) => write!(f, "Frame(RawBatch <{}>)", batch.len()), - Frame::MessageBatch(messages) => write!(f, "Frame({messages:?})"), - } - } -} - -impl From for Frame { - fn from(m: ChannelMessage) -> Self { - Self::MessageBatch(vec![m]) - } -} - -impl From> for Frame { - fn from(m: Vec) -> Self { - Self::MessageBatch(m) - } -} - -impl From> for Frame { - fn from(m: Vec) -> Self { - Self::RawBatch(vec![m]) - } -} - -pub(crate) fn decode_many_channel_messages( +pub(crate) fn decode_framed_channel_messages( buf: &[u8], ) -> Result<(Vec, usize), io::Error> { let mut index = 0; @@ -93,7 +57,7 @@ pub(crate) fn decode_many_channel_messages( let stat = stat_uint24_le(&buf[index..]); if let Some((header_len, body_len)) = stat { - let (msgs, length) = decode_one_channel_message( + let (msgs, length) = decode_unframed_channel_messages( &buf[index + header_len..index + header_len + body_len as usize], )?; if length != body_len as usize { @@ -119,7 +83,7 @@ pub(crate) fn decode_many_channel_messages( Ok((combined_messages, index)) } // bad name bc it returns many. More like, decode unframed channel messages -pub(crate) fn decode_one_channel_message( +pub(crate) fn decode_unframed_channel_messages( buf: &[u8], ) -> Result<(Vec, usize), io::Error> { if buf.len() >= 3 && buf[0] == 0x00 { @@ -188,121 +152,6 @@ pub(crate) fn decode_one_channel_message( } } -impl Frame { - /// Decodes a frame from a buffer containing multiple concurrent messages. - fn preencode(&self, state: &mut State) -> Result { - match self { - Self::RawBatch(raw_batch) => { - for raw in raw_batch { - state.add_end(raw.as_slice().encoded_len()?)?; - } - } - #[allow(clippy::comparison_chain)] - Self::MessageBatch(messages) => { - if messages.len() == 1 { - if let Message::Open(_) = &messages[0].message { - // This is a special case with 0x00, 0x01 intro bytes - state.add_end(2 + &messages[0].encoded_len()?)?; - } else if let Message::Close(_) = &messages[0].message { - // This is a special case with 0x00, 0x03 intro bytes - state.add_end(2 + &messages[0].encoded_len()?)?; - } else { - (*state).preencode(&messages[0].channel)?; - state.add_end(messages[0].encoded_len()?)?; - } - } else if messages.len() > 1 { - // Two intro bytes 0x00 0x00, then channel id, then lengths - state.add_end(2)?; - let mut current_channel: u64 = messages[0].channel; - state.preencode(¤t_channel)?; - for message in messages.iter() { - if message.channel != current_channel { - // Channel changed, need to add a 0x00 in between and then the new - // channel - state.add_end(1)?; - state.preencode(&message.channel)?; - current_channel = message.channel; - } - let message_length = message.encoded_len()?; - state.preencode(&message_length)?; - state.add_end(message_length)?; - } - } - } - } - Ok(state.end()) - } -} - -impl Encoder for Frame { - fn encoded_len(&self) -> Result { - let body_len = self.preencode(&mut State::new())?; - match self { - Self::RawBatch(_) => Ok(body_len), - Self::MessageBatch(_) => Ok(3 + body_len), - } - } - - fn encode(&self, buf: &mut [u8]) -> Result { - let mut state = State::new(); - let header_len = if let Self::RawBatch(_) = self { 0 } else { 3 }; - let body_len = self.preencode(&mut state)?; - let len = body_len + header_len; - if buf.len() < len { - return Err(EncodingError::new( - EncodingErrorKind::Overflow, - &format!("Length does not fit buffer, {} > {}", len, buf.len()), - )); - } - match self { - Self::RawBatch(ref raw_batch) => { - for raw in raw_batch { - raw.as_slice().encode(buf)?; - } - } - #[allow(clippy::comparison_chain)] - Self::MessageBatch(ref messages) => { - write_uint24_le(body_len, buf); - let buf = buf.get_mut(3..).expect("Buffer should be over 3 bytes"); - if messages.len() == 1 { - if let Message::Open(_) = &messages[0].message { - // This is a special case with 0x00, 0x01 intro bytes - state.encode(&(0_u8), buf)?; - state.encode(&(1_u8), buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } else if let Message::Close(_) = &messages[0].message { - // This is a special case with 0x00, 0x03 intro bytes - state.encode(&(0_u8), buf)?; - state.encode(&(3_u8), buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } else { - state.encode(&messages[0].channel, buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } - } else if messages.len() > 1 { - // Two intro bytes 0x00 0x00, then channel id, then lengths - state.set_slice_to_buffer(&[0_u8, 0_u8], buf)?; - let mut current_channel: u64 = messages[0].channel; - state.encode(¤t_channel, buf)?; - for message in messages.iter() { - if message.channel != current_channel { - // Channel changed, need to add a 0x00 in between and then the new - // channel - state.encode(&(0_u8), buf)?; - state.encode(&message.channel, buf)?; - current_channel = message.channel; - } - let message_length = message.encoded_len()?; - state.encode(&message_length, buf)?; - state.add_start(message.encode(&mut buf[state.start()..])?)?; - } - } - } - }; - Ok(len) - } -} - fn prencode_channel_messages( messages: &[ChannelMessage], state: &mut State, @@ -688,6 +537,156 @@ mod tests { )* } } + /// A frame of data, either a buffer or a message. + #[derive(Clone, PartialEq)] + pub(crate) enum Frame { + /// A raw batch binary buffer. Used in the handshaking phase. + RawBatch(Vec>), + /// Message batch, containing one or more channel messsages. Used for everything after the handshake. + MessageBatch(Vec), + } + + impl fmt::Debug for Frame { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Frame::RawBatch(batch) => write!(f, "Frame(RawBatch <{}>)", batch.len()), + Frame::MessageBatch(messages) => write!(f, "Frame({messages:?})"), + } + } + } + + impl From for Frame { + fn from(m: ChannelMessage) -> Self { + Self::MessageBatch(vec![m]) + } + } + + impl From> for Frame { + fn from(m: Vec) -> Self { + Self::MessageBatch(m) + } + } + + impl From> for Frame { + fn from(m: Vec) -> Self { + Self::RawBatch(vec![m]) + } + } + + impl Frame { + /// Decodes a frame from a buffer containing multiple concurrent messages. + fn preencode(&self, state: &mut State) -> Result { + match self { + Self::RawBatch(raw_batch) => { + for raw in raw_batch { + state.add_end(raw.as_slice().encoded_len()?)?; + } + } + #[allow(clippy::comparison_chain)] + Self::MessageBatch(messages) => { + if messages.len() == 1 { + if let Message::Open(_) = &messages[0].message { + // This is a special case with 0x00, 0x01 intro bytes + state.add_end(2 + &messages[0].encoded_len()?)?; + } else if let Message::Close(_) = &messages[0].message { + // This is a special case with 0x00, 0x03 intro bytes + state.add_end(2 + &messages[0].encoded_len()?)?; + } else { + (*state).preencode(&messages[0].channel)?; + state.add_end(messages[0].encoded_len()?)?; + } + } else if messages.len() > 1 { + // Two intro bytes 0x00 0x00, then channel id, then lengths + state.add_end(2)?; + let mut current_channel: u64 = messages[0].channel; + state.preencode(¤t_channel)?; + for message in messages.iter() { + if message.channel != current_channel { + // Channel changed, need to add a 0x00 in between and then the new + // channel + state.add_end(1)?; + state.preencode(&message.channel)?; + current_channel = message.channel; + } + let message_length = message.encoded_len()?; + state.preencode(&message_length)?; + state.add_end(message_length)?; + } + } + } + } + Ok(state.end()) + } + } + + impl Encoder for Frame { + fn encoded_len(&self) -> Result { + let body_len = self.preencode(&mut State::new())?; + match self { + Self::RawBatch(_) => Ok(body_len), + Self::MessageBatch(_) => Ok(3 + body_len), + } + } + + fn encode(&self, buf: &mut [u8]) -> Result { + let mut state = State::new(); + let header_len = if let Self::RawBatch(_) = self { 0 } else { 3 }; + let body_len = self.preencode(&mut state)?; + let len = body_len + header_len; + if buf.len() < len { + return Err(EncodingError::new( + EncodingErrorKind::Overflow, + &format!("Length does not fit buffer, {} > {}", len, buf.len()), + )); + } + match self { + Self::RawBatch(ref raw_batch) => { + for raw in raw_batch { + raw.as_slice().encode(buf)?; + } + } + #[allow(clippy::comparison_chain)] + Self::MessageBatch(ref messages) => { + write_uint24_le(body_len, buf); + let buf = buf.get_mut(3..).expect("Buffer should be over 3 bytes"); + if messages.len() == 1 { + if let Message::Open(_) = &messages[0].message { + // This is a special case with 0x00, 0x01 intro bytes + state.encode(&(0_u8), buf)?; + state.encode(&(1_u8), buf)?; + state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; + } else if let Message::Close(_) = &messages[0].message { + // This is a special case with 0x00, 0x03 intro bytes + state.encode(&(0_u8), buf)?; + state.encode(&(3_u8), buf)?; + state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; + } else { + state.encode(&messages[0].channel, buf)?; + state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; + } + } else if messages.len() > 1 { + // Two intro bytes 0x00 0x00, then channel id, then lengths + state.set_slice_to_buffer(&[0_u8, 0_u8], buf)?; + let mut current_channel: u64 = messages[0].channel; + state.encode(¤t_channel, buf)?; + for message in messages.iter() { + if message.channel != current_channel { + // Channel changed, need to add a 0x00 in between and then the new + // channel + state.encode(&(0_u8), buf)?; + state.encode(&message.channel, buf)?; + current_channel = message.channel; + } + let message_length = message.encoded_len()?; + state.encode(&message_length, buf)?; + state.add_start(message.encode(&mut buf[state.start()..])?)?; + } + } + } + }; + Ok(len) + } + } #[test] fn message_encode_decode() { @@ -849,6 +848,7 @@ mod tests { }), ] } + impl Frame { pub(crate) fn decode_multiple(buf: &[u8]) -> Result { let mut index = 0; @@ -988,7 +988,7 @@ mod tests { let fres = Frame::decode_multiple(&fbuf)?; assert_eq!(fres, frame); - let cres_m = decode_many_channel_messages(&cbuf)?.0; + let cres_m = decode_framed_channel_messages(&cbuf)?.0; assert_eq!(cres_m, cmvec); } Ok(()) diff --git a/src/mqueue.rs b/src/mqueue.rs index 7e1fffd..cd86caf 100644 --- a/src/mqueue.rs +++ b/src/mqueue.rs @@ -12,7 +12,7 @@ use futures::{Sink, Stream}; use tracing::{error, instrument}; use crate::{ - message::{decode_many_channel_messages, ChannelMessage, Encoder as _}, + message::{decode_framed_channel_messages, ChannelMessage, Encoder as _}, noise::EncryptionInfo, NoiseEvent, }; @@ -29,7 +29,7 @@ impl From for MqueueEvent { NoiseEvent::Meta(einf) => Self::Meta(einf), NoiseEvent::Decrypted(dec_res) => { match dec_res { - Ok(encoded) => match decode_many_channel_messages(&encoded) { + Ok(encoded) => match decode_framed_channel_messages(&encoded) { //assert_eq!(_n_read, encoded.len()); } Ok((messsages, _n_read)) => Self::Message(Ok(messsages)), Err(e) => Self::Message(Err(e)), From 483e9ffdb592189bc1a1908d6d173b58c99ce832 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 9 Apr 2025 14:46:24 -0500 Subject: [PATCH 072/206] impl CompactEncodable for Schema --- src/schema.rs | 408 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 406 insertions(+), 2 deletions(-) diff --git a/src/schema.rs b/src/schema.rs index ef58e77..328e1f6 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -1,6 +1,10 @@ -use hypercore::encoding::{CompactEncoding, EncodingError, HypercoreState, State}; +use hypercore::encoding::{ + take_array, take_array_mut, write_array, write_slice, CompactEncodable, CompactEncoding, + EncodingError, HypercoreState, State, +}; use hypercore::{ - DataBlock, DataHash, DataSeek, DataUpgrade, Proof, RequestBlock, RequestSeek, RequestUpgrade, + chain_encoded_bytes, decode, sum_encoded_size, DataBlock, DataHash, DataSeek, DataUpgrade, + Proof, RequestBlock, RequestSeek, RequestUpgrade, }; /// Open message @@ -16,6 +20,55 @@ pub struct Open { pub capability: Option>, } +impl CompactEncodable for Open { + fn encoded_size(&self) -> Result { + let out = sum_encoded_size!(self, channel, protocol, discovery_key); + if self.capability.is_some() { + return Ok( + out + + 1 // flags for future use + + 32, // TODO capabalilities buff should always be 32 bytes, but it's a vec + ); + } + Ok(out) + } + + fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + let rest = chain_encoded_bytes!(self, buffer, channel, protocol, discovery_key); + if let Some(cap) = &self.capability { + let (_, rest) = take_array_mut::<1>(rest)?; + return write_slice(cap, rest); + } + Ok(rest) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let (channel, rest) = u64::decode(buffer)?; + let (protocol, rest) = String::decode(rest)?; + let (discovery_key, rest) = >::decode(rest)?; + // TODO this is a CLEAR bug it assumes nothing is encoded after this message + let (capability, rest) = if !rest.is_empty() { + let (_, rest) = take_array::<1>(rest)?; + let (capability, rest) = take_array::<32>(rest)?; + (Some(capability.to_vec()), rest) + } else { + (None, rest) + }; + Ok(( + Open { + channel, + protocol, + discovery_key, + capability, + }, + rest, + )) + } +} + impl CompactEncoding for State { fn preencode(&mut self, value: &Open) -> Result { self.preencode(&value.channel)?; @@ -43,6 +96,7 @@ impl CompactEncoding for State { let channel: u64 = self.decode(buffer)?; let protocol: String = self.decode(buffer)?; let discovery_key: Vec = self.decode(buffer)?; + // TODO This is a BUG!!! when anything is encoded **after** Open message let capability: Option> = if self.start() < self.end() { self.add_start(1)?; // flags for future use let capability: Vec = self.decode_fixed_32(buffer)?.to_vec(); @@ -66,6 +120,22 @@ pub struct Close { pub channel: u64, } +impl CompactEncodable for Close { + fn encoded_size(&self) -> Result { + Ok(self.channel.encoded_size()?) + } + + fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + self.channel.encoded_bytes(buffer) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + decode!(Close, buffer, {channel: u64}) + } +} impl CompactEncoding for State { fn preencode(&mut self, value: &Close) -> Result { self.preencode(&value.channel) @@ -98,6 +168,50 @@ pub struct Synchronize { pub can_upgrade: bool, } +impl CompactEncodable for Synchronize { + fn encoded_size(&self) -> Result { + Ok(1 + sum_encoded_size!(self, fork, length, remote_length)) + } + + fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + let mut flags: u8 = if self.can_upgrade { 1 } else { 0 }; + flags |= if self.uploading { 2 } else { 0 }; + flags |= if self.downloading { 4 } else { 0 }; + let rest = write_array(&[flags], buffer)?; + Ok(chain_encoded_bytes!( + self, + rest, + fork, + length, + remote_length + )) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let ([flags], rest) = take_array::<1>(buffer)?; + let (fork, rest) = u64::decode(rest)?; + let (length, rest) = u64::decode(rest)?; + let (remote_length, rest) = u64::decode(rest)?; + let can_upgrade = flags & 1 != 0; + let uploading = flags & 2 != 0; + let downloading = flags & 4 != 0; + Ok(( + Synchronize { + fork, + length, + remote_length, + can_upgrade, + uploading, + downloading, + }, + rest, + )) + } +} + impl CompactEncoding for State { fn preencode(&mut self, value: &Synchronize) -> Result { self.add_end(1)?; // flags @@ -152,6 +266,85 @@ pub struct Request { pub upgrade: Option, } +macro_rules! maybe_decode { + ($cond:expr, $type:ty, $buf:ident) => { + if $cond { + let (result, rest) = <$type>::decode($buf)?; + (Some(result), rest) + } else { + (None, $buf) + } + }; +} + +impl CompactEncodable for Request { + fn encoded_size(&self) -> Result { + let mut out = 1; // flags + out += sum_encoded_size!(self, id, fork); + if let Some(block) = &self.block { + out += block.encoded_size()?; + } + if let Some(hash) = &self.hash { + out += hash.encoded_size()?; + } + if let Some(seek) = &self.seek { + out += seek.encoded_size()?; + } + if let Some(upgrade) = &self.upgrade { + out += upgrade.encoded_size()?; + } + Ok(out) + } + + fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + let mut flags: u8 = if self.block.is_some() { 1 } else { 0 }; + flags |= if self.hash.is_some() { 2 } else { 0 }; + flags |= if self.seek.is_some() { 4 } else { 0 }; + flags |= if self.upgrade.is_some() { 8 } else { 0 }; + let mut rest = write_array(&[flags], buffer)?; + chain_encoded_bytes!(self, rest, id, fork); + + if let Some(block) = &self.block { + rest = block.encoded_bytes(rest)?; + } + if let Some(hash) = &self.hash { + rest = hash.encoded_bytes(rest)?; + } + if let Some(seek) = &self.seek { + rest = seek.encoded_bytes(rest)?; + } + if let Some(upgrade) = &self.upgrade { + rest = upgrade.encoded_bytes(rest)?; + } + Ok(rest) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let ([flags], rest) = take_array::<1>(buffer)?; + let (id, rest) = u64::decode(rest)?; + let (fork, rest) = u64::decode(rest)?; + + let (block, rest) = maybe_decode!(flags & 1 != 0, RequestBlock, rest); + let (hash, rest) = maybe_decode!(flags & 2 != 0, RequestBlock, rest); + let (seek, rest) = maybe_decode!(flags & 4 != 0, RequestSeek, rest); + let (upgrade, rest) = maybe_decode!(flags & 8 != 0, RequestUpgrade, rest); + Ok(( + Request { + id, + fork, + block, + hash, + seek, + upgrade, + }, + rest, + )) + } +} + impl CompactEncoding for HypercoreState { fn preencode(&mut self, value: &Request) -> Result { self.add_end(1)?; // flags @@ -237,6 +430,23 @@ pub struct Cancel { pub request: u64, } +impl CompactEncodable for Cancel { + fn encoded_size(&self) -> Result { + self.request.encoded_size() + } + + fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + self.request.encoded_bytes(buffer) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let (request, rest) = u64::decode(buffer)?; + Ok((Cancel { request }, rest)) + } +} impl CompactEncoding for State { fn preencode(&mut self, value: &Cancel) -> Result { self.preencode(&value.request) @@ -269,6 +479,74 @@ pub struct Data { pub upgrade: Option, } +macro_rules! opt_encoded_size { + ($opt:expr, $sum:ident) => { + if let Some(thing) = $opt { + $sum += thing.encoded_size()?; + } + }; +} + +macro_rules! opt_encoded_bytes { + ($opt:expr, $buf:ident) => { + if let Some(thing) = $opt { + thing.encoded_bytes($buf)? + } else { + $buf + } + }; +} +impl CompactEncodable for Data { + fn encoded_size(&self) -> Result { + let mut out = 1; // flags + out += sum_encoded_size!(self, request, fork); + opt_encoded_size!(&self.block, out); + opt_encoded_size!(&self.hash, out); + opt_encoded_size!(&self.seek, out); + opt_encoded_size!(&self.upgrade, out); + Ok(out) + } + + fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + let mut flags: u8 = if self.block.is_some() { 1 } else { 0 }; + flags |= if self.hash.is_some() { 2 } else { 0 }; + flags |= if self.seek.is_some() { 4 } else { 0 }; + flags |= if self.upgrade.is_some() { 8 } else { 0 }; + let rest = write_array(&[flags], buffer)?; + chain_encoded_bytes!(self, rest, request, fork); + + let rest = opt_encoded_bytes!(&self.block, rest); + let rest = opt_encoded_bytes!(&self.hash, rest); + let rest = opt_encoded_bytes!(&self.seek, rest); + let rest = opt_encoded_bytes!(&self.upgrade, rest); + Ok(rest) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let ([flags], rest) = take_array::<1>(buffer)?; + let (request, rest) = u64::decode(rest)?; + let (fork, rest) = u64::decode(rest)?; + let (block, rest) = maybe_decode!(flags & 1 != 0, DataBlock, rest); + let (hash, rest) = maybe_decode!(flags & 2 != 0, DataHash, rest); + let (seek, rest) = maybe_decode!(flags & 4 != 0, DataSeek, rest); + let (upgrade, rest) = maybe_decode!(flags & 8 != 0, DataUpgrade, rest); + Ok(( + Data { + request, + fork, + block, + hash, + seek, + upgrade, + }, + rest, + )) + } +} + impl CompactEncoding for HypercoreState { fn preencode(&mut self, value: &Data) -> Result { self.add_end(1)?; // flags @@ -367,6 +645,22 @@ pub struct NoData { pub request: u64, } +impl CompactEncodable for NoData { + fn encoded_size(&self) -> Result { + Ok(sum_encoded_size!(self, request)) + } + + fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + Ok(chain_encoded_bytes!(self, buffer, request)) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + decode!(NoData, buffer, { request: u64 }) + } +} impl CompactEncoding for State { fn preencode(&mut self, value: &NoData) -> Result { self.preencode(&value.request) @@ -390,6 +684,23 @@ pub struct Want { /// Length pub length: u64, } + +impl CompactEncodable for Want { + fn encoded_size(&self) -> Result { + Ok(sum_encoded_size!(self, start, length)) + } + + fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + Ok(chain_encoded_bytes!(self, buffer, start, length)) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + decode!(Self, buffer, { start: u64, length: u64 }) + } +} impl CompactEncoding for State { fn preencode(&mut self, value: &Want) -> Result { self.preencode(&value.start)?; @@ -416,6 +727,24 @@ pub struct Unwant { /// Length pub length: u64, } + +impl CompactEncodable for Unwant { + fn encoded_size(&self) -> Result { + Ok(sum_encoded_size!(self, start, length)) + } + + fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + Ok(chain_encoded_bytes!(self, buffer, start, length)) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + decode!(Self, buffer, { start: u64, length: u64 }) + } +} + impl CompactEncoding for State { fn preencode(&mut self, value: &Unwant) -> Result { self.preencode(&value.start)?; @@ -442,6 +771,22 @@ pub struct Bitfield { /// Bitfield in 32 bit chunks beginning from `start` pub bitfield: Vec, } +impl CompactEncodable for Bitfield { + fn encoded_size(&self) -> Result { + Ok(sum_encoded_size!(self, start, bitfield)) + } + + fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + Ok(chain_encoded_bytes!(self, buffer, start, bitfield)) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + decode!(Self, buffer, { start: u64, bitfield: Vec }) + } +} impl CompactEncoding for State { fn preencode(&mut self, value: &Bitfield) -> Result { self.preencode(&value.start)?; @@ -473,6 +818,49 @@ pub struct Range { pub length: u64, } +impl CompactEncodable for Range { + fn encoded_size(&self) -> Result { + let mut out = 1 + sum_encoded_size!(self, start); + if self.length != 1 { + out += self.length.encoded_size()?; + } + Ok(out) + } + + fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + let mut flags: u8 = if self.drop { 1 } else { 0 }; + flags |= if self.length == 1 { 2 } else { 0 }; + let rest = write_array(&[flags], buffer)?; + let rest = self.start.encoded_bytes(rest)?; + if self.length != 1 { + return self.length.encoded_bytes(rest); + } + Ok(rest) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let ([flags], rest) = take_array::<1>(buffer)?; + let (start, rest) = u64::decode(rest)?; + let drop = flags & 1 != 0; + let (length, rest) = if flags & 2 != 0 { + (1, rest) + } else { + u64::decode(rest)? + }; + Ok(( + Range { + drop, + length, + start, + }, + rest, + )) + } +} + impl CompactEncoding for State { fn preencode(&mut self, value: &Range) -> Result { self.add_end(1)?; // flags @@ -519,6 +907,22 @@ pub struct Extension { /// Message content, use empty vector for no data. pub message: Vec, } +impl CompactEncodable for Extension { + fn encoded_size(&self) -> Result { + Ok(sum_encoded_size!(self, name, message)) + } + + fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + Ok(chain_encoded_bytes!(self, buffer, name, message)) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + decode!(Self, buffer, { name: String, message: Vec }) + } +} impl CompactEncoding for State { fn preencode(&mut self, value: &Extension) -> Result { self.preencode(&value.name)?; From 350f03a6fe5b8891403d2aba22ae6f2172ae152a Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 14 Apr 2025 17:19:09 -0400 Subject: [PATCH 073/206] encoded_bytes renamed to encode --- src/schema.rs | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/src/schema.rs b/src/schema.rs index 328e1f6..676348e 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -33,7 +33,7 @@ impl CompactEncodable for Open { Ok(out) } - fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { let rest = chain_encoded_bytes!(self, buffer, channel, protocol, discovery_key); if let Some(cap) = &self.capability { let (_, rest) = take_array_mut::<1>(rest)?; @@ -125,8 +125,8 @@ impl CompactEncodable for Close { Ok(self.channel.encoded_size()?) } - fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { - self.channel.encoded_bytes(buffer) + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + self.channel.encode(buffer) } fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> @@ -173,7 +173,7 @@ impl CompactEncodable for Synchronize { Ok(1 + sum_encoded_size!(self, fork, length, remote_length)) } - fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { let mut flags: u8 = if self.can_upgrade { 1 } else { 0 }; flags |= if self.uploading { 2 } else { 0 }; flags |= if self.downloading { 4 } else { 0 }; @@ -296,7 +296,7 @@ impl CompactEncodable for Request { Ok(out) } - fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { let mut flags: u8 = if self.block.is_some() { 1 } else { 0 }; flags |= if self.hash.is_some() { 2 } else { 0 }; flags |= if self.seek.is_some() { 4 } else { 0 }; @@ -305,16 +305,16 @@ impl CompactEncodable for Request { chain_encoded_bytes!(self, rest, id, fork); if let Some(block) = &self.block { - rest = block.encoded_bytes(rest)?; + rest = block.encode(rest)?; } if let Some(hash) = &self.hash { - rest = hash.encoded_bytes(rest)?; + rest = hash.encode(rest)?; } if let Some(seek) = &self.seek { - rest = seek.encoded_bytes(rest)?; + rest = seek.encode(rest)?; } if let Some(upgrade) = &self.upgrade { - rest = upgrade.encoded_bytes(rest)?; + rest = upgrade.encode(rest)?; } Ok(rest) } @@ -435,8 +435,8 @@ impl CompactEncodable for Cancel { self.request.encoded_size() } - fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { - self.request.encoded_bytes(buffer) + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + self.request.encode(buffer) } fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> @@ -490,7 +490,7 @@ macro_rules! opt_encoded_size { macro_rules! opt_encoded_bytes { ($opt:expr, $buf:ident) => { if let Some(thing) = $opt { - thing.encoded_bytes($buf)? + thing.encode($buf)? } else { $buf } @@ -507,7 +507,7 @@ impl CompactEncodable for Data { Ok(out) } - fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { let mut flags: u8 = if self.block.is_some() { 1 } else { 0 }; flags |= if self.hash.is_some() { 2 } else { 0 }; flags |= if self.seek.is_some() { 4 } else { 0 }; @@ -650,7 +650,7 @@ impl CompactEncodable for NoData { Ok(sum_encoded_size!(self, request)) } - fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { Ok(chain_encoded_bytes!(self, buffer, request)) } @@ -690,7 +690,7 @@ impl CompactEncodable for Want { Ok(sum_encoded_size!(self, start, length)) } - fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { Ok(chain_encoded_bytes!(self, buffer, start, length)) } @@ -733,7 +733,7 @@ impl CompactEncodable for Unwant { Ok(sum_encoded_size!(self, start, length)) } - fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { Ok(chain_encoded_bytes!(self, buffer, start, length)) } @@ -776,7 +776,7 @@ impl CompactEncodable for Bitfield { Ok(sum_encoded_size!(self, start, bitfield)) } - fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { Ok(chain_encoded_bytes!(self, buffer, start, bitfield)) } @@ -827,13 +827,13 @@ impl CompactEncodable for Range { Ok(out) } - fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { let mut flags: u8 = if self.drop { 1 } else { 0 }; flags |= if self.length == 1 { 2 } else { 0 }; let rest = write_array(&[flags], buffer)?; - let rest = self.start.encoded_bytes(rest)?; + let rest = self.start.encode(rest)?; if self.length != 1 { - return self.length.encoded_bytes(rest); + return self.length.encode(rest); } Ok(rest) } @@ -912,7 +912,7 @@ impl CompactEncodable for Extension { Ok(sum_encoded_size!(self, name, message)) } - fn encoded_bytes<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { Ok(chain_encoded_bytes!(self, buffer, name, message)) } From bf0d309f68d99c3d291bb798ab3fc158ec3d9c12 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 22 Apr 2025 14:16:53 -0400 Subject: [PATCH 074/206] use new CompactEncoding in schema.rs --- src/schema.rs | 465 +++++--------------------------------------------- 1 file changed, 40 insertions(+), 425 deletions(-) diff --git a/src/schema.rs b/src/schema.rs index 676348e..8a9a0a2 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -1,10 +1,10 @@ use hypercore::encoding::{ - take_array, take_array_mut, write_array, write_slice, CompactEncodable, CompactEncoding, - EncodingError, HypercoreState, State, + map_encode, sum_encoded_size, take_array, take_array_mut, write_array, write_slice, + CompactEncoding, EncodingError, }; use hypercore::{ - chain_encoded_bytes, decode, sum_encoded_size, DataBlock, DataHash, DataSeek, DataUpgrade, - Proof, RequestBlock, RequestSeek, RequestUpgrade, + decode, DataBlock, DataHash, DataSeek, DataUpgrade, Proof, RequestBlock, RequestSeek, + RequestUpgrade, }; /// Open message @@ -20,9 +20,9 @@ pub struct Open { pub capability: Option>, } -impl CompactEncodable for Open { +impl CompactEncoding for Open { fn encoded_size(&self) -> Result { - let out = sum_encoded_size!(self, channel, protocol, discovery_key); + let out = sum_encoded_size!(self.channel, self.protocol, self.discovery_key); if self.capability.is_some() { return Ok( out @@ -34,7 +34,7 @@ impl CompactEncodable for Open { } fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { - let rest = chain_encoded_bytes!(self, buffer, channel, protocol, discovery_key); + let rest = map_encode!(buffer, self.channel, self.protocol, self.discovery_key); if let Some(cap) = &self.capability { let (_, rest) = take_array_mut::<1>(rest)?; return write_slice(cap, rest); @@ -69,50 +69,6 @@ impl CompactEncodable for Open { } } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Open) -> Result { - self.preencode(&value.channel)?; - self.preencode(&value.protocol)?; - self.preencode(&value.discovery_key)?; - if value.capability.is_some() { - self.add_end(1)?; // flags for future use - self.preencode_fixed_32()?; - } - Ok(self.end()) - } - - fn encode(&mut self, value: &Open, buffer: &mut [u8]) -> Result { - self.encode(&value.channel, buffer)?; - self.encode(&value.protocol, buffer)?; - self.encode(&value.discovery_key, buffer)?; - if let Some(capability) = &value.capability { - self.add_start(1)?; // flags for future use - self.encode_fixed_32(capability, buffer)?; - } - Ok(self.start()) - } - - fn decode(&mut self, buffer: &[u8]) -> Result { - let channel: u64 = self.decode(buffer)?; - let protocol: String = self.decode(buffer)?; - let discovery_key: Vec = self.decode(buffer)?; - // TODO This is a BUG!!! when anything is encoded **after** Open message - let capability: Option> = if self.start() < self.end() { - self.add_start(1)?; // flags for future use - let capability: Vec = self.decode_fixed_32(buffer)?.to_vec(); - Some(capability) - } else { - None - }; - Ok(Open { - channel, - protocol, - discovery_key, - capability, - }) - } -} - /// Close message #[derive(Debug, Clone, PartialEq)] pub struct Close { @@ -120,7 +76,7 @@ pub struct Close { pub channel: u64, } -impl CompactEncodable for Close { +impl CompactEncoding for Close { fn encoded_size(&self) -> Result { Ok(self.channel.encoded_size()?) } @@ -136,20 +92,6 @@ impl CompactEncodable for Close { decode!(Close, buffer, {channel: u64}) } } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Close) -> Result { - self.preencode(&value.channel) - } - - fn encode(&mut self, value: &Close, buffer: &mut [u8]) -> Result { - self.encode(&value.channel, buffer) - } - - fn decode(&mut self, buffer: &[u8]) -> Result { - let channel: u64 = self.decode(buffer)?; - Ok(Close { channel }) - } -} /// Synchronize message. Type 0. #[derive(Debug, Clone, PartialEq)] @@ -168,22 +110,22 @@ pub struct Synchronize { pub can_upgrade: bool, } -impl CompactEncodable for Synchronize { +impl CompactEncoding for Synchronize { fn encoded_size(&self) -> Result { - Ok(1 + sum_encoded_size!(self, fork, length, remote_length)) + Ok(1 + sum_encoded_size!(self.fork, self.length, self.remote_length)) } fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { let mut flags: u8 = if self.can_upgrade { 1 } else { 0 }; flags |= if self.uploading { 2 } else { 0 }; flags |= if self.downloading { 4 } else { 0 }; + dbg!(flags); let rest = write_array(&[flags], buffer)?; - Ok(chain_encoded_bytes!( - self, + Ok(map_encode!( rest, - fork, - length, - remote_length + self.fork, + self.length, + self.remote_length )) } @@ -192,6 +134,7 @@ impl CompactEncodable for Synchronize { Self: Sized, { let ([flags], rest) = take_array::<1>(buffer)?; + dbg!(flags); let (fork, rest) = u64::decode(rest)?; let (length, rest) = u64::decode(rest)?; let (remote_length, rest) = u64::decode(rest)?; @@ -212,43 +155,6 @@ impl CompactEncodable for Synchronize { } } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Synchronize) -> Result { - self.add_end(1)?; // flags - self.preencode(&value.fork)?; - self.preencode(&value.length)?; - self.preencode(&value.remote_length) - } - - fn encode(&mut self, value: &Synchronize, buffer: &mut [u8]) -> Result { - let mut flags: u8 = if value.can_upgrade { 1 } else { 0 }; - flags |= if value.uploading { 2 } else { 0 }; - flags |= if value.downloading { 4 } else { 0 }; - self.encode(&flags, buffer)?; - self.encode(&value.fork, buffer)?; - self.encode(&value.length, buffer)?; - self.encode(&value.remote_length, buffer) - } - - fn decode(&mut self, buffer: &[u8]) -> Result { - let flags: u8 = self.decode(buffer)?; - let fork: u64 = self.decode(buffer)?; - let length: u64 = self.decode(buffer)?; - let remote_length: u64 = self.decode(buffer)?; - let can_upgrade = flags & 1 != 0; - let uploading = flags & 2 != 0; - let downloading = flags & 4 != 0; - Ok(Synchronize { - fork, - length, - remote_length, - can_upgrade, - uploading, - downloading, - }) - } -} - /// Request message. Type 1. #[derive(Debug, Clone, PartialEq)] pub struct Request { @@ -277,10 +183,10 @@ macro_rules! maybe_decode { }; } -impl CompactEncodable for Request { +impl CompactEncoding for Request { fn encoded_size(&self) -> Result { let mut out = 1; // flags - out += sum_encoded_size!(self, id, fork); + out += sum_encoded_size!(self.id, self.fork); if let Some(block) = &self.block { out += block.encoded_size()?; } @@ -302,7 +208,7 @@ impl CompactEncodable for Request { flags |= if self.seek.is_some() { 4 } else { 0 }; flags |= if self.upgrade.is_some() { 8 } else { 0 }; let mut rest = write_array(&[flags], buffer)?; - chain_encoded_bytes!(self, rest, id, fork); + rest = map_encode!(rest, self.id, self.fork); if let Some(block) = &self.block { rest = block.encode(rest)?; @@ -345,84 +251,6 @@ impl CompactEncodable for Request { } } -impl CompactEncoding for HypercoreState { - fn preencode(&mut self, value: &Request) -> Result { - self.add_end(1)?; // flags - self.0.preencode(&value.id)?; - self.0.preencode(&value.fork)?; - if let Some(block) = &value.block { - self.preencode(block)?; - } - if let Some(hash) = &value.hash { - self.preencode(hash)?; - } - if let Some(seek) = &value.seek { - self.preencode(seek)?; - } - if let Some(upgrade) = &value.upgrade { - self.preencode(upgrade)?; - } - Ok(self.end()) - } - - fn encode(&mut self, value: &Request, buffer: &mut [u8]) -> Result { - let mut flags: u8 = if value.block.is_some() { 1 } else { 0 }; - flags |= if value.hash.is_some() { 2 } else { 0 }; - flags |= if value.seek.is_some() { 4 } else { 0 }; - flags |= if value.upgrade.is_some() { 8 } else { 0 }; - self.0.encode(&flags, buffer)?; - self.0.encode(&value.id, buffer)?; - self.0.encode(&value.fork, buffer)?; - if let Some(block) = &value.block { - self.encode(block, buffer)?; - } - if let Some(hash) = &value.hash { - self.encode(hash, buffer)?; - } - if let Some(seek) = &value.seek { - self.encode(seek, buffer)?; - } - if let Some(upgrade) = &value.upgrade { - self.encode(upgrade, buffer)?; - } - Ok(self.start()) - } - - fn decode(&mut self, buffer: &[u8]) -> Result { - let flags: u8 = self.0.decode(buffer)?; - let id: u64 = self.0.decode(buffer)?; - let fork: u64 = self.0.decode(buffer)?; - let block: Option = if flags & 1 != 0 { - Some(self.decode(buffer)?) - } else { - None - }; - let hash: Option = if flags & 2 != 0 { - Some(self.decode(buffer)?) - } else { - None - }; - let seek: Option = if flags & 4 != 0 { - Some(self.decode(buffer)?) - } else { - None - }; - let upgrade: Option = if flags & 8 != 0 { - Some(self.decode(buffer)?) - } else { - None - }; - Ok(Request { - id, - fork, - block, - hash, - seek, - upgrade, - }) - } -} - /// Cancel message for a [Request]. Type 2 #[derive(Debug, Clone, PartialEq)] pub struct Cancel { @@ -430,7 +258,7 @@ pub struct Cancel { pub request: u64, } -impl CompactEncodable for Cancel { +impl CompactEncoding for Cancel { fn encoded_size(&self) -> Result { self.request.encoded_size() } @@ -447,20 +275,6 @@ impl CompactEncodable for Cancel { Ok((Cancel { request }, rest)) } } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Cancel) -> Result { - self.preencode(&value.request) - } - - fn encode(&mut self, value: &Cancel, buffer: &mut [u8]) -> Result { - self.encode(&value.request, buffer) - } - - fn decode(&mut self, buffer: &[u8]) -> Result { - let request: u64 = self.decode(buffer)?; - Ok(Cancel { request }) - } -} /// Data message responding to received [Request]. Type 3. #[derive(Debug, Clone, PartialEq)] @@ -496,10 +310,10 @@ macro_rules! opt_encoded_bytes { } }; } -impl CompactEncodable for Data { +impl CompactEncoding for Data { fn encoded_size(&self) -> Result { let mut out = 1; // flags - out += sum_encoded_size!(self, request, fork); + out += sum_encoded_size!(self.request, self.fork); opt_encoded_size!(&self.block, out); opt_encoded_size!(&self.hash, out); opt_encoded_size!(&self.seek, out); @@ -513,7 +327,7 @@ impl CompactEncodable for Data { flags |= if self.seek.is_some() { 4 } else { 0 }; flags |= if self.upgrade.is_some() { 8 } else { 0 }; let rest = write_array(&[flags], buffer)?; - chain_encoded_bytes!(self, rest, request, fork); + let rest = map_encode!(rest, self.request, self.fork); let rest = opt_encoded_bytes!(&self.block, rest); let rest = opt_encoded_bytes!(&self.hash, rest); @@ -547,84 +361,6 @@ impl CompactEncodable for Data { } } -impl CompactEncoding for HypercoreState { - fn preencode(&mut self, value: &Data) -> Result { - self.add_end(1)?; // flags - self.0.preencode(&value.request)?; - self.0.preencode(&value.fork)?; - if let Some(block) = &value.block { - self.preencode(block)?; - } - if let Some(hash) = &value.hash { - self.preencode(hash)?; - } - if let Some(seek) = &value.seek { - self.preencode(seek)?; - } - if let Some(upgrade) = &value.upgrade { - self.preencode(upgrade)?; - } - Ok(self.end()) - } - - fn encode(&mut self, value: &Data, buffer: &mut [u8]) -> Result { - let mut flags: u8 = if value.block.is_some() { 1 } else { 0 }; - flags |= if value.hash.is_some() { 2 } else { 0 }; - flags |= if value.seek.is_some() { 4 } else { 0 }; - flags |= if value.upgrade.is_some() { 8 } else { 0 }; - self.0.encode(&flags, buffer)?; - self.0.encode(&value.request, buffer)?; - self.0.encode(&value.fork, buffer)?; - if let Some(block) = &value.block { - self.encode(block, buffer)?; - } - if let Some(hash) = &value.hash { - self.encode(hash, buffer)?; - } - if let Some(seek) = &value.seek { - self.encode(seek, buffer)?; - } - if let Some(upgrade) = &value.upgrade { - self.encode(upgrade, buffer)?; - } - Ok(self.start()) - } - - fn decode(&mut self, buffer: &[u8]) -> Result { - let flags: u8 = self.0.decode(buffer)?; - let request: u64 = self.0.decode(buffer)?; - let fork: u64 = self.0.decode(buffer)?; - let block: Option = if flags & 1 != 0 { - Some(self.decode(buffer)?) - } else { - None - }; - let hash: Option = if flags & 2 != 0 { - Some(self.decode(buffer)?) - } else { - None - }; - let seek: Option = if flags & 4 != 0 { - Some(self.decode(buffer)?) - } else { - None - }; - let upgrade: Option = if flags & 8 != 0 { - Some(self.decode(buffer)?) - } else { - None - }; - Ok(Data { - request, - fork, - block, - hash, - seek, - upgrade, - }) - } -} - impl Data { /// Transform Data message into a Proof emptying fields pub fn into_proof(&mut self) -> Proof { @@ -645,13 +381,13 @@ pub struct NoData { pub request: u64, } -impl CompactEncodable for NoData { +impl CompactEncoding for NoData { fn encoded_size(&self) -> Result { - Ok(sum_encoded_size!(self, request)) + Ok(sum_encoded_size!(self.request)) } fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { - Ok(chain_encoded_bytes!(self, buffer, request)) + Ok(map_encode!(buffer, self.request)) } fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> @@ -661,20 +397,6 @@ impl CompactEncodable for NoData { decode!(NoData, buffer, { request: u64 }) } } -impl CompactEncoding for State { - fn preencode(&mut self, value: &NoData) -> Result { - self.preencode(&value.request) - } - - fn encode(&mut self, value: &NoData, buffer: &mut [u8]) -> Result { - self.encode(&value.request, buffer) - } - - fn decode(&mut self, buffer: &[u8]) -> Result { - let request: u64 = self.decode(buffer)?; - Ok(NoData { request }) - } -} /// Want message. Type 5. #[derive(Debug, Clone, PartialEq)] @@ -685,13 +407,13 @@ pub struct Want { pub length: u64, } -impl CompactEncodable for Want { +impl CompactEncoding for Want { fn encoded_size(&self) -> Result { - Ok(sum_encoded_size!(self, start, length)) + Ok(sum_encoded_size!(self.start, self.length)) } fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { - Ok(chain_encoded_bytes!(self, buffer, start, length)) + Ok(map_encode!(buffer, self.start, self.length)) } fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> @@ -701,23 +423,6 @@ impl CompactEncodable for Want { decode!(Self, buffer, { start: u64, length: u64 }) } } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Want) -> Result { - self.preencode(&value.start)?; - self.preencode(&value.length) - } - - fn encode(&mut self, value: &Want, buffer: &mut [u8]) -> Result { - self.encode(&value.start, buffer)?; - self.encode(&value.length, buffer) - } - - fn decode(&mut self, buffer: &[u8]) -> Result { - let start: u64 = self.decode(buffer)?; - let length: u64 = self.decode(buffer)?; - Ok(Want { start, length }) - } -} /// Un-want message. Type 6. #[derive(Debug, Clone, PartialEq)] @@ -728,13 +433,13 @@ pub struct Unwant { pub length: u64, } -impl CompactEncodable for Unwant { +impl CompactEncoding for Unwant { fn encoded_size(&self) -> Result { - Ok(sum_encoded_size!(self, start, length)) + Ok(sum_encoded_size!(self.start, self.length)) } fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { - Ok(chain_encoded_bytes!(self, buffer, start, length)) + Ok(map_encode!(buffer, self.start, self.length)) } fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> @@ -745,24 +450,6 @@ impl CompactEncodable for Unwant { } } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Unwant) -> Result { - self.preencode(&value.start)?; - self.preencode(&value.length) - } - - fn encode(&mut self, value: &Unwant, buffer: &mut [u8]) -> Result { - self.encode(&value.start, buffer)?; - self.encode(&value.length, buffer) - } - - fn decode(&mut self, buffer: &[u8]) -> Result { - let start: u64 = self.decode(buffer)?; - let length: u64 = self.decode(buffer)?; - Ok(Unwant { start, length }) - } -} - /// Bitfield message. Type 7. #[derive(Debug, Clone, PartialEq)] pub struct Bitfield { @@ -771,13 +458,13 @@ pub struct Bitfield { /// Bitfield in 32 bit chunks beginning from `start` pub bitfield: Vec, } -impl CompactEncodable for Bitfield { +impl CompactEncoding for Bitfield { fn encoded_size(&self) -> Result { - Ok(sum_encoded_size!(self, start, bitfield)) + Ok(sum_encoded_size!(self.start, self.bitfield)) } fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { - Ok(chain_encoded_bytes!(self, buffer, start, bitfield)) + Ok(map_encode!(buffer, self.start, self.bitfield)) } fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> @@ -787,23 +474,6 @@ impl CompactEncodable for Bitfield { decode!(Self, buffer, { start: u64, bitfield: Vec }) } } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Bitfield) -> Result { - self.preencode(&value.start)?; - self.preencode(&value.bitfield) - } - - fn encode(&mut self, value: &Bitfield, buffer: &mut [u8]) -> Result { - self.encode(&value.start, buffer)?; - self.encode(&value.bitfield, buffer) - } - - fn decode(&mut self, buffer: &[u8]) -> Result { - let start: u64 = self.decode(buffer)?; - let bitfield: Vec = self.decode(buffer)?; - Ok(Bitfield { start, bitfield }) - } -} /// Range message. Type 8. /// Notifies Peer's that the Sender has a range of contiguous blocks. @@ -818,9 +488,9 @@ pub struct Range { pub length: u64, } -impl CompactEncodable for Range { +impl CompactEncoding for Range { fn encoded_size(&self) -> Result { - let mut out = 1 + sum_encoded_size!(self, start); + let mut out = 1 + sum_encoded_size!(self.start); if self.length != 1 { out += self.length.encoded_size()?; } @@ -861,44 +531,6 @@ impl CompactEncodable for Range { } } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Range) -> Result { - self.add_end(1)?; // flags - self.preencode(&value.start)?; - if value.length != 1 { - self.preencode(&value.length)?; - } - Ok(self.end()) - } - - fn encode(&mut self, value: &Range, buffer: &mut [u8]) -> Result { - let mut flags: u8 = if value.drop { 1 } else { 0 }; - flags |= if value.length == 1 { 2 } else { 0 }; - self.encode(&flags, buffer)?; - self.encode(&value.start, buffer)?; - if value.length != 1 { - self.encode(&value.length, buffer)?; - } - Ok(self.end()) - } - - fn decode(&mut self, buffer: &[u8]) -> Result { - let flags: u8 = self.decode(buffer)?; - let start: u64 = self.decode(buffer)?; - let drop = flags & 1 != 0; - let length: u64 = if flags & 2 != 0 { - 1 - } else { - self.decode(buffer)? - }; - Ok(Range { - drop, - length, - start, - }) - } -} - /// Extension message. Type 9. Use this for custom messages in your application. #[derive(Debug, Clone, PartialEq)] pub struct Extension { @@ -907,13 +539,13 @@ pub struct Extension { /// Message content, use empty vector for no data. pub message: Vec, } -impl CompactEncodable for Extension { +impl CompactEncoding for Extension { fn encoded_size(&self) -> Result { - Ok(sum_encoded_size!(self, name, message)) + Ok(sum_encoded_size!(self.name, self.message)) } fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { - Ok(chain_encoded_bytes!(self, buffer, name, message)) + Ok(map_encode!(buffer, self.name, self.message)) } fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> @@ -923,20 +555,3 @@ impl CompactEncodable for Extension { decode!(Self, buffer, { name: String, message: Vec }) } } -impl CompactEncoding for State { - fn preencode(&mut self, value: &Extension) -> Result { - self.preencode(&value.name)?; - self.preencode_raw_buffer(&value.message) - } - - fn encode(&mut self, value: &Extension, buffer: &mut [u8]) -> Result { - self.encode(&value.name, buffer)?; - self.encode_raw_buffer(&value.message, buffer) - } - - fn decode(&mut self, buffer: &[u8]) -> Result { - let name: String = self.decode(buffer)?; - let message: Vec = self.decode_raw_buffer(buffer)?; - Ok(Extension { name, message }) - } -} From 1700d900514abc1097e4a38d9dcae4a9d50a231c Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 23 Apr 2025 13:44:51 -0400 Subject: [PATCH 075/206] make test easier to debug --- src/mqueue.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/mqueue.rs b/src/mqueue.rs index cd86caf..de15e37 100644 --- a/src/mqueue.rs +++ b/src/mqueue.rs @@ -161,7 +161,9 @@ mod test { fn new_msg(channel: u64) -> ChannelMessage { ChannelMessage { channel, - message: crate::Message::NoData(NoData { request: channel }), + message: crate::Message::NoData(NoData { + request: channel + 1, + }), } } From 1d511598860605ccbd862bcb01b223f7a09adda1 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 23 Apr 2025 13:45:05 -0400 Subject: [PATCH 076/206] lint --- src/schema.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/schema.rs b/src/schema.rs index 8a9a0a2..41ae6aa 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -78,7 +78,7 @@ pub struct Close { impl CompactEncoding for Close { fn encoded_size(&self) -> Result { - Ok(self.channel.encoded_size()?) + self.channel.encoded_size() } fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { From a3395c9050f48004d6b0e152bd311a12d05bb837 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 23 Apr 2025 13:50:58 -0400 Subject: [PATCH 077/206] rename old encoder trait to fix name collisio --- src/message/modern.rs | 521 +++++++++--------------------------------- src/mqueue.rs | 2 +- 2 files changed, 106 insertions(+), 417 deletions(-) diff --git a/src/message/modern.rs b/src/message/modern.rs index e70bed2..b9b6792 100644 --- a/src/message/modern.rs +++ b/src/message/modern.rs @@ -21,7 +21,7 @@ pub(crate) trait Encoder: Sized + fmt::Debug { /// Encodes the message to a buffer. /// /// An error will be returned if the buffer does not have sufficient capacity. - fn encode(&self, buf: &mut [u8]) -> Result; + fn encoder_encode(&self, buf: &mut [u8]) -> Result; } impl Encoder for &[u8] { @@ -29,7 +29,7 @@ impl Encoder for &[u8] { Ok(self.len()) } - fn encode(&self, buf: &mut [u8]) -> Result { + fn encoder_encode(&self, buf: &mut [u8]) -> Result { let len = self.encoded_len()?; if len > buf.len() { return Err(EncodingError::new( @@ -199,7 +199,7 @@ impl Encoder for Vec { } #[instrument(skip_all)] - fn encode(&self, buf: &mut [u8]) -> Result { + fn encoder_encode(&self, buf: &mut [u8]) -> Result { let mut state = State::new(); let body_len = prencode_channel_messages(self, &mut state)?; write_uint24_le(body_len, buf); @@ -393,6 +393,17 @@ impl fmt::Debug for ChannelMessage { } } +impl fmt::Display for ChannelMessage { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "ChannelMessage {{ channel {}, message {} }}", + self.channel, + self.message.name() + ) + } +} + impl ChannelMessage { /// Create a new message. pub(crate) fn new(channel: u64, message: Message) -> Self { @@ -409,21 +420,21 @@ impl ChannelMessage { /// Note: `buf` has to have a valid length, and without the 3 LE /// bytes in it pub(crate) fn decode_open_message(buf: &[u8]) -> io::Result<(Self, usize)> { - if buf.len() <= 5 { + let og_len = buf.len(); + if og_len <= 5 { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, "received too short Open message", )); } - let mut state = State::new_with_start_and_end(0, buf.len()); - let open_msg: Open = state.decode(buf)?; + let (open_msg, buf) = Open::decode(buf)?; Ok(( Self { channel: open_msg.channel, message: Message::Open(open_msg), }, - state.start(), + og_len - buf.len(), )) } @@ -432,86 +443,104 @@ impl ChannelMessage { /// Note: `buf` has to have a valid length, and without the 3 LE /// bytes in it pub(crate) fn decode_close_message(buf: &[u8]) -> io::Result<(Self, usize)> { + let og_len = buf.len(); if buf.is_empty() { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, "received too short Close message", )); } - let mut state = State::new_with_start_and_end(0, buf.len()); - let close_msg: Close = state.decode(buf)?; + let (close, buf) = Close::decode(buf)?; Ok(( Self { - channel: close_msg.channel, - message: Message::Close(close_msg), + channel: close.channel, + message: Message::Close(close), }, - state.start(), + og_len - buf.len(), )) } + #[instrument(err, skip_all)] + pub(crate) fn decode_from_channel_and_message( + buf: &[u8], + ) -> Result<(Self, &[u8]), EncodingError> { + //::decode(buf) + let (channel, buf) = u64::decode(buf)?; + let (message, buf) = ::decode(buf)?; + Ok((Self { channel, message }, buf)) + } /// Decode a normal channel message from a buffer. /// /// Note: `buf` has to have a valid length, and without the 3 LE /// bytes in it - pub(crate) fn decode(buf: &[u8], channel: u64) -> io::Result<(Self, usize)> { + pub(crate) fn decode(buf: &[u8], channel: u64) -> io::Result<(Self, &[u8])> { if buf.len() <= 1 { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, "received empty message", )); } - let mut state = State::from_buffer(buf); - let typ: u64 = state.decode(buf)?; - let (message, length) = Message::decode(&buf[state.start()..], typ)?; - Ok((Self { channel, message }, state.start() + length)) + let (message, buf) = ::decode(buf)?; + Ok((Self { channel, message }, buf)) } /// Performance optimization for letting calling encoded_len() already do /// the preencode phase of compact_encoding. - fn prepare_state(&self) -> Result { + fn prepare_state(&self) -> Result { Ok(if let Message::Open(_) = self.message { // Open message doesn't have a type // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L41 - let mut state = HypercoreState::new(); - self.message.preencode(&mut state)?; - state + self.message.encoded_size()? } else if let Message::Close(_) = self.message { // Close message doesn't have a type // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L162 - let mut state = HypercoreState::new(); - self.message.preencode(&mut state)?; - state + self.message.encoded_size()? } else { // The header is the channel id uint followed by message type uint // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L179 - let mut state = HypercoreState::new(); let typ = self.message.typ(); - (*state).preencode(&typ)?; - self.message.preencode(&mut state)?; - state + typ.encoded_size()? + self.message.encoded_size()? }) } } +/// NB: currently this is just for a standalone channel message. ChannelMessages in a vec decode & +/// encode differently +impl CompactEncoding for ChannelMessage { + fn encoded_size(&self) -> Result { + Ok(self.channel.encoded_size()? + self.message.encoded_size()?) + } + + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + let rest = self.channel.encode(buffer)?; + ::encode(&self.message, rest) + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + ChannelMessage::decode_from_channel_and_message(buffer) + } +} impl Encoder for ChannelMessage { fn encoded_len(&self) -> Result { - Ok(self.prepare_state()?.end()) + self.prepare_state() } - fn encode(&self, buf: &mut [u8]) -> Result { - let mut state = self.prepare_state()?; - if let Message::Open(_) = self.message { + #[instrument(skip_all)] + fn encoder_encode(&self, buf: &mut [u8]) -> Result { + let before = buf.len(); + let rest = if let Message::Open(_) = self.message { // Open message is different in that the type byte is missing - self.message.encode(&mut state, buf)?; + ::encode(&self.message, buf)? } else if let Message::Close(_) = self.message { // Close message is different in that the type byte is missing - self.message.encode(&mut state, buf)?; + ::encode(&self.message, buf)? } else { - let typ = self.message.typ(); - state.0.encode(&typ, buf)?; - self.message.encode(&mut state, buf)?; - } - Ok(state.start()) + ::encode(&self.message, buf)? + }; + Ok(before - rest.len()) } } @@ -528,168 +557,49 @@ mod tests { $( let channel = rand::random::() as u64; let channel_message = ChannelMessage::new(channel, $msg); - let encoded_len = channel_message.encoded_len().expect("Failed to get encoded length"); + let encoded_len = channel_message.encoded_len()?; let mut buf = vec![0u8; encoded_len]; - let n = channel_message.encode(&mut buf[..]).expect("Failed to encode message"); - let decoded = ChannelMessage::decode(&buf[..n], channel).expect("Failed to decode message").0.into_split(); - assert_eq!(channel, decoded.0); - assert_eq!($msg, decoded.1); + let rest = ::encode(&channel_message, &mut buf)?; + assert!(rest.is_empty()); + let (decoded, rest) = ::decode(&buf)?; + assert!(rest.is_empty()); + assert_eq!(decoded, channel_message); )* } } - /// A frame of data, either a buffer or a message. - #[derive(Clone, PartialEq)] - pub(crate) enum Frame { - /// A raw batch binary buffer. Used in the handshaking phase. - RawBatch(Vec>), - /// Message batch, containing one or more channel messsages. Used for everything after the handshake. - MessageBatch(Vec), - } - - impl fmt::Debug for Frame { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Frame::RawBatch(batch) => write!(f, "Frame(RawBatch <{}>)", batch.len()), - Frame::MessageBatch(messages) => write!(f, "Frame({messages:?})"), - } - } - } - - impl From for Frame { - fn from(m: ChannelMessage) -> Self { - Self::MessageBatch(vec![m]) - } - } - - impl From> for Frame { - fn from(m: Vec) -> Self { - Self::MessageBatch(m) - } - } - - impl From> for Frame { - fn from(m: Vec) -> Self { - Self::RawBatch(vec![m]) - } - } - - impl Frame { - /// Decodes a frame from a buffer containing multiple concurrent messages. - fn preencode(&self, state: &mut State) -> Result { - match self { - Self::RawBatch(raw_batch) => { - for raw in raw_batch { - state.add_end(raw.as_slice().encoded_len()?)?; - } - } - #[allow(clippy::comparison_chain)] - Self::MessageBatch(messages) => { - if messages.len() == 1 { - if let Message::Open(_) = &messages[0].message { - // This is a special case with 0x00, 0x01 intro bytes - state.add_end(2 + &messages[0].encoded_len()?)?; - } else if let Message::Close(_) = &messages[0].message { - // This is a special case with 0x00, 0x03 intro bytes - state.add_end(2 + &messages[0].encoded_len()?)?; - } else { - (*state).preencode(&messages[0].channel)?; - state.add_end(messages[0].encoded_len()?)?; - } - } else if messages.len() > 1 { - // Two intro bytes 0x00 0x00, then channel id, then lengths - state.add_end(2)?; - let mut current_channel: u64 = messages[0].channel; - state.preencode(¤t_channel)?; - for message in messages.iter() { - if message.channel != current_channel { - // Channel changed, need to add a 0x00 in between and then the new - // channel - state.add_end(1)?; - state.preencode(&message.channel)?; - current_channel = message.channel; - } - let message_length = message.encoded_len()?; - state.preencode(&message_length)?; - state.add_end(message_length)?; - } - } - } - } - Ok(state.end()) - } - } - - impl Encoder for Frame { - fn encoded_len(&self) -> Result { - let body_len = self.preencode(&mut State::new())?; - match self { - Self::RawBatch(_) => Ok(body_len), - Self::MessageBatch(_) => Ok(3 + body_len), - } - } - - fn encode(&self, buf: &mut [u8]) -> Result { - let mut state = State::new(); - let header_len = if let Self::RawBatch(_) = self { 0 } else { 3 }; - let body_len = self.preencode(&mut state)?; - let len = body_len + header_len; - if buf.len() < len { - return Err(EncodingError::new( - EncodingErrorKind::Overflow, - &format!("Length does not fit buffer, {} > {}", len, buf.len()), - )); - } - match self { - Self::RawBatch(ref raw_batch) => { - for raw in raw_batch { - raw.as_slice().encode(buf)?; - } - } - #[allow(clippy::comparison_chain)] - Self::MessageBatch(ref messages) => { - write_uint24_le(body_len, buf); - let buf = buf.get_mut(3..).expect("Buffer should be over 3 bytes"); - if messages.len() == 1 { - if let Message::Open(_) = &messages[0].message { - // This is a special case with 0x00, 0x01 intro bytes - state.encode(&(0_u8), buf)?; - state.encode(&(1_u8), buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } else if let Message::Close(_) = &messages[0].message { - // This is a special case with 0x00, 0x03 intro bytes - state.encode(&(0_u8), buf)?; - state.encode(&(3_u8), buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } else { - state.encode(&messages[0].channel, buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } - } else if messages.len() > 1 { - // Two intro bytes 0x00 0x00, then channel id, then lengths - state.set_slice_to_buffer(&[0_u8, 0_u8], buf)?; - let mut current_channel: u64 = messages[0].channel; - state.encode(¤t_channel, buf)?; - for message in messages.iter() { - if message.channel != current_channel { - // Channel changed, need to add a 0x00 in between and then the new - // channel - state.encode(&(0_u8), buf)?; - state.encode(&message.channel, buf)?; - current_channel = message.channel; - } - let message_length = message.encoded_len()?; - state.encode(&message_length, buf)?; - state.add_start(message.encode(&mut buf[state.start()..])?)?; - } - } - } - }; - Ok(len) - } + #[test] + fn boo() -> Result<(), EncodingError> { + let m = Message::Cancel(Cancel { request: 1 }); + let m = Message::Request(Request { + id: 1, + fork: 1, + block: Some(RequestBlock { + index: 5, + nodes: 10, + }), + hash: Some(RequestBlock { + index: 20, + nodes: 0, + }), + seek: Some(RequestSeek { bytes: 10 }), + upgrade: Some(RequestUpgrade { + start: 0, + length: 10, + }), + }); + let channel = rand::random::() as u64; + let channel_message = ChannelMessage::new(channel, m); + let encoded_len = channel_message.encoded_len()?; + let mut buf = vec![0u8; encoded_len]; + let rest = ::encode(&channel_message, &mut buf)?; + assert!(rest.is_empty()); + let (decoded, rest) = ::decode(&buf)?; + assert!(rest.is_empty()); + assert_eq!(decoded, channel_message); + Ok(()) } - #[test] - fn message_encode_decode() { + fn message_encode_decode() -> Result<(), EncodingError> { message_enc_dec! { Message::Synchronize(Synchronize{ fork: 0, @@ -770,227 +680,6 @@ mod tests { message: vec![0x44, 20] }) }; - } - - fn message_test_data() -> Vec { - vec![ - Message::Synchronize(Synchronize { - fork: 0, - can_upgrade: true, - downloading: true, - uploading: true, - length: 5, - remote_length: 0, - }), - Message::Request(Request { - id: 1, - fork: 1, - block: Some(RequestBlock { - index: 5, - nodes: 10, - }), - hash: Some(RequestBlock { - index: 20, - nodes: 0, - }), - seek: Some(RequestSeek { bytes: 10 }), - upgrade: Some(RequestUpgrade { - start: 0, - length: 10, - }), - }), - Message::Cancel(Cancel { request: 1 }), - Message::Data(Data { - request: 1, - fork: 5, - block: Some(DataBlock { - index: 5, - nodes: vec![Node::new(1, vec![0x01; 32], 100)], - value: vec![0xFF; 10], - }), - hash: Some(DataHash { - index: 20, - nodes: vec![Node::new(2, vec![0x02; 32], 200)], - }), - seek: Some(DataSeek { - bytes: 10, - nodes: vec![Node::new(3, vec![0x03; 32], 300)], - }), - upgrade: Some(DataUpgrade { - start: 0, - length: 10, - nodes: vec![Node::new(4, vec![0x04; 32], 400)], - additional_nodes: vec![Node::new(5, vec![0x05; 32], 500)], - signature: vec![0xAB; 32], - }), - }), - Message::NoData(NoData { request: 2 }), - Message::Want(Want { - start: 0, - length: 100, - }), - Message::Unwant(Unwant { - start: 10, - length: 2, - }), - Message::Bitfield(Bitfield { - start: 20, - bitfield: vec![0x89ABCDEF, 0x00, 0xFFFFFFFF], - }), - Message::Range(Range { - drop: true, - start: 12345, - length: 100000, - }), - Message::Extension(Extension { - name: "custom_extension/v1/open".to_string(), - message: vec![0x44, 20], - }), - ] - } - - impl Frame { - pub(crate) fn decode_multiple(buf: &[u8]) -> Result { - let mut index = 0; - let mut combined_messages: Vec = vec![]; - while index < buf.len() { - // There might be zero bytes in between, and with LE, the next message will - // start with a non-zero - if buf[index] == 0 { - index += 1; - continue; - } - - let stat = stat_uint24_le(&buf[index..]); - if let Some((header_len, body_len)) = stat { - let (frame, length) = Self::decode_message( - &buf[index + header_len..index + header_len + body_len as usize], - )?; - if length != body_len as usize { - tracing::warn!( - "Did not know what to do with all the bytes, got {} but decoded {}. \ - This may be because the peer implements a newer protocol version \ - that has extra fields.", - body_len, - length - ); - } - if let Frame::MessageBatch(messages) = frame { - for message in messages { - combined_messages.push(message); - } - } else { - unreachable!("Can not get Raw messages"); - } - index += header_len + body_len as usize; - } else { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "received invalid data in multi-message chunk", - )); - } - } - Ok(Frame::MessageBatch(combined_messages)) - } - - fn decode_message(buf: &[u8]) -> Result<(Self, usize), io::Error> { - // buffer length >= 3 or more and starts with 0 is message batch - if buf.len() >= 3 && buf[0] == 0x00 { - if buf[1] == 0x00 { - // Batch of messages - let mut messages: Vec = vec![]; - let mut state = State::new_with_start_and_end(2, buf.len()); - - // First, there is the original channel - let mut current_channel: u64 = state.decode(buf)?; - while state.start() < state.end() { - // Length of the message is inbetween here - let channel_message_length: usize = state.decode(buf)?; - if state.start() + channel_message_length > state.end() { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!( - "received invalid message length, {} + {} > {}", - state.start(), - channel_message_length, - state.end() - ), - )); - } - // Then the actual message - let (channel_message, _) = ChannelMessage::decode( - &buf[state.start()..state.start() + channel_message_length], - current_channel, - )?; - messages.push(channel_message); - state.add_start(channel_message_length)?; - // After that, if there is an extra 0x00, that means the channel - // changed. This works because of LE encoding, and channels starting - // from the index 1. - if state.start() < state.end() && buf[state.start()] == 0x00 { - state.add_start(1)?; - current_channel = state.decode(buf)?; - } - } - Ok((Frame::MessageBatch(messages), state.start())) - } else if buf[1] == 0x01 { - // Open message - let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?; - Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) - } else if buf[1] == 0x03 { - // Close message - let (channel_message, length) = - ChannelMessage::decode_close_message(&buf[2..])?; - Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - "received invalid special message", - )) - } - } else if buf.len() >= 2 { - // len >= and - // Single message - let mut state = State::from_buffer(buf); - let channel: u64 = state.decode(buf)?; - let (channel_message, length) = - ChannelMessage::decode(&buf[state.start()..], channel)?; - Ok(( - Frame::MessageBatch(vec![channel_message]), - state.start() + length, - )) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("received too short message, {buf:02X?}"), - )) - } - } - } - - #[test] - fn compare_with_frame_encoding_decoding() -> std::io::Result<()> { - let channel = 42; - for msg in message_test_data() { - let channel_message = ChannelMessage::new(channel, msg); - let frame = Frame::from(channel_message.clone()); - let cmvec = vec![channel_message.clone()]; - - let mut fbuf = vec![0; frame.encoded_len()?]; - let mut cbuf = vec![0; cmvec.encoded_len()?]; - - assert_eq!(cbuf, fbuf); - - frame.encode(&mut fbuf)?; - cmvec.encode(&mut cbuf)?; - - assert_eq!(cbuf, fbuf); - - let fres = Frame::decode_multiple(&fbuf)?; - assert_eq!(fres, frame); - let cres_m = decode_framed_channel_messages(&cbuf)?.0; - assert_eq!(cres_m, cmvec); - } Ok(()) } } diff --git a/src/mqueue.rs b/src/mqueue.rs index de15e37..9be4ab7 100644 --- a/src/mqueue.rs +++ b/src/mqueue.rs @@ -82,7 +82,7 @@ impl + Sink> + Send + Unpin + 'static> Mes } let mut buf = vec![0; messages.encoded_len()?]; - match messages.encode(&mut buf) { + match messages.encoder_encode(&mut buf) { Ok(_) => {} Err(e) => { error!(error = ?e, "error encoding messages"); From a9bab5696148048295829f8a530265fe0f9d6c6c Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 23 Apr 2025 13:51:56 -0400 Subject: [PATCH 078/206] cleaning up messages --- src/message/modern.rs | 333 ++++++++++++++++++++++-------------------- 1 file changed, 175 insertions(+), 158 deletions(-) diff --git a/src/message/modern.rs b/src/message/modern.rs index b9b6792..05e37be 100644 --- a/src/message/modern.rs +++ b/src/message/modern.rs @@ -1,12 +1,12 @@ use crate::schema::*; use crate::util::{stat_uint24_le, write_uint24_le}; use hypercore::encoding::{ - CompactEncoding, EncodingError, EncodingErrorKind, HypercoreState, State, + decode_usize, take_array, write_array, CompactEncoding, EncodingError, EncodingErrorKind, }; use pretty_hash::fmt as pretty_fmt; use std::fmt; use std::io; -use tracing::instrument; +use tracing::{instrument, trace}; const UINT24_HEADER_LEN: usize = 3; @@ -86,44 +86,41 @@ pub(crate) fn decode_framed_channel_messages( pub(crate) fn decode_unframed_channel_messages( buf: &[u8], ) -> Result<(Vec, usize), io::Error> { - if buf.len() >= 3 && buf[0] == 0x00 { + let og_len = buf.len(); + if og_len >= 3 && buf[0] == 0x00 { if buf[1] == 0x00 { + let (_, mut buf) = take_array::<2>(buf)?; // Batch of messages let mut messages: Vec = vec![]; - let mut state = State::new_with_start_and_end(2, buf.len()); // First, there is the original channel - let mut current_channel: u64 = state.decode(buf)?; - while state.start() < state.end() { + let mut current_channel; + (current_channel, buf) = u64::decode(buf)?; + while !buf.is_empty() { // Length of the message is inbetween here - let channel_message_length: usize = state.decode(buf)?; - if state.start() + channel_message_length > state.end() { + let channel_message_length; + (channel_message_length, buf) = decode_usize(buf)?; + if channel_message_length > buf.len() { return Err(io::Error::new( io::ErrorKind::InvalidData, format!( - "received invalid message length, {} + {} > {}", - state.start(), - channel_message_length, - state.end() + "received invalid message length: [{channel_message_length}] but we have [{}] remaining bytes. Initial buffer size [{og_len}]", + buf.len() ), )); } // Then the actual message - let (channel_message, _) = ChannelMessage::decode( - &buf[state.start()..state.start() + channel_message_length], - current_channel, - )?; + let channel_message; + (channel_message, buf) = ChannelMessage::decode(buf, current_channel)?; messages.push(channel_message); - state.add_start(channel_message_length)?; // After that, if there is an extra 0x00, that means the channel // changed. This works because of LE encoding, and channels starting // from the index 1. - if state.start() < state.end() && buf[state.start()] == 0x00 { - state.add_start(1)?; - current_channel = state.decode(buf)?; + if !buf.is_empty() && buf[0] == 0x00 { + (current_channel, buf) = u64::decode(buf)?; } } - Ok((messages, state.start())) + Ok((messages, og_len - buf.len())) } else if buf[1] == 0x01 { // Open message let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?; @@ -139,11 +136,11 @@ pub(crate) fn decode_unframed_channel_messages( )) } } else if buf.len() >= 2 { + trace!("Decoding single ChannelMessage"); // Single message - let mut state = State::from_buffer(buf); - let channel: u64 = state.decode(buf)?; - let (channel_message, length) = ChannelMessage::decode(&buf[state.start()..], channel)?; - Ok((vec![channel_message], state.start() + length)) + let og_len = buf.len(); + let (channel_message, buf) = ChannelMessage::decode_from_channel_and_message(buf)?; + Ok((vec![channel_message], og_len - buf.len())) } else { Err(io::Error::new( io::ErrorKind::InvalidData, @@ -152,92 +149,84 @@ pub(crate) fn decode_unframed_channel_messages( } } -fn prencode_channel_messages( - messages: &[ChannelMessage], - state: &mut State, -) -> Result { - match messages.len().cmp(&1) { - std::cmp::Ordering::Less => {} +fn prencode_channel_messages(messages: &[ChannelMessage]) -> Result { + Ok(match messages.len().cmp(&1) { + std::cmp::Ordering::Less => 0, std::cmp::Ordering::Equal => { if let Message::Open(_) = &messages[0].message { // This is a special case with 0x00, 0x01 intro bytes - state.add_end(2 + &messages[0].encoded_len()?)?; + 2 + &messages[0].encoded_len()? } else if let Message::Close(_) = &messages[0].message { // This is a special case with 0x00, 0x03 intro bytes - state.add_end(2 + &messages[0].encoded_len()?)?; + 2 + &messages[0].encoded_len()? } else { - state.preencode(&messages[0].channel)?; - state.add_end(messages[0].encoded_len()?)?; + messages[0].channel.encoded_size()? + messages[0].encoded_size()? } } std::cmp::Ordering::Greater => { // Two intro bytes 0x00 0x00, then channel id, then lengths - state.add_end(2)?; + let mut out = 2; let mut current_channel: u64 = messages[0].channel; - state.preencode(¤t_channel)?; + out += current_channel.encoded_size()?; for message in messages.iter() { if message.channel != current_channel { // Channel changed, need to add a 0x00 in between and then the new // channel - state.add_end(1)?; - state.preencode(&message.channel)?; + out += 1 + message.channel.encoded_size()?; current_channel = message.channel; } let message_length = message.encoded_len()?; - state.preencode(&message_length)?; - state.add_end(message_length)?; + out += message.encoded_size()? + message_length; } + out } - }; - Ok(state.end()) + }) } impl Encoder for Vec { fn encoded_len(&self) -> Result { - let mut state = State::new(); - Ok(prencode_channel_messages(self, &mut state)? + UINT24_HEADER_LEN) + Ok(prencode_channel_messages(self)? + UINT24_HEADER_LEN) } #[instrument(skip_all)] fn encoder_encode(&self, buf: &mut [u8]) -> Result { - let mut state = State::new(); - let body_len = prencode_channel_messages(self, &mut state)?; - write_uint24_le(body_len, buf); - let buf = buf.get_mut(3..).expect("Buffer should be over 3 bytes"); + let body_len = prencode_channel_messages(self)?; + let mut u24_bytes = [0, 0, 0]; + write_uint24_le(body_len, u24_bytes.as_mut_slice()); + let mut buf = write_array(&u24_bytes, buf)?; + // skip the u24 we just wrote match self.len().cmp(&1) { std::cmp::Ordering::Less => {} std::cmp::Ordering::Equal => { + trace!("Encoding single ChannelMessage {}", self[0]); if let Message::Open(_) = &self[0].message { // This is a special case with 0x00, 0x01 intro bytes - state.encode(&(0_u8), buf)?; - state.encode(&(1_u8), buf)?; - state.add_start(self[0].encode(&mut buf[state.start()..])?)?; + buf = write_array(&[0, 1], buf)?; + self[0].encode(buf)?; } else if let Message::Close(_) = &self[0].message { // This is a special case with 0x00, 0x03 intro bytes - state.encode(&(0_u8), buf)?; - state.encode(&(3_u8), buf)?; - state.add_start(self[0].encode(&mut buf[state.start()..])?)?; + buf = write_array(&[0, 3], buf)?; + self[0].encode(buf)?; } else { - state.encode(&self[0].channel, buf)?; - state.add_start(self[0].encode(&mut buf[state.start()..])?)?; + self[0].encode(buf)?; } } std::cmp::Ordering::Greater => { // Two intro bytes 0x00 0x00, then channel id, then lengths - state.set_slice_to_buffer(&[0_u8, 0_u8], buf)?; + buf = write_array(&[0, 0], buf)?; let mut current_channel: u64 = self[0].channel; - state.encode(¤t_channel, buf)?; + buf = current_channel.encode(buf)?; for message in self.iter() { if message.channel != current_channel { // Channel changed, need to add a 0x00 in between and then the new // channel - state.encode(&(0_u8), buf)?; - state.encode(&message.channel, buf)?; + buf = write_array(&[0], buf)?; + buf = message.channel.encode(buf)?; current_channel = message.channel; } let message_length = message.encoded_len()?; - state.encode(&message_length, buf)?; - state.add_start(message.encode(&mut buf[state.start()..])?)?; + buf = (message_length as u32).encode(buf)?; + buf = message.encode(buf)?; } } } @@ -265,6 +254,113 @@ pub enum Message { LocalSignal((String, Vec)), } +macro_rules! message_from { + ($($val:ident),+) => { + $( + impl From<$val> for Message { + fn from(value: $val) -> Self { + Message::$val(value) + } + } + )* + } +} +message_from!( + Open, + Close, + Synchronize, + Request, + Cancel, + Data, + NoData, + Want, + Unwant, + Bitfield, + Range, + Extension +); + +macro_rules! decode_message { + ($type:ty, $buf:expr) => {{ + let (x, rest) = <$type>::decode($buf)?; + (Message::from(x), rest) + }}; +} + +impl CompactEncoding for Message { + fn encoded_size(&self) -> Result { + let typ_size = if let Self::Open(_) | Self::Close(_) = &self { + 0 + } else { + self.typ().encoded_size()? + }; + let msg_size = match self { + Self::LocalSignal(_) => Ok(0), + Self::Open(x) => x.encoded_size(), + Self::Close(x) => x.encoded_size(), + Self::Synchronize(x) => x.encoded_size(), + Self::Request(x) => x.encoded_size(), + Self::Cancel(x) => x.encoded_size(), + Self::Data(x) => x.encoded_size(), + Self::NoData(x) => x.encoded_size(), + Self::Want(x) => x.encoded_size(), + Self::Unwant(x) => x.encoded_size(), + Self::Bitfield(x) => x.encoded_size(), + Self::Range(x) => x.encoded_size(), + Self::Extension(x) => x.encoded_size(), + }?; + Ok(typ_size + msg_size) + } + + #[instrument(skip_all, fields(name = self.name()))] + fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + let rest = if let Self::Open(_) | Self::Close(_) = &self { + buffer + } else { + self.typ().encode(buffer)? + }; + match self { + Self::Open(x) => x.encode(rest), + Self::Close(x) => x.encode(rest), + Self::Synchronize(x) => x.encode(rest), + Self::Request(x) => x.encode(rest), + Self::Cancel(x) => x.encode(rest), + Self::Data(x) => x.encode(rest), + Self::NoData(x) => x.encode(rest), + Self::Want(x) => x.encode(rest), + Self::Unwant(x) => x.encode(rest), + Self::Bitfield(x) => x.encode(rest), + Self::Range(x) => x.encode(rest), + Self::Extension(x) => x.encode(rest), + Self::LocalSignal(_) => unimplemented!("do not encode LocalSignal"), + } + } + + fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> + where + Self: Sized, + { + let (typ, rest) = u64::decode(buffer)?; + Ok(match typ { + 0 => decode_message!(Synchronize, rest), + 1 => decode_message!(Request, rest), + 2 => decode_message!(Cancel, rest), + 3 => decode_message!(Data, rest), + 4 => decode_message!(NoData, rest), + 5 => decode_message!(Want, rest), + 6 => decode_message!(Unwant, rest), + 7 => decode_message!(Bitfield, rest), + 8 => decode_message!(Range, rest), + 9 => decode_message!(Extension, rest), + _ => { + return Err(EncodingError::new( + EncodingErrorKind::InvalidData, + &format!("Invalid message type to decode: {typ}"), + )) + } + }) + } +} impl Message { /// Wire type of this message. pub(crate) fn typ(&self) -> u64 { @@ -282,71 +378,23 @@ impl Message { value => unimplemented!("{} does not have a type", value), } } - - /// Decode a message from a buffer based on type. - pub(crate) fn decode(buf: &[u8], typ: u64) -> Result<(Self, usize), EncodingError> { - let mut state = HypercoreState::from_buffer(buf); - let message = match typ { - 0 => Ok(Self::Synchronize((*state).decode(buf)?)), - 1 => Ok(Self::Request(state.decode(buf)?)), - 2 => Ok(Self::Cancel((*state).decode(buf)?)), - 3 => Ok(Self::Data(state.decode(buf)?)), - 4 => Ok(Self::NoData((*state).decode(buf)?)), - 5 => Ok(Self::Want((*state).decode(buf)?)), - 6 => Ok(Self::Unwant((*state).decode(buf)?)), - 7 => Ok(Self::Bitfield((*state).decode(buf)?)), - 8 => Ok(Self::Range((*state).decode(buf)?)), - 9 => Ok(Self::Extension((*state).decode(buf)?)), - _ => Err(EncodingError::new( - EncodingErrorKind::InvalidData, - &format!("Invalid message type to decode: {typ}"), - )), - }?; - Ok((message, state.start())) - } - - /// Pre-encodes a message to state, returns length - pub(crate) fn preencode(&self, state: &mut HypercoreState) -> Result { + /// Get the name of the message + pub fn name(&self) -> &'static str { match self { - Self::Open(ref message) => state.0.preencode(message)?, - Self::Close(ref message) => state.0.preencode(message)?, - Self::Synchronize(ref message) => state.0.preencode(message)?, - Self::Request(ref message) => state.preencode(message)?, - Self::Cancel(ref message) => state.0.preencode(message)?, - Self::Data(ref message) => state.preencode(message)?, - Self::NoData(ref message) => state.0.preencode(message)?, - Self::Want(ref message) => state.0.preencode(message)?, - Self::Unwant(ref message) => state.0.preencode(message)?, - Self::Bitfield(ref message) => state.0.preencode(message)?, - Self::Range(ref message) => state.0.preencode(message)?, - Self::Extension(ref message) => state.0.preencode(message)?, - Self::LocalSignal(_) => 0, - }; - Ok(state.end()) - } - - /// Encodes a message to a given buffer, using preencoded state, results size - pub(crate) fn encode( - &self, - state: &mut HypercoreState, - buf: &mut [u8], - ) -> Result { - match self { - Self::Open(ref message) => state.0.encode(message, buf)?, - Self::Close(ref message) => state.0.encode(message, buf)?, - Self::Synchronize(ref message) => state.0.encode(message, buf)?, - Self::Request(ref message) => state.encode(message, buf)?, - Self::Cancel(ref message) => state.0.encode(message, buf)?, - Self::Data(ref message) => state.encode(message, buf)?, - Self::NoData(ref message) => state.0.encode(message, buf)?, - Self::Want(ref message) => state.0.encode(message, buf)?, - Self::Unwant(ref message) => state.0.encode(message, buf)?, - Self::Bitfield(ref message) => state.0.encode(message, buf)?, - Self::Range(ref message) => state.0.encode(message, buf)?, - Self::Extension(ref message) => state.0.encode(message, buf)?, - Self::LocalSignal(_) => 0, - }; - Ok(state.start()) + Message::Open(_) => "Open", + Message::Close(_) => "Close", + Message::Synchronize(_) => "Synchronize", + Message::Request(_) => "Request", + Message::Cancel(_) => "Cancel", + Message::Data(_) => "Data", + Message::NoData(_) => "NoData", + Message::Want(_) => "Want", + Message::Unwant(_) => "Unwant", + Message::Bitfield(_) => "Bitfield", + Message::Range(_) => "Range", + Message::Extension(_) => "Extension", + Message::LocalSignal(_) => "LocalSignal", + } } } @@ -568,37 +616,6 @@ mod tests { } } #[test] - fn boo() -> Result<(), EncodingError> { - let m = Message::Cancel(Cancel { request: 1 }); - let m = Message::Request(Request { - id: 1, - fork: 1, - block: Some(RequestBlock { - index: 5, - nodes: 10, - }), - hash: Some(RequestBlock { - index: 20, - nodes: 0, - }), - seek: Some(RequestSeek { bytes: 10 }), - upgrade: Some(RequestUpgrade { - start: 0, - length: 10, - }), - }); - let channel = rand::random::() as u64; - let channel_message = ChannelMessage::new(channel, m); - let encoded_len = channel_message.encoded_len()?; - let mut buf = vec![0u8; encoded_len]; - let rest = ::encode(&channel_message, &mut buf)?; - assert!(rest.is_empty()); - let (decoded, rest) = ::decode(&buf)?; - assert!(rest.is_empty()); - assert_eq!(decoded, channel_message); - Ok(()) - } - #[test] fn message_encode_decode() -> Result<(), EncodingError> { message_enc_dec! { Message::Synchronize(Synchronize{ From 80c29ee629a34177fa0577726ed5baddb86b9e1c Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 24 Apr 2025 14:19:09 -0400 Subject: [PATCH 079/206] removing Encoder trait --- src/message/modern.rs | 74 ++++++++++++++----------------------------- 1 file changed, 24 insertions(+), 50 deletions(-) diff --git a/src/message/modern.rs b/src/message/modern.rs index 05e37be..d3b8551 100644 --- a/src/message/modern.rs +++ b/src/message/modern.rs @@ -2,11 +2,12 @@ use crate::schema::*; use crate::util::{stat_uint24_le, write_uint24_le}; use hypercore::encoding::{ decode_usize, take_array, write_array, CompactEncoding, EncodingError, EncodingErrorKind, + VecEncodable, }; use pretty_hash::fmt as pretty_fmt; use std::fmt; use std::io; -use tracing::{instrument, trace}; +use tracing::{instrument, trace, warn}; const UINT24_HEADER_LEN: usize = 3; @@ -24,24 +25,6 @@ pub(crate) trait Encoder: Sized + fmt::Debug { fn encoder_encode(&self, buf: &mut [u8]) -> Result; } -impl Encoder for &[u8] { - fn encoded_len(&self) -> Result { - Ok(self.len()) - } - - fn encoder_encode(&self, buf: &mut [u8]) -> Result { - let len = self.encoded_len()?; - if len > buf.len() { - return Err(EncodingError::new( - EncodingErrorKind::Overflow, - &format!("Length does not fit buffer, {} > {}", len, buf.len()), - )); - } - buf[..len].copy_from_slice(&self[..]); - Ok(len) - } -} - pub(crate) fn decode_framed_channel_messages( buf: &[u8], ) -> Result<(Vec, usize), io::Error> { @@ -61,12 +44,11 @@ pub(crate) fn decode_framed_channel_messages( &buf[index + header_len..index + header_len + body_len as usize], )?; if length != body_len as usize { - tracing::warn!( + warn!( "Did not know what to do with all the bytes, got {} but decoded {}. \ This may be because the peer implements a newer protocol version \ that has extra fields.", - body_len, - length + body_len, length ); } for message in msgs { @@ -150,32 +132,24 @@ pub(crate) fn decode_unframed_channel_messages( } fn prencode_channel_messages(messages: &[ChannelMessage]) -> Result { - Ok(match messages.len().cmp(&1) { - std::cmp::Ordering::Less => 0, - std::cmp::Ordering::Equal => { - if let Message::Open(_) = &messages[0].message { - // This is a special case with 0x00, 0x01 intro bytes - 2 + &messages[0].encoded_len()? - } else if let Message::Close(_) = &messages[0].message { - // This is a special case with 0x00, 0x03 intro bytes - 2 + &messages[0].encoded_len()? - } else { - messages[0].channel.encoded_size()? + messages[0].encoded_size()? - } - } - std::cmp::Ordering::Greater => { - // Two intro bytes 0x00 0x00, then channel id, then lengths + Ok(match messages { + [] => 0, + [msg] => match msg.message { + Message::Open(_) | Message::Close(_) => 2 + msg.encoded_size()?, + _ => msg.encoded_size()?, + }, + msgs => { let mut out = 2; let mut current_channel: u64 = messages[0].channel; out += current_channel.encoded_size()?; - for message in messages.iter() { + for message in msgs.iter() { if message.channel != current_channel { // Channel changed, need to add a 0x00 in between and then the new // channel out += 1 + message.channel.encoded_size()?; current_channel = message.channel; } - let message_length = message.encoded_len()?; + let message_length = message.message.encoded_size()?; out += message.encoded_size()? + message_length; } out @@ -224,7 +198,7 @@ impl Encoder for Vec { buf = message.channel.encode(buf)?; current_channel = message.channel; } - let message_length = message.encoded_len()?; + let message_length = message.message.encoded_size()?; buf = (message_length as u32).encode(buf)?; buf = message.encode(buf)?; } @@ -535,19 +509,17 @@ impl ChannelMessage { /// Performance optimization for letting calling encoded_len() already do /// the preencode phase of compact_encoding. fn prepare_state(&self) -> Result { - Ok(if let Message::Open(_) = self.message { - // Open message doesn't have a type + Ok(match self.message { + // Open & Close message doesn't have a type byte // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L41 - self.message.encoded_size()? - } else if let Message::Close(_) = self.message { - // Close message doesn't have a type // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L162 - self.message.encoded_size()? - } else { + Message::Open(_) | Message::Close(_) => self.message.encoded_size()?, // The header is the channel id uint followed by message type uint // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L179 - let typ = self.message.typ(); - typ.encoded_size()? + self.message.encoded_size()? + _ => { + let typ = self.message.typ(); + typ.encoded_size()? + self.message.encoded_size()? + } }) } } @@ -597,7 +569,8 @@ mod tests { use super::*; use hypercore::{ - DataBlock, DataHash, DataSeek, DataUpgrade, Node, RequestBlock, RequestSeek, RequestUpgrade, + encoding::to_encoded_bytes, DataBlock, DataHash, DataSeek, DataUpgrade, Node, RequestBlock, + RequestSeek, RequestUpgrade, }; macro_rules! message_enc_dec { @@ -615,6 +588,7 @@ mod tests { )* } } + #[test] fn message_encode_decode() -> Result<(), EncodingError> { message_enc_dec! { From a57a885f39c0568b5c49eb17ceef5d18246603ed Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 24 Apr 2025 16:05:12 -0400 Subject: [PATCH 080/206] wip impl VecEncodable for CompactEncoding --- src/message/modern.rs | 136 +++++++++++++++++++++++++++++------------- src/util.rs | 1 + 2 files changed, 94 insertions(+), 43 deletions(-) diff --git a/src/message/modern.rs b/src/message/modern.rs index d3b8551..bdc6164 100644 --- a/src/message/modern.rs +++ b/src/message/modern.rs @@ -1,8 +1,8 @@ use crate::schema::*; use crate::util::{stat_uint24_le, write_uint24_le}; use hypercore::encoding::{ - decode_usize, take_array, write_array, CompactEncoding, EncodingError, EncodingErrorKind, - VecEncodable, + decode_usize, take_array, take_array_mut, write_array, CompactEncoding, EncodingError, + EncodingErrorKind, VecEncodable, }; use pretty_hash::fmt as pretty_fmt; use std::fmt; @@ -10,6 +10,10 @@ use std::io; use tracing::{instrument, trace, warn}; const UINT24_HEADER_LEN: usize = 3; +const OPEN_MESSAGE_PREFIX: [u8; 2] = [0, 1]; +const CLOSE_MESSAGE_PREFIX: [u8; 2] = [0, 3]; +const MULTI_MESSAGE_PREFIX: [u8; 2] = [0, 0]; +const CHANNEL_CHANGE_SEPERATOR: [u8; 1] = [0]; /// Encode data into a buffer. /// @@ -64,12 +68,12 @@ pub(crate) fn decode_framed_channel_messages( } Ok((combined_messages, index)) } -// bad name bc it returns many. More like, decode unframed channel messages pub(crate) fn decode_unframed_channel_messages( buf: &[u8], ) -> Result<(Vec, usize), io::Error> { let og_len = buf.len(); if og_len >= 3 && buf[0] == 0x00 { + // batch of NOT open/close messages if buf[1] == 0x00 { let (_, mut buf) = take_array::<2>(buf)?; // Batch of messages @@ -157,6 +161,19 @@ fn prencode_channel_messages(messages: &[ChannelMessage]) -> Result Result<&mut [u8], EncodingError> { + let (header, rest) = take_array_mut::(buf)?; + write_uint24_le(n, header); + Ok(rest) +} + +/// decode a u24 from `buffer` as a `usize` +fn decode_u24(buffer: &[u8]) -> Result<(usize, &[u8]), EncodingError> { + let (u24_bytes, rest) = take_array::(buffer)?; + let (_, out) = stat_uint24_le(&u24_bytes).expect("input garunteed to be long enough"); + Ok((out as usize, rest)) +} + impl Encoder for Vec { fn encoded_len(&self) -> Result { Ok(prencode_channel_messages(self)? + UINT24_HEADER_LEN) @@ -165,9 +182,7 @@ impl Encoder for Vec { #[instrument(skip_all)] fn encoder_encode(&self, buf: &mut [u8]) -> Result { let body_len = prencode_channel_messages(self)?; - let mut u24_bytes = [0, 0, 0]; - write_uint24_le(body_len, u24_bytes.as_mut_slice()); - let mut buf = write_array(&u24_bytes, buf)?; + let mut buf = checked_write_uint24_le(body_len, buf)?; // skip the u24 we just wrote match self.len().cmp(&1) { std::cmp::Ordering::Less => {} @@ -505,23 +520,6 @@ impl ChannelMessage { let (message, buf) = ::decode(buf)?; Ok((Self { channel, message }, buf)) } - - /// Performance optimization for letting calling encoded_len() already do - /// the preencode phase of compact_encoding. - fn prepare_state(&self) -> Result { - Ok(match self.message { - // Open & Close message doesn't have a type byte - // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L41 - // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L162 - Message::Open(_) | Message::Close(_) => self.message.encoded_size()?, - // The header is the channel id uint followed by message type uint - // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L179 - _ => { - let typ = self.message.typ(); - typ.encoded_size()? + self.message.encoded_size()? - } - }) - } } /// NB: currently this is just for a standalone channel message. ChannelMessages in a vec decode & @@ -543,24 +541,77 @@ impl CompactEncoding for ChannelMessage { ChannelMessage::decode_from_channel_and_message(buffer) } } -impl Encoder for ChannelMessage { - fn encoded_len(&self) -> Result { - self.prepare_state() + +impl VecEncodable for ChannelMessage { + fn vec_encoded_size(vec: &[Self]) -> Result + where + Self: Sized, + { + Ok(prencode_channel_messages(vec)? + UINT24_HEADER_LEN) } - #[instrument(skip_all)] - fn encoder_encode(&self, buf: &mut [u8]) -> Result { - let before = buf.len(); - let rest = if let Message::Open(_) = self.message { - // Open message is different in that the type byte is missing - ::encode(&self.message, buf)? - } else if let Message::Close(_) = self.message { - // Close message is different in that the type byte is missing - ::encode(&self.message, buf)? - } else { - ::encode(&self.message, buf)? - }; - Ok(before - rest.len()) + fn vec_encode<'a>(vec: &[Self], buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> + where + Self: Sized, + { + let body_len = prencode_channel_messages(&vec)?; + let mut buffer = checked_write_uint24_le(body_len, buffer)?; + match vec { + [] => Ok(buffer), + [msg] => { + buffer = match msg.message { + Message::Open(_) => write_array(&OPEN_MESSAGE_PREFIX, buffer)?, + Message::Close(_) => write_array(&CLOSE_MESSAGE_PREFIX, buffer)?, + _ => msg.channel.encode(buffer)?, + }; + msg.message.encode(buffer) + } + msgs => { + buffer = write_array(&MULTI_MESSAGE_PREFIX, buffer)?; + let mut current_channel: u64 = msgs[0].channel; + buffer = current_channel.encode(buffer)?; + for msg in msgs { + if msg.channel != current_channel { + buffer = write_array(&CHANNEL_CHANGE_SEPERATOR, buffer)?; + buffer = msg.channel.encode(buffer)?; + current_channel = msg.channel; + } + let msg_len = msg.message.encoded_size()?; + buffer = (msg_len as u32).encode(buffer)?; + buffer = msg.message.encode(buffer)?; + } + Ok(buffer) + } + } + } + + fn vec_decode(buffer: &[u8]) -> Result<(Vec, &[u8]), EncodingError> + where + Self: Sized, + { + let mut index = 0; + let mut combined_messages: Vec = vec![]; + while index < buffer.len() { + // There might be zero bytes in between, and with LE, the next message will + // start with a non-zero + if buffer[index] == 0 { + index += 1; + continue; + } + let (frame_len, next_frame_start) = decode_u24(&buffer[index..])?; + let (msgs, length) = decode_unframed_channel_messages(&next_frame_start[..frame_len]) + .map_err(|e| EncodingError::external(&format!("{e}")))?; + if length != frame_len { + warn!( + "Did not know what to do with all the bytes, got {frame_len} but decoded {length}. \ + This may be because the peer implements a newer protocol version \ + that has extra fields.", + ); + } + combined_messages.extend(msgs); + index += UINT24_HEADER_LEN + frame_len; + } + todo!() } } @@ -569,8 +620,7 @@ mod tests { use super::*; use hypercore::{ - encoding::to_encoded_bytes, DataBlock, DataHash, DataSeek, DataUpgrade, Node, RequestBlock, - RequestSeek, RequestUpgrade, + DataBlock, DataHash, DataSeek, DataUpgrade, Node, RequestBlock, RequestSeek, RequestUpgrade, }; macro_rules! message_enc_dec { @@ -578,8 +628,8 @@ mod tests { $( let channel = rand::random::() as u64; let channel_message = ChannelMessage::new(channel, $msg); - let encoded_len = channel_message.encoded_len()?; - let mut buf = vec![0u8; encoded_len]; + let encoded_size = channel_message.encoded_size()?; + let mut buf = vec![0u8; encoded_size]; let rest = ::encode(&channel_message, &mut buf)?; assert!(rest.is_empty()); let (decoded, rest) = ::decode(&buf)?; diff --git a/src/util.rs b/src/util.rs index 21e4c75..579a0fd 100644 --- a/src/util.rs +++ b/src/util.rs @@ -73,6 +73,7 @@ pub(crate) fn write_uint24_le(n: usize, buf: &mut [u8]) { } #[inline] +/// Read uint24 from the given `buffer` as a `u64` pub(crate) fn stat_uint24_le(buffer: &[u8]) -> Option<(usize, u64)> { if buffer.len() >= 3 { let len = From 564c0060c6fbcdd52b59108166562a3535bab211 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 30 Apr 2025 11:28:17 -0400 Subject: [PATCH 081/206] Only use ChanMsg::channel when not Open & Close --- src/message/modern.rs | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/message/modern.rs b/src/message/modern.rs index bdc6164..ae0dd22 100644 --- a/src/message/modern.rs +++ b/src/message/modern.rs @@ -526,11 +526,21 @@ impl ChannelMessage { /// encode differently impl CompactEncoding for ChannelMessage { fn encoded_size(&self) -> Result { - Ok(self.channel.encoded_size()? + self.message.encoded_size()?) + let channel_size = if let Message::Open(_) | Message::Close(_) = &self.message { + 0 + } else { + self.channel.encoded_size()? + }; + + Ok(channel_size + self.message.encoded_size()?) } fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { - let rest = self.channel.encode(buffer)?; + let rest = if let Message::Open(_) | Message::Close(_) = &self.message { + buffer + } else { + self.channel.encode(buffer)? + }; ::encode(&self.message, rest) } From 861ee29bd60254128681f7860d162e88880a604f Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 30 Apr 2025 11:28:55 -0400 Subject: [PATCH 082/206] Add #[instrument] --- src/message/modern.rs | 1 + src/schema.rs | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/src/message/modern.rs b/src/message/modern.rs index ae0dd22..23524b2 100644 --- a/src/message/modern.rs +++ b/src/message/modern.rs @@ -456,6 +456,7 @@ impl ChannelMessage { /// /// Note: `buf` has to have a valid length, and without the 3 LE /// bytes in it + #[instrument(skip_all, err)] pub(crate) fn decode_open_message(buf: &[u8]) -> io::Result<(Self, usize)> { let og_len = buf.len(); if og_len <= 5 { diff --git a/src/schema.rs b/src/schema.rs index 41ae6aa..c58a40b 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -6,6 +6,7 @@ use hypercore::{ decode, DataBlock, DataHash, DataSeek, DataUpgrade, Proof, RequestBlock, RequestSeek, RequestUpgrade, }; +use tracing::instrument; /// Open message #[derive(Debug, Clone, PartialEq)] @@ -21,6 +22,7 @@ pub struct Open { } impl CompactEncoding for Open { + #[instrument(skip_all, ret, err)] fn encoded_size(&self) -> Result { let out = sum_encoded_size!(self.channel, self.protocol, self.discovery_key); if self.capability.is_some() { @@ -33,6 +35,7 @@ impl CompactEncoding for Open { Ok(out) } + #[instrument(skip_all)] fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { let rest = map_encode!(buffer, self.channel, self.protocol, self.discovery_key); if let Some(cap) = &self.capability { @@ -42,6 +45,7 @@ impl CompactEncoding for Open { Ok(rest) } + #[instrument(skip_all, err)] fn decode(buffer: &[u8]) -> Result<(Self, &[u8]), EncodingError> where Self: Sized, From a62ce55b58cfef5000e88f180930d66bb2a63651 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 30 Apr 2025 11:33:38 -0400 Subject: [PATCH 083/206] rm old stuff --- src/message/old.rs | 813 -------------------------------------------- src/protocol/old.rs | 706 -------------------------------------- src/reader.rs | 246 -------------- src/writer.rs | 198 ----------- 4 files changed, 1963 deletions(-) delete mode 100644 src/message/old.rs delete mode 100644 src/protocol/old.rs delete mode 100644 src/reader.rs delete mode 100644 src/writer.rs diff --git a/src/message/old.rs b/src/message/old.rs deleted file mode 100644 index d4afd64..0000000 --- a/src/message/old.rs +++ /dev/null @@ -1,813 +0,0 @@ -use crate::schema::*; -use crate::util::{stat_uint24_le, write_uint24_le}; -use hypercore::encoding::{ - CompactEncoding, EncodingError, EncodingErrorKind, HypercoreState, State, -}; -use pretty_hash::fmt as pretty_fmt; -use std::fmt; -use std::io; - -/// The type of a data frame. -#[derive(Debug, Clone, PartialEq)] -pub(crate) enum FrameType { - Raw, - Message, -} - -/// Encode data into a buffer. -/// -/// This trait is implemented on data frames and their components -/// (channel messages, messages, and individual message types through prost). -pub(crate) trait Encoder: Sized + fmt::Debug { - /// Calculates the length that the encoded message needs. - fn encoded_len(&mut self) -> Result; - - /// Encodes the message to a buffer. - /// - /// An error will be returned if the buffer does not have sufficient capacity. - fn encode(&mut self, buf: &mut [u8]) -> Result; -} - -impl Encoder for &[u8] { - fn encoded_len(&mut self) -> Result { - Ok(self.len()) - } - - fn encode(&mut self, buf: &mut [u8]) -> Result { - let len = self.encoded_len()?; - if len > buf.len() { - return Err(EncodingError::new( - EncodingErrorKind::Overflow, - &format!("Length does not fit buffer, {} > {}", len, buf.len()), - )); - } - buf[..len].copy_from_slice(&self[..]); - Ok(len) - } -} - -/// A frame of data, either a buffer or a message. -#[derive(Clone, PartialEq)] -pub(crate) enum Frame { - /// A raw batch binary buffer. Used in the handshaking phase. - RawBatch(Vec>), - /// Message batch, containing one or more channel messsages. Used for everything after the handshake. - MessageBatch(Vec), -} - -impl fmt::Debug for Frame { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Frame::RawBatch(batch) => write!(f, "Frame(RawBatch <{}>)", batch.len()), - Frame::MessageBatch(messages) => write!(f, "Frame({messages:?})"), - } - } -} - -impl From for Frame { - fn from(m: ChannelMessage) -> Self { - Self::MessageBatch(vec![m]) - } -} - -impl From> for Frame { - fn from(m: Vec) -> Self { - Self::RawBatch(vec![m]) - } -} - -impl Frame { - /// Decodes a frame from a buffer containing multiple concurrent messages. - pub(crate) fn decode_multiple(buf: &[u8], frame_type: &FrameType) -> Result { - match frame_type { - FrameType::Raw => { - let mut index = 0; - let mut raw_batch: Vec> = vec![]; - while index < buf.len() { - // There might be zero bytes in between, and with LE, the next message will - // start with a non-zero - if buf[index] == 0 { - index += 1; - continue; - } - let stat = stat_uint24_le(&buf[index..]); - if let Some((header_len, body_len)) = stat { - raw_batch.push( - buf[index + header_len..index + header_len + body_len as usize] - .to_vec(), - ); - index += header_len + body_len as usize; - } else { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "received invalid data in raw batch", - )); - } - } - Ok(Frame::RawBatch(raw_batch)) - } - FrameType::Message => { - let mut index = 0; - let mut combined_messages: Vec = vec![]; - while index < buf.len() { - // There might be zero bytes in between, and with LE, the next message will - // start with a non-zero - if buf[index] == 0 { - index += 1; - continue; - } - - let stat = stat_uint24_le(&buf[index..]); - if let Some((header_len, body_len)) = stat { - let (frame, length) = Self::decode_message( - &buf[index + header_len..index + header_len + body_len as usize], - )?; - if length != body_len as usize { - tracing::warn!( - "Did not know what to do with all the bytes, got {} but decoded {}. \ - This may be because the peer implements a newer protocol version \ - that has extra fields.", - body_len, - length - ); - } - if let Frame::MessageBatch(messages) = frame { - for message in messages { - combined_messages.push(message); - } - } else { - unreachable!("Can not get Raw messages"); - } - index += header_len + body_len as usize; - } else { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "received invalid data in multi-message chunk", - )); - } - } - Ok(Frame::MessageBatch(combined_messages)) - } - } - } - - /// Decode a frame from a buffer. - pub(crate) fn decode(buf: &[u8], frame_type: &FrameType) -> Result { - match frame_type { - FrameType::Raw => Ok(Frame::RawBatch(vec![buf.to_vec()])), - FrameType::Message => { - let (frame, _) = Self::decode_message(buf)?; - Ok(frame) - } - } - } - - fn decode_message(buf: &[u8]) -> Result<(Self, usize), io::Error> { - // buffer length >= 3 or more and starts with 0 is message batch - if buf.len() >= 3 && buf[0] == 0x00 { - if buf[1] == 0x00 { - // Batch of messages - let mut messages: Vec = vec![]; - let mut state = State::new_with_start_and_end(2, buf.len()); - - // First, there is the original channel - let mut current_channel: u64 = state.decode(buf)?; - while state.start() < state.end() { - // Length of the message is inbetween here - let channel_message_length: usize = state.decode(buf)?; - if state.start() + channel_message_length > state.end() { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - format!( - "received invalid message length, {} + {} > {}", - state.start(), - channel_message_length, - state.end() - ), - )); - } - // Then the actual message - let (channel_message, _) = ChannelMessage::decode( - &buf[state.start()..state.start() + channel_message_length], - current_channel, - )?; - messages.push(channel_message); - state.add_start(channel_message_length)?; - // After that, if there is an extra 0x00, that means the channel - // changed. This works because of LE encoding, and channels starting - // from the index 1. - if state.start() < state.end() && buf[state.start()] == 0x00 { - state.add_start(1)?; - current_channel = state.decode(buf)?; - } - } - Ok((Frame::MessageBatch(messages), state.start())) - } else if buf[1] == 0x01 { - // Open message - let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?; - Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) - } else if buf[1] == 0x03 { - // Close message - let (channel_message, length) = ChannelMessage::decode_close_message(&buf[2..])?; - Ok((Frame::MessageBatch(vec![channel_message]), length + 2)) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - "received invalid special message", - )) - } - } else if buf.len() >= 2 { - // len >= and - // Single message - let mut state = State::from_buffer(buf); - let channel: u64 = state.decode(buf)?; - let (channel_message, length) = ChannelMessage::decode(&buf[state.start()..], channel)?; - Ok(( - Frame::MessageBatch(vec![channel_message]), - state.start() + length, - )) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("received too short message, {buf:02X?}"), - )) - } - } - - fn preencode(&mut self, state: &mut State) -> Result { - match self { - Self::RawBatch(raw_batch) => { - for raw in raw_batch { - state.add_end(raw.as_slice().encoded_len()?)?; - } - } - #[allow(clippy::comparison_chain)] - Self::MessageBatch(messages) => { - if messages.len() == 1 { - if let Message::Open(_) = &messages[0].message { - // This is a special case with 0x00, 0x01 intro bytes - state.add_end(2 + &messages[0].encoded_len()?)?; - } else if let Message::Close(_) = &messages[0].message { - // This is a special case with 0x00, 0x03 intro bytes - state.add_end(2 + &messages[0].encoded_len()?)?; - } else { - (*state).preencode(&messages[0].channel)?; - state.add_end(messages[0].encoded_len()?)?; - } - } else if messages.len() > 1 { - // Two intro bytes 0x00 0x00, then channel id, then lengths - state.add_end(2)?; - let mut current_channel: u64 = messages[0].channel; - state.preencode(¤t_channel)?; - for message in messages.iter_mut() { - if message.channel != current_channel { - // Channel changed, need to add a 0x00 in between and then the new - // channel - state.add_end(1)?; - state.preencode(&message.channel)?; - current_channel = message.channel; - } - let message_length = message.encoded_len()?; - state.preencode(&message_length)?; - state.add_end(message_length)?; - } - } - } - } - Ok(state.end()) - } -} - -impl Encoder for Frame { - fn encoded_len(&mut self) -> Result { - let body_len = self.preencode(&mut State::new())?; - match self { - Self::RawBatch(_) => Ok(body_len), - Self::MessageBatch(_) => Ok(3 + body_len), - } - } - - fn encode(&mut self, buf: &mut [u8]) -> Result { - let mut state = State::new(); - let header_len = if let Self::RawBatch(_) = self { 0 } else { 3 }; - let body_len = self.preencode(&mut state)?; - let len = body_len + header_len; - if buf.len() < len { - return Err(EncodingError::new( - EncodingErrorKind::Overflow, - &format!("Length does not fit buffer, {} > {}", len, buf.len()), - )); - } - match self { - Self::RawBatch(ref raw_batch) => { - for raw in raw_batch { - raw.as_slice().encode(buf)?; - } - } - #[allow(clippy::comparison_chain)] - Self::MessageBatch(ref mut messages) => { - write_uint24_le(body_len, buf); - let buf = buf.get_mut(3..).expect("Buffer should be over 3 bytes"); - if messages.len() == 1 { - if let Message::Open(_) = &messages[0].message { - // This is a special case with 0x00, 0x01 intro bytes - state.encode(&(0_u8), buf)?; - state.encode(&(1_u8), buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } else if let Message::Close(_) = &messages[0].message { - // This is a special case with 0x00, 0x03 intro bytes - state.encode(&(0_u8), buf)?; - state.encode(&(3_u8), buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } else { - state.encode(&messages[0].channel, buf)?; - state.add_start(messages[0].encode(&mut buf[state.start()..])?)?; - } - } else if messages.len() > 1 { - // Two intro bytes 0x00 0x00, then channel id, then lengths - state.set_slice_to_buffer(&[0_u8, 0_u8], buf)?; - let mut current_channel: u64 = messages[0].channel; - state.encode(¤t_channel, buf)?; - for message in messages.iter_mut() { - if message.channel != current_channel { - // Channel changed, need to add a 0x00 in between and then the new - // channel - state.encode(&(0_u8), buf)?; - state.encode(&message.channel, buf)?; - current_channel = message.channel; - } - let message_length = message.encoded_len()?; - state.encode(&message_length, buf)?; - state.add_start(message.encode(&mut buf[state.start()..])?)?; - } - } - } - }; - Ok(len) - } -} - -/// A protocol message. -#[derive(Debug, Clone, PartialEq)] -#[allow(missing_docs)] -pub enum Message { - Open(Open), - Close(Close), - Synchronize(Synchronize), - Request(Request), - Cancel(Cancel), - Data(Data), - NoData(NoData), - Want(Want), - Unwant(Unwant), - Bitfield(Bitfield), - Range(Range), - Extension(Extension), - /// A local signalling message never sent over the wire - LocalSignal((String, Vec)), -} - -impl Message { - /// Wire type of this message. - pub(crate) fn typ(&self) -> u64 { - match self { - Self::Synchronize(_) => 0, - Self::Request(_) => 1, - Self::Cancel(_) => 2, - Self::Data(_) => 3, - Self::NoData(_) => 4, - Self::Want(_) => 5, - Self::Unwant(_) => 6, - Self::Bitfield(_) => 7, - Self::Range(_) => 8, - Self::Extension(_) => 9, - value => unimplemented!("{} does not have a type", value), - } - } - - /// Decode a message from a buffer based on type. - pub(crate) fn decode(buf: &[u8], typ: u64) -> Result<(Self, usize), EncodingError> { - let mut state = HypercoreState::from_buffer(buf); - let message = match typ { - 0 => Ok(Self::Synchronize((*state).decode(buf)?)), - 1 => Ok(Self::Request(state.decode(buf)?)), - 2 => Ok(Self::Cancel((*state).decode(buf)?)), - 3 => Ok(Self::Data(state.decode(buf)?)), - 4 => Ok(Self::NoData((*state).decode(buf)?)), - 5 => Ok(Self::Want((*state).decode(buf)?)), - 6 => Ok(Self::Unwant((*state).decode(buf)?)), - 7 => Ok(Self::Bitfield((*state).decode(buf)?)), - 8 => Ok(Self::Range((*state).decode(buf)?)), - 9 => Ok(Self::Extension((*state).decode(buf)?)), - _ => Err(EncodingError::new( - EncodingErrorKind::InvalidData, - &format!("Invalid message type to decode: {typ}"), - )), - }?; - Ok((message, state.start())) - } - - /// Pre-encodes a message to state, returns length - pub(crate) fn preencode(&self, state: &mut HypercoreState) -> Result { - match self { - Self::Open(ref message) => state.0.preencode(message)?, - Self::Close(ref message) => state.0.preencode(message)?, - Self::Synchronize(ref message) => state.0.preencode(message)?, - Self::Request(ref message) => state.preencode(message)?, - Self::Cancel(ref message) => state.0.preencode(message)?, - Self::Data(ref message) => state.preencode(message)?, - Self::NoData(ref message) => state.0.preencode(message)?, - Self::Want(ref message) => state.0.preencode(message)?, - Self::Unwant(ref message) => state.0.preencode(message)?, - Self::Bitfield(ref message) => state.0.preencode(message)?, - Self::Range(ref message) => state.0.preencode(message)?, - Self::Extension(ref message) => state.0.preencode(message)?, - Self::LocalSignal(_) => 0, - }; - Ok(state.end()) - } - - /// Encodes a message to a given buffer, using preencoded state, results size - pub(crate) fn encode( - &self, - state: &mut HypercoreState, - buf: &mut [u8], - ) -> Result { - match self { - Self::Open(ref message) => state.0.encode(message, buf)?, - Self::Close(ref message) => state.0.encode(message, buf)?, - Self::Synchronize(ref message) => state.0.encode(message, buf)?, - Self::Request(ref message) => state.encode(message, buf)?, - Self::Cancel(ref message) => state.0.encode(message, buf)?, - Self::Data(ref message) => state.encode(message, buf)?, - Self::NoData(ref message) => state.0.encode(message, buf)?, - Self::Want(ref message) => state.0.encode(message, buf)?, - Self::Unwant(ref message) => state.0.encode(message, buf)?, - Self::Bitfield(ref message) => state.0.encode(message, buf)?, - Self::Range(ref message) => state.0.encode(message, buf)?, - Self::Extension(ref message) => state.0.encode(message, buf)?, - Self::LocalSignal(_) => 0, - }; - Ok(state.start()) - } -} - -impl fmt::Display for Message { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Open(msg) => write!( - f, - "Open(discovery_key: {}, capability <{}>)", - pretty_fmt(&msg.discovery_key).unwrap(), - msg.capability.as_ref().map_or(0, |c| c.len()) - ), - Self::Data(msg) => write!( - f, - "Data(request: {}, fork: {}, block: {}, hash: {}, seek: {}, upgrade: {})", - msg.request, - msg.fork, - msg.block.is_some(), - msg.hash.is_some(), - msg.seek.is_some(), - msg.upgrade.is_some(), - ), - _ => write!(f, "{:?}", &self), - } - } -} - -/// A message on a channel. -#[derive(Clone)] -pub(crate) struct ChannelMessage { - pub(crate) channel: u64, - pub(crate) message: Message, - state: Option, -} - -impl PartialEq for ChannelMessage { - fn eq(&self, other: &Self) -> bool { - self.channel == other.channel && self.message == other.message - } -} - -impl fmt::Debug for ChannelMessage { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "ChannelMessage({}, {})", self.channel, self.message) - } -} - -impl ChannelMessage { - /// Create a new message. - pub(crate) fn new(channel: u64, message: Message) -> Self { - Self { - channel, - message, - state: None, - } - } - - /// Consume self and return (channel, Message). - pub(crate) fn into_split(self) -> (u64, Message) { - (self.channel, self.message) - } - - /// Decodes an open message for a channel message from a buffer. - /// - /// Note: `buf` has to have a valid length, and without the 3 LE - /// bytes in it - pub(crate) fn decode_open_message(buf: &[u8]) -> io::Result<(Self, usize)> { - if buf.len() <= 5 { - return Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - "received too short Open message", - )); - } - - let mut state = State::new_with_start_and_end(0, buf.len()); - let open_msg: Open = state.decode(buf)?; - Ok(( - Self { - channel: open_msg.channel, - message: Message::Open(open_msg), - state: None, - }, - state.start(), - )) - } - - /// Decodes a close message for a channel message from a buffer. - /// - /// Note: `buf` has to have a valid length, and without the 3 LE - /// bytes in it - pub(crate) fn decode_close_message(buf: &[u8]) -> io::Result<(Self, usize)> { - if buf.is_empty() { - return Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - "received too short Close message", - )); - } - let mut state = State::new_with_start_and_end(0, buf.len()); - let close_msg: Close = state.decode(buf)?; - Ok(( - Self { - channel: close_msg.channel, - message: Message::Close(close_msg), - state: None, - }, - state.start(), - )) - } - - /// Decode a normal channel message from a buffer. - /// - /// Note: `buf` has to have a valid length, and without the 3 LE - /// bytes in it - pub(crate) fn decode(buf: &[u8], channel: u64) -> io::Result<(Self, usize)> { - if buf.len() <= 1 { - return Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - "received empty message", - )); - } - let mut state = State::from_buffer(buf); - let typ: u64 = state.decode(buf)?; - let (message, length) = Message::decode(&buf[state.start()..], typ)?; - Ok(( - Self { - channel, - message, - state: None, - }, - state.start() + length, - )) - } - - /// Performance optimization for letting calling encoded_len() already do - /// the preencode phase of compact_encoding. - fn prepare_state(&mut self) -> Result<(), EncodingError> { - if self.state.is_none() { - let state = if let Message::Open(_) = self.message { - // Open message doesn't have a type - // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L41 - let mut state = HypercoreState::new(); - self.message.preencode(&mut state)?; - state - } else if let Message::Close(_) = self.message { - // Close message doesn't have a type - // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L162 - let mut state = HypercoreState::new(); - self.message.preencode(&mut state)?; - state - } else { - // The header is the channel id uint followed by message type uint - // https://github.com/mafintosh/protomux/blob/43d5192f31e7a7907db44c11afef3195b7797508/index.js#L179 - let mut state = HypercoreState::new(); - let typ = self.message.typ(); - (*state).preencode(&typ)?; - self.message.preencode(&mut state)?; - state - }; - self.state = Some(state); - } - Ok(()) - } -} - -impl Encoder for ChannelMessage { - fn encoded_len(&mut self) -> Result { - self.prepare_state()?; - Ok(self.state.as_ref().unwrap().end()) - } - - fn encode(&mut self, buf: &mut [u8]) -> Result { - self.prepare_state()?; - let state = self.state.as_mut().unwrap(); - if let Message::Open(_) = self.message { - // Open message is different in that the type byte is missing - self.message.encode(state, buf)?; - } else if let Message::Close(_) = self.message { - // Close message is different in that the type byte is missing - self.message.encode(state, buf)?; - } else { - let typ = self.message.typ(); - state.0.encode(&typ, buf)?; - self.message.encode(state, buf)?; - } - Ok(state.start()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use hypercore::{ - DataBlock, DataHash, DataSeek, DataUpgrade, Node, RequestBlock, RequestSeek, RequestUpgrade, - }; - - macro_rules! message_enc_dec { - ($( $msg:expr ),*) => { - $( - let channel = rand::random::() as u64; - let mut channel_message = ChannelMessage::new(channel, $msg); - let encoded_len = channel_message.encoded_len().expect("Failed to get encoded length"); - let mut buf = vec![0u8; encoded_len]; - let n = channel_message.encode(&mut buf[..]).expect("Failed to encode message"); - let decoded = ChannelMessage::decode(&buf[..n], channel).expect("Failed to decode message").0.into_split(); - assert_eq!(channel, decoded.0); - assert_eq!($msg, decoded.1); - )* - } - } - #[test] - fn frame_encode_decode() -> std::io::Result<()> { - let msg = Message::Synchronize(Synchronize { - fork: 0, - can_upgrade: true, - downloading: true, - uploading: true, - length: 5, - remote_length: 0, - }); - - let channel = rand::random::() as u64; - let channel_message = ChannelMessage::new(channel, msg); - - let mut frame = Frame::from(channel_message); - let mut buf = vec![0; frame.encoded_len()?]; - frame.encode(&mut buf)?; - let res_frame = Frame::decode_multiple(&buf, &FrameType::Message)?; - assert_eq!(res_frame, frame); - Ok(()) - } - #[test] - fn frame_encode_decode_bar() -> std::io::Result<()> { - let msg = Message::Synchronize(Synchronize { - fork: 0, - can_upgrade: true, - downloading: true, - uploading: true, - length: 5, - remote_length: 0, - }); - - //let channel = rand::random::() as u64; - let channel = 42; - let channel_message = ChannelMessage::new(channel, msg); - - let mut frame = Frame::from(channel_message.clone()); - - let mut fbuf = vec![0; frame.encoded_len()?]; - - frame.encode(&mut fbuf)?; - - let fres = Frame::decode_multiple(&fbuf, &FrameType::Message)?; - assert_eq!(fres, frame); - //assert_eq!(cres, cmvec); - //println!("REG frame buf\t{frame_buf:02X?}"); - //let res_frame = Frame::decode(&frame_buf, &FrameType::Message)?; - //dbg!(res_frame); - //let res_frame = Frame::decode_multiple(&frame_buf, &FrameType::Message)?; - //dbg!(res_frame); - - //let mut vec_frame_buf = vec![0; vec_frame.encoded_len()?]; - //vec_frame.encode(&mut vec_frame_buf)?; - - //assert_eq!(vec_frame_buf, frame_buf); - //println!("VEC frame buf\t{vec_frame_buf:02X?}"); - - //let res_frame = Frame::decode(&vec_frame_buf, &FrameType::Message)?; - //dbg!(res_frame); - //let res_frame = Frame::decode_multiple(&vec_frame_buf, &FrameType::Message)?; - //dbg!(&res_frame); - - //let (msg, _len) = decode_channel_messages(&vec_frame_buf)?; - //assert_eq!(msg, vec![channel_message]); - - //assert_eq!(res_frame, frame); - Ok(()) - } - - #[test] - fn message_encode_decode() { - message_enc_dec! { - Message::Synchronize(Synchronize{ - fork: 0, - can_upgrade: true, - downloading: true, - uploading: true, - length: 5, - remote_length: 0, - }), - Message::Request(Request { - id: 1, - fork: 1, - block: Some(RequestBlock { - index: 5, - nodes: 10, - }), - hash: Some(RequestBlock { - index: 20, - nodes: 0 - }), - seek: Some(RequestSeek { - bytes: 10 - }), - upgrade: Some(RequestUpgrade { - start: 0, - length: 10 - }) - }), - Message::Cancel(Cancel { - request: 1, - }), - Message::Data(Data{ - request: 1, - fork: 5, - block: Some(DataBlock { - index: 5, - nodes: vec![Node::new(1, vec![0x01; 32], 100)], - value: vec![0xFF; 10] - }), - hash: Some(DataHash { - index: 20, - nodes: vec![Node::new(2, vec![0x02; 32], 200)], - }), - seek: Some(DataSeek { - bytes: 10, - nodes: vec![Node::new(3, vec![0x03; 32], 300)], - }), - upgrade: Some(DataUpgrade { - start: 0, - length: 10, - nodes: vec![Node::new(4, vec![0x04; 32], 400)], - additional_nodes: vec![Node::new(5, vec![0x05; 32], 500)], - signature: vec![0xAB; 32] - }) - }), - Message::NoData(NoData { - request: 2, - }), - Message::Want(Want { - start: 0, - length: 100, - }), - Message::Unwant(Unwant { - start: 10, - length: 2, - }), - Message::Bitfield(Bitfield { - start: 20, - bitfield: vec![0x89ABCDEF, 0x00, 0xFFFFFFFF], - }), - Message::Range(Range { - drop: true, - start: 12345, - length: 100000 - }), - Message::Extension(Extension { - name: "custom_extension/v1/open".to_string(), - message: vec![0x44, 20] - }) - }; - } -} diff --git a/src/protocol/old.rs b/src/protocol/old.rs deleted file mode 100644 index 20c9064..0000000 --- a/src/protocol/old.rs +++ /dev/null @@ -1,706 +0,0 @@ -use async_channel::{Receiver, Sender}; -use futures_lite::io::{AsyncRead, AsyncWrite}; -use futures_lite::stream::Stream; -use futures_timer::Delay; -use std::collections::VecDeque; -use std::convert::TryInto; -use std::fmt; -use std::future::Future; -use std::io::{self, Error, ErrorKind, Result}; -use std::pin::Pin; -use std::task::{Context, Poll}; -use std::time::Duration; -use tracing::{instrument, trace}; - -use crate::channels::{Channel, ChannelMap}; -use crate::constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}; -use crate::crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeResult}; -use crate::message::{ChannelMessage, Frame, FrameType, Message}; -use crate::reader::ReadState; -use crate::schema::*; -use crate::util::{map_channel_err, pretty_hash}; -use crate::writer::WriteState; - -macro_rules! return_error { - ($msg:expr) => { - if let Err(e) = $msg { - return Poll::Ready(Err(e)); - } - }; -} - -const CHANNEL_CAP: usize = 1000; -const KEEPALIVE_DURATION: Duration = Duration::from_secs(DEFAULT_KEEPALIVE as u64); - -/// Options for a Protocol instance. -#[derive(Debug)] -pub(crate) struct Options { - /// Whether this peer initiated the IO connection for this protoccol - pub(crate) is_initiator: bool, - /// Enable or disable the handshake. - /// Disabling the handshake will also disable capabilitity verification. - /// Don't disable this if you're not 100% sure you want this. - pub(crate) noise: bool, - /// Enable or disable transport encryption. - pub(crate) encrypted: bool, -} - -impl Options { - /// Create with default options. - pub(crate) fn new(is_initiator: bool) -> Self { - Self { - is_initiator, - noise: true, - encrypted: true, - } - } -} - -/// Remote public key (32 bytes). -pub(crate) type RemotePublicKey = [u8; 32]; -/// Discovery key (32 bytes). -pub type DiscoveryKey = [u8; 32]; -/// Key (32 bytes). -pub type Key = [u8; 32]; - -/// A protocol event. -#[non_exhaustive] -#[derive(PartialEq)] -pub enum Event { - /// Emitted after the handshake with the remote peer is complete. - /// This is the first event (if the handshake is not disabled). - Handshake(RemotePublicKey), - /// Emitted when the remote peer opens a channel that we did not yet open. - DiscoveryKey(DiscoveryKey), - /// Emitted when a channel is established. - Channel(Channel), - /// Emitted when a channel is closed. - Close(DiscoveryKey), - /// Convenience event to make it possible to signal the protocol from a channel. - /// See channel.signal_local() and protocol.commands().signal_local(). - LocalSignal((String, Vec)), -} - -/// A protocol command. -#[derive(Debug)] -pub enum Command { - /// Open a channel - Open(Key), - /// Close a channel by discovery key - Close(DiscoveryKey), - /// Signal locally to protocol - SignalLocal((String, Vec)), -} - -impl fmt::Debug for Event { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Event::Handshake(remote_key) => { - write!(f, "Handshake(remote_key={})", &pretty_hash(remote_key)) - } - Event::DiscoveryKey(discovery_key) => { - write!(f, "DiscoveryKey({})", &pretty_hash(discovery_key)) - } - Event::Channel(channel) => { - write!(f, "Channel({})", &pretty_hash(channel.discovery_key())) - } - Event::Close(discovery_key) => write!(f, "Close({})", &pretty_hash(discovery_key)), - Event::LocalSignal((name, data)) => { - write!(f, "LocalSignal(name={},len={})", name, data.len()) - } - } - } -} - -/// Protocol state -#[allow(clippy::large_enum_variant)] -pub(crate) enum State { - NotInitialized, - // The Handshake struct sits behind an option only so that we can .take() - // it out, it's never actually empty when in State::Handshake. - Handshake(Option), - SecretStream(Option), - Established, -} - -impl fmt::Debug for State { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - State::NotInitialized => write!(f, "NotInitialized"), - State::Handshake(_) => write!(f, "Handshaking"), - State::SecretStream(_) => write!(f, "SecretStream"), - State::Established => write!(f, "Established"), - } - } -} - -/// A Protocol stream. -pub struct Protocol { - write_state: WriteState, - read_state: ReadState, - io: IO, - state: State, - options: Options, - handshake: Option, - channels: ChannelMap, - command_rx: Receiver, - command_tx: CommandTx, - outbound_rx: Receiver>, - outbound_tx: Sender>, - keepalive: Delay, - queued_events: VecDeque, -} - -impl std::fmt::Debug for Protocol { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Protocol") - .field("write_state", &self.write_state) - .field("read_state", &self.read_state) - //.field("io", &self.io) - .field("state", &self.state) - .field("options", &self.options) - .field("handshake", &self.handshake) - .field("channels", &self.channels) - .field("command_rx", &self.command_rx) - .field("command_tx", &self.command_tx) - .field("outbound_rx", &self.outbound_rx) - .field("outbound_tx", &self.outbound_tx) - .field("keepalive", &self.keepalive) - .field("queued_events", &self.queued_events) - .finish() - } -} - -impl Protocol -where - IO: AsyncWrite + AsyncRead + Send + Unpin + 'static, -{ - /// Create a new protocol instance. - pub(crate) fn new(io: IO, options: Options) -> Self { - let (command_tx, command_rx) = async_channel::bounded(CHANNEL_CAP); - let (outbound_tx, outbound_rx): ( - Sender>, - Receiver>, - ) = async_channel::bounded(1); - Protocol { - io, - read_state: ReadState::new(), - write_state: WriteState::new(), - options, - state: State::NotInitialized, - channels: ChannelMap::new(), - handshake: None, - command_rx, - command_tx: CommandTx(command_tx), - outbound_tx, - outbound_rx, - keepalive: Delay::new(Duration::from_secs(DEFAULT_KEEPALIVE as u64)), - queued_events: VecDeque::new(), - } - } - - /// Whether this protocol stream initiated the underlying IO connection. - pub fn is_initiator(&self) -> bool { - self.options.is_initiator - } - - /// Get your own Noise public key. - /// - /// Empty before the handshake completed. - pub fn public_key(&self) -> Option<&[u8]> { - match &self.handshake { - None => None, - Some(handshake) => Some(handshake.local_pubkey.as_slice()), - } - } - - /// Get the remote's Noise public key. - /// - /// Empty before the handshake completed. - pub fn remote_public_key(&self) -> Option<&[u8]> { - match &self.handshake { - None => None, - Some(handshake) => Some(handshake.remote_pubkey.as_slice()), - } - } - - /// Get a sender to send commands. - pub fn commands(&self) -> CommandTx { - self.command_tx.clone() - } - - /// Give a command to the protocol. - #[instrument(skip(self))] - pub async fn command(&mut self, command: Command) -> Result<()> { - self.command_tx.send(command).await - } - - /// Open a new protocol channel. - /// - /// Once the other side proofed that it also knows the `key`, the channel is emitted as - /// `Event::Channel` on the protocol event stream. - #[instrument(skip(self))] - pub async fn open(&mut self, key: Key) -> Result<()> { - self.command_tx.open(key).await - } - - /// Iterator of all currently opened channels. - pub fn channels(&self) -> impl Iterator { - self.channels.iter().map(|c| c.discovery_key()) - } - - #[instrument(skip_all, fields(initiator = ?self.is_initiator()))] - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - - if let State::NotInitialized = this.state { - return_error!(this.init()); - } - - // Drain queued events first. - if let Some(event) = this.queued_events.pop_front() { - return Poll::Ready(Ok(event)); - } - - // Read and process incoming messages. - return_error!(this.poll_inbound_read(cx)); - - if let State::Established = this.state { - // Check for commands, but only once the connection is established. - return_error!(this.poll_commands(cx)); - } - - // Poll the keepalive timer. - this.poll_keepalive(cx); - - // Write everything we can write. - return_error!(this.poll_outbound_write(cx)); - - // Check if any events are enqueued. - if let Some(event) = this.queued_events.pop_front() { - Poll::Ready(Ok(event)) - } else { - Poll::Pending - } - } - - fn init(&mut self) -> Result<()> { - trace!( - "protocol Init, state {:?}, options {:?}", - self.state, - self.options - ); - match self.state { - State::NotInitialized => {} - _ => return Ok(()), - }; - - self.state = if self.options.noise { - let mut handshake = Handshake::new(self.options.is_initiator)?; - // If the handshake start returns a buffer, send it now. - if let Some(buf) = handshake.start()? { - // TODO what if this fails? or returns false - self.queue_frame_direct(buf.to_vec()).unwrap(); - } - self.read_state.set_frame_type(FrameType::Raw); - State::Handshake(Some(handshake)) - } else { - self.read_state.set_frame_type(FrameType::Message); - State::Established - }; - - Ok(()) - } - - /// Poll commands. - #[instrument(skip_all)] - fn poll_commands(&mut self, cx: &mut Context<'_>) -> Result<()> { - while let Poll::Ready(Some(command)) = Pin::new(&mut self.command_rx).poll_next(cx) { - self.on_command(command)?; - } - Ok(()) - } - - /// Poll the keepalive timer and queue a ping message if needed. - fn poll_keepalive(&mut self, cx: &mut Context<'_>) { - if Pin::new(&mut self.keepalive).poll(cx).is_ready() { - if let State::Established = self.state { - // 24 bit header for the empty message, hence the 3 - self.write_state - .queue_frame(Frame::RawBatch(vec![vec![0u8; 3]])); - } - self.keepalive.reset(KEEPALIVE_DURATION); - } - } - - fn on_outbound_message(&mut self, message: &ChannelMessage) -> bool { - // If message is close, close the local channel. - if let ChannelMessage { - channel, - message: Message::Close(_), - .. - } = message - { - self.close_local(*channel); - // If message is a LocalSignal, emit an event and return false to indicate - // this message should be filtered out. - } else if let ChannelMessage { - message: Message::LocalSignal((name, data)), - .. - } = message - { - self.queue_event(Event::LocalSignal((name.to_string(), data.to_vec()))); - return false; - } - true - } - - /// Poll for inbound messages and processs them. - #[instrument(skip_all)] - fn poll_inbound_read(&mut self, cx: &mut Context<'_>) -> Result<()> { - loop { - let msg = self.read_state.poll_reader(cx, &mut self.io); - match msg { - Poll::Ready(Ok(message)) => { - self.on_inbound_frame(message)?; - } - Poll::Ready(Err(e)) => return Err(e), - Poll::Pending => return Ok(()), - } - } - } - - /// Poll for outbound messages and write them. - #[instrument(skip_all)] - fn poll_outbound_write(&mut self, cx: &mut Context<'_>) -> Result<()> { - loop { - if let Poll::Ready(Err(e)) = self.write_state.poll_send(cx, &mut self.io) { - return Err(e); - } - // if no parking or setup in progress - if !self.write_state.can_park_frame() || !matches!(self.state, State::Established) { - return Ok(()); - } - - match Pin::new(&mut self.outbound_rx).poll_next(cx) { - Poll::Ready(Some(mut messages)) => { - if !messages.is_empty() { - messages.retain(|message| self.on_outbound_message(message)); - if !messages.is_empty() { - let frame = Frame::MessageBatch(messages); - // TODO try replacing this with queue_frame - self.write_state.park_frame(frame); - } - } - } - Poll::Ready(None) => unreachable!("Channel closed before end"), - Poll::Pending => return Ok(()), - } - } - } - - #[instrument(skip_all)] - fn on_inbound_frame(&mut self, frame: Frame) -> Result<()> { - match frame { - Frame::RawBatch(raw_batch) => { - let mut processed_state: Option = None; - for buf in raw_batch { - let state_name: String = format!("{:?}", self.state); - match self.state { - State::Handshake(_) => self.on_handshake_message(buf)?, - State::SecretStream(_) => self.on_secret_stream_message(buf)?, - State::Established => { - if let Some(processed_state) = processed_state.as_ref() { - // last state before established - let previous_state = if self.options.encrypted { - // was SecretStream if we're encrypted - State::SecretStream(None) - } else { - // or wa hasdshake if we're not encrypted - State::Handshake(None) - }; - - // if htis raw_batch included regular messages (not handshake) - // after handshake stuff - if processed_state == &format!("{previous_state:?}") { - // This is the unlucky case where the batch had two or more messages where - // the first one was correctly identified as Raw but everything - // after that should have been (decrypted and) a MessageBatch. Correct the mistake - // here post-hoc. - let buf = self.read_state.decrypt_buf(&buf)?; - let frame = Frame::decode(&buf, &FrameType::Message)?; - self.on_inbound_frame(frame)?; - continue; - } - } - unreachable!( - "May not receive raw frames in Established state" - ) - } - _ => unreachable!( - "May not receive raw frames outside of handshake or secretstream state, was {:?}", - self.state - ), - }; - if processed_state.is_none() { - processed_state = Some(state_name) - } - } - Ok(()) - } - Frame::MessageBatch(channel_messages) => match self.state { - State::Established => { - for channel_message in channel_messages { - self.on_inbound_message(channel_message)? - } - Ok(()) - } - _ => unreachable!("May not receive message batch frames when not established"), - }, - } - } - - fn on_handshake_message(&mut self, buf: Vec) -> Result<()> { - let mut handshake = match &mut self.state { - State::Handshake(handshake) => handshake.take().unwrap(), - _ => unreachable!("May not call on_handshake_message when not in Handshake state"), - }; - - if let Some(response_buf) = handshake.read(&buf)? { - self.queue_frame_direct(response_buf.to_vec()).unwrap(); - } - - if !handshake.complete() { - self.state = State::Handshake(Some(handshake)); - } else { - let handshake_result = handshake.into_result()?; - - if self.options.encrypted { - // The cipher will be put to use to the writer only after the peer's answer has come - let (cipher, init_msg) = EncryptCipher::from_handshake_tx(handshake_result)?; - self.state = State::SecretStream(Some(cipher)); - - // Send the secret stream init message header to the other side - self.queue_frame_direct(init_msg).unwrap(); - } else { - // Skip secret stream and go straight to Established, then notify about - // handshake - self.read_state.set_frame_type(FrameType::Message); - let remote_public_key = parse_key(&handshake_result.remote_pubkey)?; - self.queue_event(Event::Handshake(remote_public_key)); - self.state = State::Established; - } - // Store handshake result - self.handshake = Some(handshake_result.clone()); - } - Ok(()) - } - - fn on_secret_stream_message(&mut self, buf: Vec) -> Result<()> { - let encrypt_cipher = match &mut self.state { - State::SecretStream(encrypt_cipher) => encrypt_cipher.take().unwrap(), - _ => { - unreachable!("May not call on_secret_stream_message when not in SecretStream state") - } - }; - let handshake_result = &self - .handshake - .as_ref() - .expect("Handshake result must be set before secret stream"); - let decrypt_cipher = DecryptCipher::from_handshake_rx_and_init_msg(handshake_result, &buf)?; - self.read_state.upgrade_with_decrypt_cipher(decrypt_cipher); - self.write_state.upgrade_with_encrypt_cipher(encrypt_cipher); - self.read_state.set_frame_type(FrameType::Message); - - // Lastly notify that handshake is ready and set state to established - let remote_public_key = parse_key(&handshake_result.remote_pubkey)?; - self.queue_event(Event::Handshake(remote_public_key)); - self.state = State::Established; - Ok(()) - } - #[instrument(skip_all)] - fn on_inbound_message(&mut self, channel_message: ChannelMessage) -> Result<()> { - // let channel_message = ChannelMessage::decode(buf)?; - let (remote_id, message) = channel_message.into_split(); - match message { - Message::Open(msg) => self.on_open(remote_id, msg)?, - Message::Close(msg) => self.on_close(remote_id, msg)?, - _ => self - .channels - .forward_inbound_message(remote_id as usize, message)?, - } - Ok(()) - } - - #[instrument(skip(self))] - fn on_command(&mut self, command: Command) -> Result<()> { - match command { - Command::Open(key) => self.command_open(key), - Command::Close(discovery_key) => self.command_close(discovery_key), - Command::SignalLocal((name, data)) => self.command_signal_local(name, data), - } - } - - /// Open a Channel with the given key. Adding it to our channel map - #[instrument(skip_all)] - fn command_open(&mut self, key: Key) -> Result<()> { - // Create a new channel. - let channel_handle = self.channels.attach_local(key); - // Safe because attach_local always puts Some(local_id) - let local_id = channel_handle.local_id().unwrap(); - let discovery_key = *channel_handle.discovery_key(); - - // If the channel was already opened from the remote end, verify, and if - // verification is ok, push a channel open event. - if channel_handle.is_connected() { - self.accept_channel(local_id)?; - } - - // Tell the remote end about the new channel. - let capability = self.capability(&key); - let channel = local_id as u64; - let message = Message::Open(Open { - channel, - protocol: PROTOCOL_NAME.to_string(), - discovery_key: discovery_key.to_vec(), - capability, - }); - let channel_message = ChannelMessage::new(channel, message); - self.write_state - .queue_frame(Frame::MessageBatch(vec![channel_message])); - Ok(()) - } - - fn command_close(&mut self, discovery_key: DiscoveryKey) -> Result<()> { - if self.channels.has_channel(&discovery_key) { - self.channels.remove(&discovery_key); - self.queue_event(Event::Close(discovery_key)); - } - Ok(()) - } - - fn command_signal_local(&mut self, name: String, data: Vec) -> Result<()> { - self.queue_event(Event::LocalSignal((name, data))); - Ok(()) - } - - #[instrument(skip(self))] - fn on_open(&mut self, ch: u64, msg: Open) -> Result<()> { - let discovery_key: DiscoveryKey = parse_key(&msg.discovery_key)?; - let channel_handle = - self.channels - .attach_remote(discovery_key, ch as usize, msg.capability); - - if channel_handle.is_connected() { - let local_id = channel_handle.local_id().unwrap(); - self.accept_channel(local_id)?; - } else { - self.queue_event(Event::DiscoveryKey(discovery_key)); - } - - Ok(()) - } - - #[instrument(skip(self))] - fn queue_event(&mut self, event: Event) { - self.queued_events.push_back(event); - } - - /// enequeu a buf to be sent - fn queue_frame_direct(&mut self, body: Vec) -> Result { - let mut frame = Frame::RawBatch(vec![body]); - self.write_state - .try_encode_and_enqueue_frame_for_tx(&mut frame) - } - - #[instrument(skip(self))] - fn accept_channel(&mut self, local_id: usize) -> Result<()> { - let (key, remote_capability) = self.channels.prepare_to_verify(local_id)?; - self.verify_remote_capability(remote_capability.cloned(), key)?; - let channel = self.channels.accept(local_id, self.outbound_tx.clone())?; - self.queue_event(Event::Channel(channel)); - Ok(()) - } - - fn close_local(&mut self, local_id: u64) { - if let Some(channel) = self.channels.get_local(local_id as usize) { - let discovery_key = *channel.discovery_key(); - self.channels.remove(&discovery_key); - self.queue_event(Event::Close(discovery_key)); - } - } - - fn on_close(&mut self, remote_id: u64, msg: Close) -> Result<()> { - if let Some(channel_handle) = self.channels.get_remote(remote_id as usize) { - let discovery_key = *channel_handle.discovery_key(); - // There is a possibility both sides will close at the same time, so - // the channel could be closed already, let's tolerate that. - self.channels - .forward_inbound_message_tolerate_closed(remote_id as usize, Message::Close(msg))?; - self.channels.remove(&discovery_key); - self.queue_event(Event::Close(discovery_key)); - } - Ok(()) - } - - #[instrument(skip_all)] - fn capability(&self, key: &[u8]) -> Option> { - match self.handshake.as_ref() { - Some(handshake) => handshake.capability(key), - None => None, - } - } - - fn verify_remote_capability(&self, capability: Option>, key: &[u8]) -> Result<()> { - match self.handshake.as_ref() { - Some(handshake) => handshake.verify_remote_capability(capability, key), - None => Err(Error::new( - ErrorKind::PermissionDenied, - "Missing handshake state for capability verification", - )), - } - } -} - -impl Stream for Protocol -where - IO: AsyncRead + AsyncWrite + Send + Unpin + 'static, -{ - type Item = Result; - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Protocol::poll_next(self, cx).map(Some) - } -} - -/// Send [Command](Command)s to the [Protocol](Protocol). -#[derive(Clone, Debug)] -pub struct CommandTx(Sender); - -impl CommandTx { - /// Send a protocol command - pub async fn send(&mut self, command: Command) -> Result<()> { - self.0.send(command).await.map_err(map_channel_err) - } - /// Open a protocol channel. - /// - /// The channel will be emitted on the main protocol. - pub async fn open(&mut self, key: Key) -> Result<()> { - self.send(Command::Open(key)).await - } - - /// Close a protocol channel. - pub async fn close(&mut self, discovery_key: DiscoveryKey) -> Result<()> { - self.send(Command::Close(discovery_key)).await - } - - /// Send a local signal event to the protocol. - pub async fn signal_local(&mut self, name: &str, data: Vec) -> Result<()> { - self.send(Command::SignalLocal((name.to_string(), data))) - .await - } -} - -fn parse_key(key: &[u8]) -> io::Result<[u8; 32]> { - key.try_into() - .map_err(|_e| io::Error::new(io::ErrorKind::InvalidInput, "Key must be 32 bytes long")) -} diff --git a/src/reader.rs b/src/reader.rs deleted file mode 100644 index cc80c5c..0000000 --- a/src/reader.rs +++ /dev/null @@ -1,246 +0,0 @@ -use crate::crypto::DecryptCipher; -use futures_lite::io::AsyncRead; -use futures_timer::Delay; -use std::future::Future; -use std::io::{Error, ErrorKind, Result}; -use std::pin::Pin; -use std::task::{Context, Poll}; - -use crate::constants::{DEFAULT_TIMEOUT, MAX_MESSAGE_SIZE}; -use crate::message::{Frame, FrameType}; -use crate::util::stat_uint24_le; -use std::time::Duration; - -const TIMEOUT: Duration = Duration::from_secs(DEFAULT_TIMEOUT as u64); -const READ_BUF_INITIAL_SIZE: usize = 1024 * 128; - -#[derive(Debug)] -pub(crate) struct ReadState { - /// The read buffer. - buf: Vec, - /// The start of the not-yet-processed byte range in the read buffer. - start: usize, - /// The end of the not-yet-processed byte range in the read buffer. - end: usize, - /// The logical state of the reading (either header or body). - step: Step, - /// The timeout after which the connection is closed. - timeout: Delay, - /// Optional decryption cipher. - cipher: Option, - /// The frame type to be passed to the decoder. - frame_type: FrameType, -} - -impl ReadState { - pub(crate) fn new() -> ReadState { - ReadState { - buf: vec![0u8; READ_BUF_INITIAL_SIZE], - start: 0, - end: 0, - step: Step::Header, - timeout: Delay::new(TIMEOUT), - cipher: None, - frame_type: FrameType::Raw, - } - } -} - -#[derive(Debug)] -enum Step { - Header, - Body { - header_len: usize, - body_len: usize, - }, - /// Multiple messages one after another - Batch, -} - -impl ReadState { - pub(crate) fn upgrade_with_decrypt_cipher(&mut self, decrypt_cipher: DecryptCipher) { - self.cipher = Some(decrypt_cipher); - } - - /// Decrypts a given buf with stored cipher, if present. Used to correct - /// the rare mistake that more than two messages came in where the first - /// one created the cipher, and the next one should have been decrypted - /// but wasn't. - pub(crate) fn decrypt_buf(&mut self, buf: &[u8]) -> Result> { - if let Some(cipher) = self.cipher.as_mut() { - Ok(cipher.decrypt_buf(buf)?.0) - } else { - Ok(buf.to_vec()) - } - } - - pub(crate) fn set_frame_type(&mut self, frame_type: FrameType) { - self.frame_type = frame_type; - } - - pub(crate) fn poll_reader( - &mut self, - cx: &mut Context<'_>, - mut reader: &mut R, - ) -> Poll> - where - R: AsyncRead + Unpin, - { - let mut incomplete = true; - loop { - if !incomplete { - if let Some(result) = self.process() { - return Poll::Ready(result); - } - } else { - incomplete = false; - } - let n = match Pin::new(&mut reader).poll_read(cx, &mut self.buf[self.end..]) { - Poll::Ready(Ok(n)) if n > 0 => n, - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - // If the reader is pending, poll the timeout. - Poll::Pending | Poll::Ready(Ok(_)) => { - // Return Pending if the timeout is pending, or an error if the - // timeout expired (i.e. returned Poll::Ready). - return Pin::new(&mut self.timeout) - .poll(cx) - .map(|()| Err(Error::new(ErrorKind::TimedOut, "Remote timed out"))); - } - }; - - let end = self.end + n; - let (success, segments) = create_segments(&self.buf[self.start..end])?; - if success { - if let Some(ref mut cipher) = self.cipher { - let mut dec_end = self.start; - // What happens if decrypt fails here? - // next call to this func would have same start, corret? - // so it'd fail repeatedly? - // Why not just decrypt to the end? - for (index, header_len, body_len) in segments { - let de = cipher.decrypt( - &mut self.buf[self.start + index..end], - header_len, - body_len, - )?; - dec_end = self.start + index + de; - } - self.end = dec_end; - } else { - self.end = end; - } - } else { - // Could not segment due to buffer being full, need to cycle the buffer - // and possibly resize it too if the message is too big. - self.cycle_buf_and_resize_if_needed(segments[segments.len() - 1]); - - // Set incomplete flag to skip processing and instead poll more data - incomplete = true; - } - self.timeout.reset(TIMEOUT); - } - } - - /// Moves start of unprocessed data to the start of the buffer. And resize if necessary. - fn cycle_buf_and_resize_if_needed(&mut self, last_segment: (usize, usize, usize)) { - let (last_index, last_header_len, last_body_len) = last_segment; - let total_incoming_length = last_index + last_header_len + last_body_len; - - if self.buf.len() < total_incoming_length { - // The incoming segments will not fit into the buffer, need to resize it - self.buf.resize(total_incoming_length, 0u8); - } - - // to-read length - let temp = self.buf[self.start..].to_vec(); - let len = temp.len(); - self.buf[..len].copy_from_slice(&temp[..]); - self.end = len; - self.start = 0; - } - - fn process(&mut self) -> Option> { - loop { - match self.step { - Step::Header => { - let stat = stat_uint24_le(&self.buf[self.start..self.end]); - if let Some((header_len, body_len)) = stat { - if body_len == 0 { - // This is a keepalive message, just remain in Step::Header - self.start += header_len; - return None; - } else if (self.start + header_len + body_len as usize) < self.end { - // There are more than one message here, create a batch from all of - // then - self.step = Step::Batch; - } else { - let body_len = body_len as usize; - if body_len > MAX_MESSAGE_SIZE as usize { - return Some(Err(Error::new( - ErrorKind::InvalidData, - "Message length above max allowed size", - ))); - } - self.step = Step::Body { - header_len, - body_len, - }; - } - } else { - return Some(Err(Error::new(ErrorKind::InvalidData, "Invalid header"))); - } - } - - // one message within an encrypted frame - // encrypted frame [ u24 header + encoded_frame [ ]] - Step::Body { - header_len, - body_len, - } => { - let message_len = header_len + body_len; - let range = self.start + header_len..self.start + message_len; - // this includes a a frame header - let frame = Frame::decode(&self.buf[range], &self.frame_type); - self.start += message_len; - self.step = Step::Header; - return Some(frame); - } - // multiple message within an encrypted frame - Step::Batch => { - let frame = - Frame::decode_multiple(&self.buf[self.start..self.end], &self.frame_type); - self.start = self.end; - self.step = Step::Header; - return Some(frame); - } - } - } - } -} - -#[allow(clippy::type_complexity)] -/// Given a buff get all the segments (starting_index_in_buffer, header_len, buffer_len) -/// returns returns `(true, segments)` if we read all segments, but (false, ..) if there -/// are remaining segments -fn create_segments(buf: &[u8]) -> Result<(bool, Vec<(usize, usize, usize)>)> { - let mut index: usize = 0; - let len = buf.len(); - let mut segments: Vec<(usize, usize, usize)> = vec![]; - while index < len { - if let Some((header_len, body_len)) = stat_uint24_le(&buf[index..]) { - let body_len = body_len as usize; - segments.push((index, header_len, body_len)); - if len < index + header_len + body_len { - // The segments will not fit, return false to indicate that more needs to be read - return Ok((false, segments)); - } - index += header_len + body_len; - } else { - return Err(Error::new( - ErrorKind::InvalidData, - "Could not read header while decrypting", - )); - } - } - Ok((true, segments)) -} diff --git a/src/writer.rs b/src/writer.rs deleted file mode 100644 index 9a1465b..0000000 --- a/src/writer.rs +++ /dev/null @@ -1,198 +0,0 @@ -use crate::crypto::EncryptCipher; -use crate::message::{Encoder, Frame}; -use tracing::instrument; - -use futures_lite::{ready, AsyncWrite}; -use std::collections::VecDeque; -use std::fmt; -use std::io::Result; -use std::pin::Pin; -use std::task::{Context, Poll}; - -const BUF_SIZE: usize = 1024 * 64; - -#[derive(Debug)] -pub(crate) enum Step { - Flushing, - Writing, - Processing, -} - -pub(crate) struct WriteState { - queue: VecDeque, - current_frame: Option, - cipher: Option, - buf: Vec, - written_up_to_idx: usize, - should_write_up_to_idx: usize, - step: Step, -} - -impl fmt::Debug for WriteState { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("WriteState") - .field("queue (len)", &self.queue.len()) - .field("current_frame", &self.current_frame) - .field("cipher", &self.cipher.is_some()) - .field("buf (len)", &self.buf.len()) - .field("start", &self.written_up_to_idx) - .field("end", &self.should_write_up_to_idx) - .field("step", &self.step) - .finish() - } -} - -impl WriteState { - pub(crate) fn new() -> Self { - Self { - queue: VecDeque::new(), - buf: vec![0u8; BUF_SIZE], - current_frame: None, - written_up_to_idx: 0, - should_write_up_to_idx: 0, - cipher: None, - step: Step::Processing, - } - } - - pub(crate) fn queue_frame(&mut self, frame: F) - where - F: Into, - { - self.queue.push_back(frame.into()) - } - - #[instrument(skip(self))] - pub(crate) fn try_encode_and_enqueue_frame_for_tx( - &mut self, - frame: &mut T, - ) -> Result { - let promised_len = frame.encoded_len()?; - let padded_promised_len = self.safe_encrypted_len(promised_len); - // this handles when a message would be longer than the entire buffer - if self.buf.len() < padded_promised_len { - self.buf.resize(padded_promised_len, 0u8); - } - - // check we have enough room - if padded_promised_len > self.remaining() { - return Ok(false); - } - - // write frame starting at end. fram is from end to end + actual_end - let actual_len = frame.encode(&mut self.buf[self.should_write_up_to_idx..])?; - if actual_len != promised_len { - panic!( - "encoded_len() did not return that right size, expected={promised_len}, actual={actual_len}" - ); - } - // Instead of the above, write the buffer to a new vec `foo` of length `promised_length` - // encode frame.to this buff - // slice `foo[(header_len /* 3*/)..actual_len]` this is the fram data - // encrypt this in place - // replace header at start of foo - // write its len to self.buf and then write it to self.buf - // slice from - - self.encrypt_frame_contents_onto_buf(padded_promised_len)?; - Ok(true) - } - - pub(crate) fn can_park_frame(&self) -> bool { - self.current_frame.is_none() - } - - pub(crate) fn park_frame(&mut self, frame: F) - where - F: Into, - { - if self.current_frame.is_none() { - self.current_frame = Some(frame.into()) - } - } - - /// The frame should be written to `self.buf` before calling this. And - /// `self.should_write_up_to_idx` should mark the start of the message. - /// `max_message_size` is the maximum size the message could be when it is encrypted - /// We encrypt the message in-place on `self.buf`. - fn encrypt_frame_contents_onto_buf(&mut self, max_message_size: usize) -> Result<()> { - let end_of_message_index = self.should_write_up_to_idx + max_message_size; - - let encrypted_end = if let Some(ref mut cipher) = self.cipher { - self.should_write_up_to_idx - + cipher - .encrypt(&mut self.buf[self.should_write_up_to_idx..end_of_message_index])? - } else { - end_of_message_index - }; - - self.should_write_up_to_idx = encrypted_end; - Ok(()) - } - - pub(crate) fn upgrade_with_encrypt_cipher(&mut self, encrypt_cipher: EncryptCipher) { - self.cipher = Some(encrypt_cipher); - } - - fn remaining(&self) -> usize { - self.buf.len() - self.should_write_up_to_idx - } - - fn pending(&self) -> usize { - self.should_write_up_to_idx - self.written_up_to_idx - } - - pub(crate) fn poll_send( - &mut self, - cx: &mut Context<'_>, - mut writer: &mut W, - ) -> Poll> - where - W: AsyncWrite + Unpin, - { - loop { - self.step = match self.step { - Step::Processing => { - if self.current_frame.is_none() && !self.queue.is_empty() { - self.current_frame = self.queue.pop_front(); - } - - if let Some(mut frame) = self.current_frame.take() { - if !self.try_encode_and_enqueue_frame_for_tx(&mut frame)? { - self.current_frame = Some(frame); - } - } - - if self.pending() == 0 { - return Poll::Ready(Ok(())); - } - Step::Writing - } - Step::Writing => { - let n = ready!(Pin::new(&mut writer).poll_write( - cx, - &self.buf[self.written_up_to_idx..self.should_write_up_to_idx] - ))?; - self.written_up_to_idx += n; - if self.written_up_to_idx == self.should_write_up_to_idx { - self.written_up_to_idx = 0; - self.should_write_up_to_idx = 0; - } - Step::Flushing - } - Step::Flushing => { - ready!(Pin::new(&mut writer).poll_flush(cx))?; - Step::Processing - } - } - } - } - - fn safe_encrypted_len(&self, encoded_len: usize) -> usize { - if let Some(cipher) = &self.cipher { - cipher.safe_encrypted_len(encoded_len) - } else { - encoded_len - } - } -} From 6f1995de65a560003d04dbcebf96d5f3ed51e660 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 30 Apr 2025 11:36:26 -0400 Subject: [PATCH 084/206] cargo clippy --fix --- benches/pipe.rs | 29 +++++++++++------------ benches/throughput.rs | 53 ++++++++++++++++++------------------------- src/message/modern.rs | 2 +- src/test_utils.rs | 2 +- 4 files changed, 37 insertions(+), 49 deletions(-) diff --git a/benches/pipe.rs b/benches/pipe.rs index 630146c..b726545 100644 --- a/benches/pipe.rs +++ b/benches/pipe.rs @@ -18,7 +18,7 @@ fn bench_throughput(c: &mut Criterion) { env_logger::from_env(env_logger::Env::default().default_filter_or("error")).init(); let mut group = c.benchmark_group("pipe"); group.sample_size(10); - group.throughput(Throughput::Bytes(SIZE * COUNT * CONNS as u64)); + group.throughput(Throughput::Bytes(SIZE * COUNT * CONNS)); group.bench_function("pipe_echo", |b| { b.iter(|| { task::block_on(async move { @@ -72,7 +72,7 @@ where debug!("[{}] EVENT {:?}", is_initiator, event); match event { Event::Handshake(_) => { - protocol.open(key.clone()).await?; + protocol.open(key).await?; } Event::DiscoveryKey(_dkey) => {} Event::Channel(channel) => { @@ -92,7 +92,7 @@ where } Some(Err(err)) => { error!("ERROR {:?}", err); - return Err(err.into()); + return Err(err); } None => return Ok(0), } @@ -127,20 +127,17 @@ async fn on_channel_init(i: u64, mut channel: Channel) -> Result { let start = std::time::Instant::now(); while let Some(message) = channel.next().await { - match message { - Message::Data(mut data) => { - len += value_len(&data); - debug!("[a] recv {}", index(&data)); - if index(&data) >= COUNT { - debug!("close at {}", index(&data)); - channel.close().await?; - break; - } else { - increment_index(&mut data); - channel.send(Message::Data(data)).await?; - } + if let Message::Data(mut data) = message { + len += value_len(&data); + debug!("[a] recv {}", index(&data)); + if index(&data) >= COUNT { + debug!("close at {}", index(&data)); + channel.close().await?; + break; + } else { + increment_index(&mut data); + channel.send(Message::Data(data)).await?; } - _ => {} } } // let bytes = (COUNT * SIZE) as f64; diff --git a/benches/throughput.rs b/benches/throughput.rs index 7f9890d..6b9d6af 100644 --- a/benches/throughput.rs +++ b/benches/throughput.rs @@ -71,15 +71,12 @@ async fn start_server(address: &str) -> futures::channel::oneshot::Sender<()> { // let kill_rx = &mut kill_rx; loop { match futures::future::select(incoming.next(), &mut kill_rx).await { - Either::Left((next, _)) => match next { - Some(Ok(stream)) => { - let peer_addr = stream.peer_addr().unwrap(); - debug!("new connection from {}", peer_addr); - task::spawn(async move { - onconnection(stream.clone(), stream, false).await; - }); - } - _ => {} + Either::Left((next, _)) => if let Some(Ok(stream)) = next { + let peer_addr = stream.peer_addr().unwrap(); + debug!("new connection from {}", peer_addr); + task::spawn(async move { + onconnection(stream.clone(), stream, false).await; + }); }, Either::Right((_, _)) => return, } @@ -101,7 +98,7 @@ where // eprintln!("RECV EVENT [{}] {:?}", protocol.is_initiator(), event); match event { Event::Handshake(_) => { - protocol.open(key.clone()).await.unwrap(); + protocol.open(key).await.unwrap(); } Event::DiscoveryKey(_) => {} Event::Channel(channel) => { @@ -126,10 +123,7 @@ async fn onchannel(mut channel: Channel, is_initiator: bool) { async fn channel_server(channel: &mut Channel) { while let Some(message) = channel.next().await { - match message { - Message::Data(_) => channel.send(message).await.unwrap(), - _ => {} - } + if let Message::Data(_) = message { channel.send(message).await.unwrap() } } } @@ -139,24 +133,21 @@ async fn channel_client(channel: &mut Channel) { let message = msg_data(0, data.clone()); channel.send(message).await.unwrap(); while let Some(message) = channel.next().await { - match message { - Message::Data(ref msg) => { - if index(msg) < COUNT { - let message = msg_data(index(msg) + 1, data.clone()); - channel.send(message).await.unwrap(); - } else { - let time = start.elapsed(); - let bytes = COUNT * SIZE; - trace!( - "client completed. {} blocks, {} bytes, {:?}", - index(msg), - bytes, - time - ); - break; - } + if let Message::Data(ref msg) = message { + if index(msg) < COUNT { + let message = msg_data(index(msg) + 1, data.clone()); + channel.send(message).await.unwrap(); + } else { + let time = start.elapsed(); + let bytes = COUNT * SIZE; + trace!( + "client completed. {} blocks, {} bytes, {:?}", + index(msg), + bytes, + time + ); + break; } - _ => {} } } } diff --git a/src/message/modern.rs b/src/message/modern.rs index 23524b2..1b68e24 100644 --- a/src/message/modern.rs +++ b/src/message/modern.rs @@ -565,7 +565,7 @@ impl VecEncodable for ChannelMessage { where Self: Sized, { - let body_len = prencode_channel_messages(&vec)?; + let body_len = prencode_channel_messages(vec)?; let mut buffer = checked_write_uint24_le(body_len, buffer)?; match vec { [] => Ok(buffer), diff --git a/src/test_utils.rs b/src/test_utils.rs index 3f687ea..5309529 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -185,7 +185,7 @@ impl Moo { fn result_channel() -> (Sender>, impl Stream>>) { let (tx, rx) = unbounded::>(); - (tx, rx.map(|x| Ok(x))) + (tx, rx.map(Ok)) } pub(crate) fn create_result_connected() -> ( From ca329b4a9c59a1bdd9469f9d47f5007c27d325c4 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 30 Apr 2025 11:45:14 -0400 Subject: [PATCH 085/206] Remove protocol feature --- Cargo.toml | 3 +- src/constants.rs | 9 ---- src/crypto/cipher.rs | 97 ----------------------------------------- src/crypto/handshake.rs | 9 ---- src/crypto/mod.rs | 4 -- src/lib.rs | 6 --- src/message/mod.rs | 9 ---- src/protocol/mod.rs | 10 ----- 8 files changed, 1 insertion(+), 146 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4ceb1f3..0862678 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,10 +66,9 @@ tracing-tree = "0.4.0" tokio-util = { version = "0.7.14", features = ["compat"] } [features] -default = ["tokio", "sparse", "protocol"] +default = ["tokio", "sparse"] #default = ["tokio", "sparse"] uint24 = [] -protocol = [] wasm-bindgen = [ "futures-timer/wasm-bindgen" ] diff --git a/src/constants.rs b/src/constants.rs index 73d0748..1efbbed 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -6,12 +6,3 @@ pub(crate) const DEFAULT_KEEPALIVE: u32 = 10; /// v10: Protocol name pub(crate) const PROTOCOL_NAME: &str = "hypercore/alpha"; - -// 16,78MB is the max encrypted wire message size (will be much smaller usually). -// This limitation stems from the 24bit header. -#[cfg(not(feature = "protocol"))] -pub(crate) const MAX_MESSAGE_SIZE: u64 = 0xFFFFFF; - -/// Default timeout (in seconds) -#[cfg(not(feature = "protocol"))] -pub(crate) const DEFAULT_TIMEOUT: u32 = 20; diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index 53c291f..aa096c3 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -72,103 +72,6 @@ impl DecryptCipher { } } -#[cfg(not(feature = "protocol"))] -mod encrypt_cipher { - use super::*; - use crate::util::{stat_uint24_le, write_uint24_le, UINT_24_LENGTH}; - const HEADER_MSG_LEN: usize = UINT_24_LENGTH + STREAM_ID_LENGTH + Header::BYTES; - - pub(crate) struct EncryptCipher { - push_stream: PushStream, - } - - impl std::fmt::Debug for EncryptCipher { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "EncryptCipher(crypto_secretstream)") - } - } - - impl EncryptCipher { - pub(crate) fn from_handshake_tx( - handshake_result: &HandshakeResult, - ) -> std::io::Result<(Self, Vec)> { - let key: [u8; KEY_LENGTH] = handshake_result.split_tx[..KEY_LENGTH] - .try_into() - .expect("split_tx with incorrect length"); - let key = Key::from(key); - - let mut header_message: [u8; HEADER_MSG_LEN] = [0; HEADER_MSG_LEN]; - write_uint24_le(STREAM_ID_LENGTH + Header::BYTES, &mut header_message); - write_stream_id( - &handshake_result.handshake_hash, - handshake_result.is_initiator, - &mut header_message[UINT_24_LENGTH..UINT_24_LENGTH + STREAM_ID_LENGTH], - ); - - let (header, push_stream) = PushStream::init(OsRng, &key); - let header = header.as_ref(); - header_message[UINT_24_LENGTH + STREAM_ID_LENGTH..].copy_from_slice(header); - let msg = header_message.to_vec(); - Ok((Self { push_stream }, msg)) - } - - /// Get the length needed for encryption, that includes padding. - pub(crate) fn safe_encrypted_len(&self, plaintext_len: usize) -> usize { - // ChaCha20-Poly1305 uses padding in two places, use two 15 bytes as a safe - // extra room. - // https://mailarchive.ietf.org/arch/msg/cfrg/u734TEOSDDWyQgE0pmhxjdncwvw/ - plaintext_len + 2 * 15 - } - - /// Encrypts message in the given buffer to the same buffer, returns number of bytes - /// of total message. - /// NB: we expect the first 3 bytes of the buffer to a size header. - /// The encrypted buffer will also be written prepended with a size header, with it's new size. - pub(crate) fn encrypt(&mut self, buf: &mut [u8]) -> io::Result { - let stat = stat_uint24_le(buf); - if let Some((header_len, body_len)) = stat { - let mut to_encrypt = buf[header_len..header_len + body_len as usize].to_vec(); - self.push_stream - .push(&mut to_encrypt, &[], Tag::Message) - .map_err(|err| { - io::Error::new(io::ErrorKind::Other, format!("Encrypt failed: {err}")) - })?; - let encrypted_len = to_encrypt.len(); - write_uint24_le(encrypted_len, buf); - buf[header_len..header_len + encrypted_len].copy_from_slice(to_encrypt.as_slice()); - Ok(header_len + encrypted_len) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("Could not encrypt invalid data, len: {}", buf.len()), - )) - } - } - } - - impl DecryptCipher { - pub(crate) fn decrypt( - &mut self, - buf: &mut [u8], - header_len: usize, - body_len: usize, - ) -> io::Result { - let (to_decrypt, _tag) = self.decrypt_buf(&buf[header_len..header_len + body_len])?; - let decrypted_len = to_decrypt.len(); - write_uint24_le(decrypted_len, buf); - let decrypted_end = header_len + to_decrypt.len(); - buf[header_len..decrypted_end].copy_from_slice(to_decrypt.as_slice()); - // Set extra bytes in the buffer to 0 - // Why? - let encrypted_end = header_len + body_len; - buf[decrypted_end..encrypted_end].fill(0x00); - Ok(decrypted_end) - } - } -} -#[cfg(not(feature = "protocol"))] -pub(crate) use encrypt_cipher::*; - // NB: These values come from Javascript-side // // const [NS_INITIATOR, NS_RESPONDER] = crypto.namespace('hyperswarm/secret-stream', 2) diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index 72c9da3..492a9a4 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -112,10 +112,6 @@ impl Handshake { Ok(None) } } - #[cfg(not(feature = "protocol"))] - pub(crate) fn start(&mut self) -> Result>> { - Ok(self.start_raw()?.map(|x| crate::util::wrap_uint24_le(&x))) - } pub(crate) fn complete(&self) -> bool { self.complete @@ -178,11 +174,6 @@ impl Handshake { self.complete = true; Ok(tx_buf) } - // reads in `msg` without framing bytes, but emits msg WITH framing bytes - #[cfg(not(feature = "protocol"))] - pub(crate) fn read(&mut self, msg: &[u8]) -> Result>> { - Ok(self.read_raw(msg)?.map(|x| crate::util::wrap_uint24_le(&x))) - } pub(crate) fn into_result(&self) -> Result<&HandshakeResult> { if !self.complete() { diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 3de592a..9e49c0a 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -1,9 +1,5 @@ mod cipher; mod curve; mod handshake; -#[cfg(not(feature = "protocol"))] -pub(crate) use cipher::{DecryptCipher, EncryptCipher, RawEncryptCipher}; - -#[cfg(feature = "protocol")] pub(crate) use cipher::{DecryptCipher, RawEncryptCipher}; pub(crate) use handshake::{Handshake, HandshakeResult}; diff --git a/src/lib.rs b/src/lib.rs index c13ccae..3602517 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -124,17 +124,12 @@ mod crypto; mod duplex; mod framing; mod message; -#[cfg(feature = "protocol")] mod mqueue; mod noise; mod protocol; -#[cfg(not(feature = "protocol"))] -mod reader; #[cfg(test)] mod test_utils; mod util; -#[cfg(not(feature = "protocol"))] -mod writer; /// The wire messages used by the protocol. pub mod schema; @@ -144,7 +139,6 @@ pub use channels::Channel; pub use framing::Uint24LELengthPrefixedFraming; pub use noise::{encrypted_framed_message_channel, Encrypted, Event as NoiseEvent}; // Export the needed types for Channel::take_receiver, and Channel::local_sender() -#[cfg(feature = "protocol")] pub use async_channel::{ Receiver as ChannelReceiver, SendError as ChannelSendError, Sender as ChannelSender, }; diff --git a/src/message/mod.rs b/src/message/mod.rs index 1526f3a..dc42d7a 100644 --- a/src/message/mod.rs +++ b/src/message/mod.rs @@ -1,11 +1,2 @@ -#[cfg(feature = "protocol")] mod modern; - -#[cfg(feature = "protocol")] pub use modern::*; - -#[cfg(not(feature = "protocol"))] -mod old; - -#[cfg(not(feature = "protocol"))] -pub use old::*; diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 7382df8..d24738c 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -1,13 +1,3 @@ -#[cfg(feature = "protocol")] mod modern; -#[cfg(feature = "protocol")] pub(crate) use modern::Options; -#[cfg(feature = "protocol")] pub use modern::{Command, CommandTx, DiscoveryKey, Event, Key, Protocol}; - -#[cfg(not(feature = "protocol"))] -mod old; -#[cfg(not(feature = "protocol"))] -pub(crate) use old::Options; -#[cfg(not(feature = "protocol"))] -pub use old::{Command, CommandTx, DiscoveryKey, Event, Key, Protocol}; From 3548f43871b418506bef6a576e57400e7fda0dac Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 30 Apr 2025 11:45:36 -0400 Subject: [PATCH 086/206] cargo fmt --- benches/throughput.rs | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/benches/throughput.rs b/benches/throughput.rs index 6b9d6af..cc2c278 100644 --- a/benches/throughput.rs +++ b/benches/throughput.rs @@ -71,13 +71,15 @@ async fn start_server(address: &str) -> futures::channel::oneshot::Sender<()> { // let kill_rx = &mut kill_rx; loop { match futures::future::select(incoming.next(), &mut kill_rx).await { - Either::Left((next, _)) => if let Some(Ok(stream)) = next { - let peer_addr = stream.peer_addr().unwrap(); - debug!("new connection from {}", peer_addr); - task::spawn(async move { - onconnection(stream.clone(), stream, false).await; - }); - }, + Either::Left((next, _)) => { + if let Some(Ok(stream)) = next { + let peer_addr = stream.peer_addr().unwrap(); + debug!("new connection from {}", peer_addr); + task::spawn(async move { + onconnection(stream.clone(), stream, false).await; + }); + } + } Either::Right((_, _)) => return, } } @@ -123,7 +125,9 @@ async fn onchannel(mut channel: Channel, is_initiator: bool) { async fn channel_server(channel: &mut Channel) { while let Some(message) = channel.next().await { - if let Message::Data(_) = message { channel.send(message).await.unwrap() } + if let Message::Data(_) = message { + channel.send(message).await.unwrap() + } } } From 72d1d9e0f3e4deff1bb88cf629d86005b9bd0c56 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 30 Apr 2025 11:50:03 -0400 Subject: [PATCH 087/206] clippy fixes --- src/crypto/handshake.rs | 2 +- src/noise.rs | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index 492a9a4..10c111f 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -175,7 +175,7 @@ impl Handshake { Ok(tx_buf) } - pub(crate) fn into_result(&self) -> Result<&HandshakeResult> { + pub(crate) fn get_result(&self) -> Result<&HandshakeResult> { if !self.complete() { Err(Error::new(ErrorKind::Other, "Handshake is not complete")) } else { diff --git a/src/noise.rs b/src/noise.rs index 40ce6ac..5b5b867 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -267,6 +267,7 @@ impl>> + Sink> + Send + Unpin + 'static /// Handle all message throughput. Sends, encrypts and decrypts messages /// Returns `true` `step` is already [`Step::Established`]. +#[allow(clippy::too_many_arguments)] fn poll_message_throughput< IO: Stream>> + Sink> + Send + Unpin + 'static, >( @@ -554,7 +555,7 @@ fn handle_setup_message( if handshake.complete() { debug!(initiator = %is_initiator, "Handshake completed"); - let handshake_result = match handshake.into_result() { + let handshake_result = match handshake.get_result() { Ok(x) => x, Err(e) => { error!("into-result error {e:?}"); From 1e3edd127c8d103d56b1a815856b5ee1f99ec0a0 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 30 Apr 2025 11:53:02 -0400 Subject: [PATCH 088/206] rm nested messaged & protocol modules --- src/{message/modern.rs => message.rs} | 0 src/message/mod.rs | 2 -- src/{protocol/modern.rs => protocol.rs} | 0 src/protocol/mod.rs | 3 --- 4 files changed, 5 deletions(-) rename src/{message/modern.rs => message.rs} (100%) delete mode 100644 src/message/mod.rs rename src/{protocol/modern.rs => protocol.rs} (100%) delete mode 100644 src/protocol/mod.rs diff --git a/src/message/modern.rs b/src/message.rs similarity index 100% rename from src/message/modern.rs rename to src/message.rs diff --git a/src/message/mod.rs b/src/message/mod.rs deleted file mode 100644 index dc42d7a..0000000 --- a/src/message/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -mod modern; -pub use modern::*; diff --git a/src/protocol/modern.rs b/src/protocol.rs similarity index 100% rename from src/protocol/modern.rs rename to src/protocol.rs diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs deleted file mode 100644 index d24738c..0000000 --- a/src/protocol/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod modern; -pub(crate) use modern::Options; -pub use modern::{Command, CommandTx, DiscoveryKey, Event, Key, Protocol}; From 323503633604f6d03a7229808b321fcdb840c910 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 30 Apr 2025 12:08:45 -0400 Subject: [PATCH 089/206] remove encoder trait --- src/message.rs | 63 -------------------------------------------------- src/mqueue.rs | 11 +++++---- 2 files changed, 6 insertions(+), 68 deletions(-) diff --git a/src/message.rs b/src/message.rs index 1b68e24..f222a1e 100644 --- a/src/message.rs +++ b/src/message.rs @@ -15,20 +15,6 @@ const CLOSE_MESSAGE_PREFIX: [u8; 2] = [0, 3]; const MULTI_MESSAGE_PREFIX: [u8; 2] = [0, 0]; const CHANNEL_CHANGE_SEPERATOR: [u8; 1] = [0]; -/// Encode data into a buffer. -/// -/// This trait is implemented on data frames and their components -/// (channel messages, messages, and individual message types through prost). -pub(crate) trait Encoder: Sized + fmt::Debug { - /// Calculates the length that the encoded message needs. - fn encoded_len(&self) -> Result; - - /// Encodes the message to a buffer. - /// - /// An error will be returned if the buffer does not have sufficient capacity. - fn encoder_encode(&self, buf: &mut [u8]) -> Result; -} - pub(crate) fn decode_framed_channel_messages( buf: &[u8], ) -> Result<(Vec, usize), io::Error> { @@ -174,55 +160,6 @@ fn decode_u24(buffer: &[u8]) -> Result<(usize, &[u8]), EncodingError> { Ok((out as usize, rest)) } -impl Encoder for Vec { - fn encoded_len(&self) -> Result { - Ok(prencode_channel_messages(self)? + UINT24_HEADER_LEN) - } - - #[instrument(skip_all)] - fn encoder_encode(&self, buf: &mut [u8]) -> Result { - let body_len = prencode_channel_messages(self)?; - let mut buf = checked_write_uint24_le(body_len, buf)?; - // skip the u24 we just wrote - match self.len().cmp(&1) { - std::cmp::Ordering::Less => {} - std::cmp::Ordering::Equal => { - trace!("Encoding single ChannelMessage {}", self[0]); - if let Message::Open(_) = &self[0].message { - // This is a special case with 0x00, 0x01 intro bytes - buf = write_array(&[0, 1], buf)?; - self[0].encode(buf)?; - } else if let Message::Close(_) = &self[0].message { - // This is a special case with 0x00, 0x03 intro bytes - buf = write_array(&[0, 3], buf)?; - self[0].encode(buf)?; - } else { - self[0].encode(buf)?; - } - } - std::cmp::Ordering::Greater => { - // Two intro bytes 0x00 0x00, then channel id, then lengths - buf = write_array(&[0, 0], buf)?; - let mut current_channel: u64 = self[0].channel; - buf = current_channel.encode(buf)?; - for message in self.iter() { - if message.channel != current_channel { - // Channel changed, need to add a 0x00 in between and then the new - // channel - buf = write_array(&[0], buf)?; - buf = message.channel.encode(buf)?; - current_channel = message.channel; - } - let message_length = message.message.encoded_size()?; - buf = (message_length as u32).encode(buf)?; - buf = message.encode(buf)?; - } - } - } - Ok(UINT24_HEADER_LEN + body_len) - } -} - /// A protocol message. #[derive(Debug, Clone, PartialEq)] #[allow(missing_docs)] diff --git a/src/mqueue.rs b/src/mqueue.rs index 9be4ab7..cd2237b 100644 --- a/src/mqueue.rs +++ b/src/mqueue.rs @@ -9,10 +9,11 @@ use std::{ }; use futures::{Sink, Stream}; +use hypercore::encoding::CompactEncoding as _; use tracing::{error, instrument}; use crate::{ - message::{decode_framed_channel_messages, ChannelMessage, Encoder as _}, + message::{decode_framed_channel_messages, ChannelMessage}, noise::EncryptionInfo, NoiseEvent, }; @@ -81,16 +82,16 @@ impl + Sink> + Send + Unpin + 'static> Mes messages.push(msg); } - let mut buf = vec![0; messages.encoded_len()?]; - match messages.encoder_encode(&mut buf) { - Ok(_) => {} + let buf = match messages.to_encoded_bytes() { + Ok(x) => x, Err(e) => { error!(error = ?e, "error encoding messages"); // TODO this would probably be a programming error. // if so, this sholud just be an unwrap/expect return Poll::Ready(Err(e.into())); } - } + }; + if let Err(_e) = Sink::start_send(Pin::new(&mut self.io), buf) { todo!() } From b9e5b490ebece2165c3a28266d1f92899aa3da5c Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 30 Apr 2025 12:34:11 -0400 Subject: [PATCH 090/206] clippy fixes --- src/framing.rs | 2 +- src/test_utils.rs | 1 + tests/js_interop.rs | 5 ++++- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/framing.rs b/src/framing.rs index 7daef38..12d3c41 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -16,7 +16,7 @@ use crate::util::{stat_uint24_le, wrap_uint24_le}; const BUF_SIZE: usize = 1024 * 64; const _HEADER_LEN: usize = 3; -/// Turn a `AsyncWrite` of length prefixed messages and emit the messages with a Stream +/// take a `AsyncWrite` of length prefixed messages and emit them as a Stream pub struct Uint24LELengthPrefixedFraming { io: IO, /// Data from [`Self::io`]'s [`AsyncRead`] interface to be sent out via the [`Stream`] interface. diff --git a/src/test_utils.rs b/src/test_utils.rs index 5309529..24256cb 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -188,6 +188,7 @@ fn result_channel() -> (Sender>, impl Stream>> (tx, rx.map(Ok)) } +#[allow(clippy::type_complexity)] pub(crate) fn create_result_connected() -> ( Moo>>, impl Sink>>, Moo>>, impl Sink>>, diff --git a/tests/js_interop.rs b/tests/js_interop.rs index d703734..41bad94 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -767,7 +767,10 @@ async fn on_replication_message( for i in 0..new_info.contiguous_length { let value = String::from_utf8(hypercore.get(i).await?.unwrap()).unwrap(); let line = format!("{} {}\n", i, value); - writer.write(line.as_bytes()).await?; + let n_written = writer.write(line.as_bytes()).await?; + if line.len() != n_written { + panic!("Couldn't write all write all bytse"); + } } writer.flush().await?; true From 2f6f694cc029c10dcbab1a621bbd8fa28d5cb384 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Fri, 2 May 2025 18:41:34 -0400 Subject: [PATCH 091/206] add compact_encoding dependency --- Cargo.toml | 3 +++ src/message.rs | 2 +- src/mqueue.rs | 4 ++-- src/schema.rs | 2 +- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0862678..fe7ccda 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,9 @@ path = "../core" #version = "0.14.0" #default-features = false +[dependencies.compact-encoding] +path = "../compact-encoding" + [dev-dependencies] async-std = { version = "1.12.0", features = ["attributes", "unstable"] } diff --git a/src/message.rs b/src/message.rs index f222a1e..aa6ca24 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,6 +1,6 @@ use crate::schema::*; use crate::util::{stat_uint24_le, write_uint24_le}; -use hypercore::encoding::{ +use compact_encoding::{ decode_usize, take_array, take_array_mut, write_array, CompactEncoding, EncodingError, EncodingErrorKind, VecEncodable, }; diff --git a/src/mqueue.rs b/src/mqueue.rs index cd2237b..e5df5b8 100644 --- a/src/mqueue.rs +++ b/src/mqueue.rs @@ -8,8 +8,8 @@ use std::{ task::{Context, Poll}, }; +use compact_encoding::CompactEncoding as _; use futures::{Sink, Stream}; -use hypercore::encoding::CompactEncoding as _; use tracing::{error, instrument}; use crate::{ @@ -92,7 +92,7 @@ impl + Sink> + Send + Unpin + 'static> Mes } }; - if let Err(_e) = Sink::start_send(Pin::new(&mut self.io), buf) { + if let Err(_e) = Sink::start_send(Pin::new(&mut self.io), buf.to_vec()) { todo!() } diff --git a/src/schema.rs b/src/schema.rs index c58a40b..7d6fb58 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -1,4 +1,4 @@ -use hypercore::encoding::{ +use compact_encoding::{ map_encode, sum_encoded_size, take_array, take_array_mut, write_array, write_slice, CompactEncoding, EncodingError, }; From f1188c787cae57f6f879f28068618606180afedd Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 5 May 2025 13:15:16 -0400 Subject: [PATCH 092/206] remove decode macro this came from hypercore but has been removed --- src/schema.rs | 41 ++++++++++++++++++++--------------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/src/schema.rs b/src/schema.rs index 7d6fb58..049a590 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -1,10 +1,9 @@ use compact_encoding::{ - map_encode, sum_encoded_size, take_array, take_array_mut, write_array, write_slice, + map_decode, map_encode, sum_encoded_size, take_array, take_array_mut, write_array, write_slice, CompactEncoding, EncodingError, }; use hypercore::{ - decode, DataBlock, DataHash, DataSeek, DataUpgrade, Proof, RequestBlock, RequestSeek, - RequestUpgrade, + DataBlock, DataHash, DataSeek, DataUpgrade, Proof, RequestBlock, RequestSeek, RequestUpgrade, }; use tracing::instrument; @@ -50,9 +49,8 @@ impl CompactEncoding for Open { where Self: Sized, { - let (channel, rest) = u64::decode(buffer)?; - let (protocol, rest) = String::decode(rest)?; - let (discovery_key, rest) = >::decode(rest)?; + let ((channel, protocol, discovery_key), rest) = + map_decode!(buffer, [u64, String, Vec]); // TODO this is a CLEAR bug it assumes nothing is encoded after this message let (capability, rest) = if !rest.is_empty() { let (_, rest) = take_array::<1>(rest)?; @@ -62,7 +60,7 @@ impl CompactEncoding for Open { (None, rest) }; Ok(( - Open { + Self { channel, protocol, discovery_key, @@ -93,7 +91,8 @@ impl CompactEncoding for Close { where Self: Sized, { - decode!(Close, buffer, {channel: u64}) + let (channel, rest) = u64::decode(buffer)?; + Ok((Self { channel }, rest)) } } @@ -138,10 +137,7 @@ impl CompactEncoding for Synchronize { Self: Sized, { let ([flags], rest) = take_array::<1>(buffer)?; - dbg!(flags); - let (fork, rest) = u64::decode(rest)?; - let (length, rest) = u64::decode(rest)?; - let (remote_length, rest) = u64::decode(rest)?; + let ((fork, length, remote_length), rest) = map_decode!(rest, [u64, u64, u64]); let can_upgrade = flags & 1 != 0; let uploading = flags & 2 != 0; let downloading = flags & 4 != 0; @@ -234,8 +230,7 @@ impl CompactEncoding for Request { Self: Sized, { let ([flags], rest) = take_array::<1>(buffer)?; - let (id, rest) = u64::decode(rest)?; - let (fork, rest) = u64::decode(rest)?; + let ((id, fork), rest) = map_decode!(rest, [u64, u64]); let (block, rest) = maybe_decode!(flags & 1 != 0, RequestBlock, rest); let (hash, rest) = maybe_decode!(flags & 2 != 0, RequestBlock, rest); @@ -345,8 +340,7 @@ impl CompactEncoding for Data { Self: Sized, { let ([flags], rest) = take_array::<1>(buffer)?; - let (request, rest) = u64::decode(rest)?; - let (fork, rest) = u64::decode(rest)?; + let ((request, fork), rest) = map_decode!(rest, [u64, u64]); let (block, rest) = maybe_decode!(flags & 1 != 0, DataBlock, rest); let (hash, rest) = maybe_decode!(flags & 2 != 0, DataHash, rest); let (seek, rest) = maybe_decode!(flags & 4 != 0, DataSeek, rest); @@ -398,7 +392,8 @@ impl CompactEncoding for NoData { where Self: Sized, { - decode!(NoData, buffer, { request: u64 }) + let (request, rest) = u64::decode(buffer)?; + Ok((Self { request }, rest)) } } @@ -424,7 +419,8 @@ impl CompactEncoding for Want { where Self: Sized, { - decode!(Self, buffer, { start: u64, length: u64 }) + let ((start, length), rest) = map_decode!(buffer, [u64, u64]); + Ok((Self { start, length }, rest)) } } @@ -450,7 +446,8 @@ impl CompactEncoding for Unwant { where Self: Sized, { - decode!(Self, buffer, { start: u64, length: u64 }) + let ((start, length), rest) = map_decode!(buffer, [u64, u64]); + Ok((Self { start, length }, rest)) } } @@ -475,7 +472,8 @@ impl CompactEncoding for Bitfield { where Self: Sized, { - decode!(Self, buffer, { start: u64, bitfield: Vec }) + let ((start, bitfield), rest) = map_decode!(buffer, [u64, Vec]); + Ok((Self { start, bitfield }, rest)) } } @@ -556,6 +554,7 @@ impl CompactEncoding for Extension { where Self: Sized, { - decode!(Self, buffer, { name: String, message: Vec }) + let ((name, message), rest) = map_decode!(buffer, [String, Vec]); + Ok((Self { name, message }, rest)) } } From c15aff9595fcf6a5f37367f37905fa7d0e93dce0 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 6 May 2025 01:34:53 -0400 Subject: [PATCH 093/206] update compact-encoding version --- Cargo.toml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fe7ccda..3c31ea0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,16 +40,13 @@ sha2 = "0.10" curve25519-dalek = "4" crypto_secretstream = "0.2" futures = "0.3.31" +compact-encoding = "2" [dependencies.hypercore] path = "../core" #version = "0.14.0" #default-features = false -[dependencies.compact-encoding] -path = "../compact-encoding" - - [dev-dependencies] async-std = { version = "1.12.0", features = ["attributes", "unstable"] } async-compat = "0.2.1" From 6a44f819b8ff64503e348db8f716226cc84975bd Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 6 May 2025 01:35:03 -0400 Subject: [PATCH 094/206] Remove unused features --- Cargo.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3c31ea0..d53f5f6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -67,8 +67,6 @@ tokio-util = { version = "0.7.14", features = ["compat"] } [features] default = ["tokio", "sparse"] -#default = ["tokio", "sparse"] -uint24 = [] wasm-bindgen = [ "futures-timer/wasm-bindgen" ] From e7188075128efd2aae3d24f68a0c102a33b2208d Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 6 May 2025 01:35:55 -0400 Subject: [PATCH 095/206] remove unused uint24 feature --- src/util.rs | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/src/util.rs b/src/util.rs index 579a0fd..7e70336 100644 --- a/src/util.rs +++ b/src/util.rs @@ -29,33 +29,6 @@ pub(crate) fn map_channel_err(err: async_channel::SendError) -> Error { } pub(crate) const UINT_24_LENGTH: usize = 3; -#[cfg(feature = "uint24")] -mod uint24 { - use super::UINT_24_LENGTH; - pub struct Uint24LE([u8; UINT_24_LENGTH]); - impl Uint24LE { - pub const MAX_USIZE: usize = 16777215; - pub const SIZE: usize = UINT_24_LENGTH; - } - - impl AsRef<[u8; 3]> for Uint24LE { - fn as_ref(&self) -> &[u8; 3] { - &self.0 - } - } - - // TODO we are using std::io::Error everywhere so I won't add a new one but this isn't ideal - impl TryFrom for Uint24LE { - type Error = Error; - - fn try_from(n: usize) -> Result { - if n > Self::MAX_USIZE { - todo!() - } - Ok(Self([(n & 255) as u8, (n >> 8) as u8, (n >> 16) as u8])) - } - } -} #[inline] pub(crate) fn wrap_uint24_le(data: &[u8]) -> Vec { let mut buf: Vec = vec![0; 3]; From ec47b843ac69fb60e0eba54d21d9f4bf98ce9e28 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 6 May 2025 01:49:34 -0400 Subject: [PATCH 096/206] remove use of test_log just use our own logger as needed --- tests/js_interop.rs | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/js_interop.rs b/tests/js_interop.rs index 41bad94..28f40ba 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -24,14 +24,13 @@ use async_std::{ task::{self, sleep}, test as async_test, }; -use test_log::test; #[cfg(feature = "tokio")] use tokio::{ fs::{metadata, File}, io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter}, net::{TcpListener, TcpStream}, sync::Mutex, - task, test as async_test, + task, time::sleep, }; @@ -59,28 +58,28 @@ const TEST_SET_SERVER_WRITER: &str = "sw"; const TEST_SET_CLIENT_WRITER: &str = "cw"; const TEST_SET_SIMPLE: &str = "simple"; -#[test(async_test)] +#[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] async fn js_interop_ncns_simple_server_writer() -> Result<()> { js_interop_ncns_simple(true, 8101).await?; Ok(()) } -#[test(async_test)] +#[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] async fn js_interop_ncns_simple_client_writer() -> Result<()> { js_interop_ncns_simple(false, 8102).await?; Ok(()) } -#[test(async_test)] +#[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] async fn js_interop_rcns_simple_server_writer() -> Result<()> { js_interop_rcns_simple(true, 8103).await?; Ok(()) } -#[test(async_test)] +#[tokio::test] //#[cfg_attr(not(feature = "js_interop_tests"), ignore)] #[ignore] // FIXME this tests hangs sporadically async fn js_interop_rcns_simple_client_writer() -> Result<()> { @@ -88,28 +87,29 @@ async fn js_interop_rcns_simple_client_writer() -> Result<()> { Ok(()) } -#[test(async_test)] +#[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] async fn js_interop_ncrs_simple_server_writer() -> Result<()> { js_interop_ncrs_simple(true, 8105).await?; Ok(()) } -#[test(async_test)] +#[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] async fn js_interop_ncrs_simple_client_writer() -> Result<()> { js_interop_ncrs_simple(false, 8106).await?; Ok(()) } -#[test(async_test)] +#[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] async fn js_interop_rcrs_simple_server_writer() -> Result<()> { + _util::log(); js_interop_rcrs_simple(true, 8107).await?; Ok(()) } -#[test(async_test)] +#[tokio::test] //#[cfg_attr(not(feature = "js_interop_tests"), ignore)] #[ignore] // FIXME this tests hangs sporadically async fn js_interop_rcrs_simple_client_writer() -> Result<()> { From 3357e6a9175a6ae3192c283e3e5b82418c74b8f0 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 6 May 2025 01:50:38 -0400 Subject: [PATCH 097/206] remove test-log dep --- Cargo.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index d53f5f6..a651734 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,7 +60,6 @@ duplexify = "1.1.0" sluice = "0.5.4" futures = "0.3.13" log = "0.4" -test-log = { version = "0.2.11", default-features = false, features = ["trace"] } tracing-subscriber = { version = "0.3.19", features = ["env-filter", "fmt"] } tracing-tree = "0.4.0" tokio-util = { version = "0.7.14", features = ["compat"] } From 778a192066c1ed272497db35e825711f0a24e815 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 6 May 2025 02:07:38 -0400 Subject: [PATCH 098/206] Add instrument to some funcs --- src/message.rs | 2 ++ src/noise.rs | 4 ++-- src/protocol.rs | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/message.rs b/src/message.rs index aa6ca24..80208ff 100644 --- a/src/message.rs +++ b/src/message.rs @@ -54,6 +54,7 @@ pub(crate) fn decode_framed_channel_messages( } Ok((combined_messages, index)) } +#[instrument(skip_all err)] pub(crate) fn decode_unframed_channel_messages( buf: &[u8], ) -> Result<(Vec, usize), io::Error> { @@ -448,6 +449,7 @@ impl ChannelMessage { /// /// Note: `buf` has to have a valid length, and without the 3 LE /// bytes in it + #[instrument(err, skip(buf))] pub(crate) fn decode(buf: &[u8], channel: u64) -> io::Result<(Self, &[u8])> { if buf.len() <= 1 { return Err(io::Error::new( diff --git a/src/noise.rs b/src/noise.rs index 5b5b867..9f35393 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -228,7 +228,7 @@ impl>> + Sink> + Send + Unpin + 'static { type Item = Event; - #[instrument(skip(cx), fields(initiator = %self.is_initiator))] + #[instrument(skip_all, fields(initiator = %self.is_initiator))] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let Encrypted { io, @@ -502,7 +502,7 @@ fn reset_encrypted( } /// handle setup messages: if any are incorrect (cause an error) the state is reset -#[instrument(skip_all, fields(initiator = %is_initiator))] +#[instrument(err, skip_all, fields(initiator = %is_initiator))] fn handle_setup_message( step: &mut Step, msg: &[u8], diff --git a/src/protocol.rs b/src/protocol.rs index 42bd5d6..955ef7c 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -298,7 +298,7 @@ where } /// Poll for inbound messages and processs them. - #[instrument(skip_all)] + #[instrument(skip_all, err)] fn poll_inbound_read(&mut self, cx: &mut Context<'_>) -> Result<()> { loop { match self.io.poll_inbound(cx) { From b5ed63ef9d917aad24f3476a4db52bb36ced77f8 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 6 May 2025 15:38:52 -0400 Subject: [PATCH 099/206] remove redundant 'simple' from every func name --- tests/js_interop.rs | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/js_interop.rs b/tests/js_interop.rs index 28f40ba..ebd03ac 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -60,50 +60,50 @@ const TEST_SET_SIMPLE: &str = "simple"; #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_ncns_simple_server_writer() -> Result<()> { +async fn js_interop_ncns_server_writer() -> Result<()> { js_interop_ncns_simple(true, 8101).await?; Ok(()) } #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_ncns_simple_client_writer() -> Result<()> { +async fn js_interop_ncns_client_writer() -> Result<()> { js_interop_ncns_simple(false, 8102).await?; Ok(()) } #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_rcns_simple_server_writer() -> Result<()> { - js_interop_rcns_simple(true, 8103).await?; +async fn js_interop_rcns_server_writer() -> Result<()> { + js_interop_rcns(true, 8103).await?; Ok(()) } #[tokio::test] //#[cfg_attr(not(feature = "js_interop_tests"), ignore)] #[ignore] // FIXME this tests hangs sporadically -async fn js_interop_rcns_simple_client_writer() -> Result<()> { - js_interop_rcns_simple(false, 8104).await?; +async fn js_interop_rcns_client_writer() -> Result<()> { + js_interop_rcns(false, 8104).await?; Ok(()) } #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_ncrs_simple_server_writer() -> Result<()> { +async fn js_interop_ncrs_server_writer() -> Result<()> { js_interop_ncrs_simple(true, 8105).await?; Ok(()) } #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_ncrs_simple_client_writer() -> Result<()> { +async fn js_interop_ncrs_client_writer() -> Result<()> { js_interop_ncrs_simple(false, 8106).await?; Ok(()) } #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_rcrs_simple_server_writer() -> Result<()> { +async fn js_interop_rcrs_server_writer() -> Result<()> { _util::log(); js_interop_rcrs_simple(true, 8107).await?; Ok(()) @@ -112,7 +112,7 @@ async fn js_interop_rcrs_simple_server_writer() -> Result<()> { #[tokio::test] //#[cfg_attr(not(feature = "js_interop_tests"), ignore)] #[ignore] // FIXME this tests hangs sporadically -async fn js_interop_rcrs_simple_client_writer() -> Result<()> { +async fn js_interop_rcrs_client_writer() -> Result<()> { js_interop_rcrs_simple(false, 8108).await?; Ok(()) } @@ -156,7 +156,7 @@ async fn js_interop_ncns_simple(server_writer: bool, port: u32) -> Result<()> { Ok(()) } -async fn js_interop_rcns_simple(server_writer: bool, port: u32) -> Result<()> { +async fn js_interop_rcns(server_writer: bool, port: u32) -> Result<()> { init(); let test_set = format!( "{}_{}_{}", From 4c0402fe7fe78433eeec80e3fcb5c76cd0fa3b39 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 7 May 2025 11:21:26 -0400 Subject: [PATCH 100/206] rm async_std test wrappers --- tests/js_interop.rs | 76 --------------------------------------------- 1 file changed, 76 deletions(-) diff --git a/tests/js_interop.rs b/tests/js_interop.rs index ebd03ac..41d8160 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -15,15 +15,6 @@ use std::sync::Once; #[cfg(feature = "tokio")] use async_compat::CompatExt; -#[cfg(feature = "async-std")] -use async_std::{ - fs::{metadata, File}, - io::{prelude::BufReadExt, BufReader, BufWriter, WriteExt}, - net::{TcpListener, TcpStream}, - sync::Mutex, - task::{self, sleep}, - test as async_test, -}; #[cfg(feature = "tokio")] use tokio::{ fs::{metadata, File}, @@ -441,40 +432,6 @@ pub fn get_test_key_pair(include_secret: bool) -> PartialKeypair { PartialKeypair { public, secret } } -#[cfg(feature = "async-std")] -async fn on_replication_connection( - stream: TcpStream, - is_initiator: bool, - hypercore: Arc, -) -> Result<()> { - let mut protocol = ProtocolBuilder::new(is_initiator).connect(stream); - while let Some(event) = protocol.next().await { - let event = event?; - match event { - Event::Handshake(_) => { - if is_initiator { - protocol.open(*hypercore.key()).await?; - } - } - Event::DiscoveryKey(dkey) => { - if hypercore.discovery_key == dkey { - protocol.open(*hypercore.key()).await?; - } else { - panic!("Invalid discovery key"); - } - } - Event::Channel(channel) => { - hypercore.on_replication_peer(channel); - } - Event::Close(_dkey) => { - break; - } - _ => {} - } - } - Ok(()) -} - #[cfg(feature = "tokio")] async fn on_replication_connection( stream: TcpStream, @@ -850,39 +807,6 @@ impl RustServer { } } -impl Drop for RustServer { - fn drop(&mut self) { - #[cfg(feature = "async-std")] - if let Some(handle) = self.handle.take() { - task::block_on(handle.cancel()); - } - } -} - -#[cfg(feature = "async-std")] -pub async fn tcp_server( - port: u32, - onconnection: impl Fn(TcpStream, bool, C) -> F + Send + Sync + Copy + 'static, - context: C, -) -> Result<()> -where - F: Future> + Send, - C: Clone + Send + 'static, -{ - let listener = TcpListener::bind(&format!("localhost:{}", port)).await?; - let mut incoming = listener.incoming(); - while let Some(Ok(stream)) = incoming.next().await { - let context = context.clone(); - let _peer_addr = stream.peer_addr().unwrap(); - task::spawn(async move { - onconnection(stream, false, context) - .await - .expect("Should return ok"); - }); - } - Ok(()) -} - #[cfg(feature = "tokio")] pub async fn tcp_server( port: u32, From 917b8374d9612fbcafedef73c284812ad0a00d44 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 8 May 2025 13:47:46 -0400 Subject: [PATCH 101/206] instrument and rename vec_encoded_size for cm --- src/message.rs | 51 +++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 7 deletions(-) diff --git a/src/message.rs b/src/message.rs index 80208ff..ca09465 100644 --- a/src/message.rs +++ b/src/message.rs @@ -7,7 +7,7 @@ use compact_encoding::{ use pretty_hash::fmt as pretty_fmt; use std::fmt; use std::io; -use tracing::{instrument, trace, warn}; +use tracing::{debug, instrument, trace, warn}; const UINT24_HEADER_LEN: usize = 3; const OPEN_MESSAGE_PREFIX: [u8; 2] = [0, 1]; @@ -41,9 +41,7 @@ pub(crate) fn decode_framed_channel_messages( body_len, length ); } - for message in msgs { - combined_messages.push(message); - } + combined_messages.extend(msgs); index += header_len + body_len as usize; } else { return Err(io::Error::new( @@ -122,7 +120,7 @@ pub(crate) fn decode_unframed_channel_messages( } } -fn prencode_channel_messages(messages: &[ChannelMessage]) -> Result { +fn vec_channel_messages_encoded_size(messages: &[ChannelMessage]) -> Result { Ok(match messages { [] => 0, [msg] => match msg.message { @@ -454,7 +452,7 @@ impl ChannelMessage { if buf.len() <= 1 { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, - "received empty message", + format!("received empty message [{buf:?}]"), )); } let (message, buf) = ::decode(buf)?; @@ -493,13 +491,15 @@ impl CompactEncoding for ChannelMessage { } impl VecEncodable for ChannelMessage { + #[instrument(skip_all, ret)] fn vec_encoded_size(vec: &[Self]) -> Result where Self: Sized, { - Ok(prencode_channel_messages(vec)? + UINT24_HEADER_LEN) + Ok(vec_channel_messages_encoded_size(vec)? + UINT24_HEADER_LEN) } + #[instrument(skip_all)] fn vec_encode<'a>(vec: &[Self], buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> where Self: Sized, @@ -568,6 +568,8 @@ impl VecEncodable for ChannelMessage { #[cfg(test)] mod tests { + use crate::test_utils::log; + use super::*; use hypercore::{ DataBlock, DataHash, DataSeek, DataUpgrade, Node, RequestBlock, RequestSeek, RequestUpgrade, @@ -673,4 +675,39 @@ mod tests { }; Ok(()) } + + #[test] + fn extras() -> Result<(), EncodingError> { + let one = Message::Synchronize(Synchronize { + fork: 0, + length: 4, + remote_length: 0, + downloading: true, + uploading: true, + can_upgrade: true, + }); + let two = Message::Range(Range { + drop: false, + start: 0, + length: 4, + }); + let msgs = vec![ChannelMessage::new(1, one), ChannelMessage::new(1, two)]; + let buff = msgs.to_encoded_bytes()?; + + let res = as CompactEncoding>::decode(&buff); + assert!(res.is_err()); + log(); + + let buff = msgs.to_encoded_bytes()?; + let (res2, _size) = decode_framed_channel_messages(&buff).unwrap(); + assert_eq!(res2, msgs); + + // from js interop tests + // [0, 0, 1, 5, 0, 7, 0, 4, 0, 4, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0] + // [0, 0, 1, 5, 0, 7, 0, 4, 0, 4, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0] + // [0, 0, 1, 5, 0, 7, 0, 4, 0, 4, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0] + // [23, 0, 0, 0, 0, 1, 5, 0, 7, 0, 4, 0, 4, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0] + //assert!(res2.is_ok()); + Ok(()) + } } From d7bd06d29bad8aa396addbb4aa3be9419be96017 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 8 May 2025 16:02:33 -0400 Subject: [PATCH 102/206] fix Vec encoding --- src/message.rs | 62 ++++++++++++++++++++++++++++++++++---------------- src/schema.rs | 1 - 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/src/message.rs b/src/message.rs index ca09465..ff2ea2b 100644 --- a/src/message.rs +++ b/src/message.rs @@ -15,6 +15,7 @@ const CLOSE_MESSAGE_PREFIX: [u8; 2] = [0, 3]; const MULTI_MESSAGE_PREFIX: [u8; 2] = [0, 0]; const CHANNEL_CHANGE_SEPERATOR: [u8; 1] = [0]; +#[instrument(skip_all)] pub(crate) fn decode_framed_channel_messages( buf: &[u8], ) -> Result<(Vec, usize), io::Error> { @@ -75,14 +76,22 @@ pub(crate) fn decode_unframed_channel_messages( return Err(io::Error::new( io::ErrorKind::InvalidData, format!( - "received invalid message length: [{channel_message_length}] but we have [{}] remaining bytes. Initial buffer size [{og_len}]", + "received invalid message length: [{channel_message_length}] +\tbut we have [{}] remaining bytes. +\tInitial buffer size [{og_len}]", buf.len() ), )); } // Then the actual message let channel_message; - (channel_message, buf) = ChannelMessage::decode(buf, current_channel)?; + let bl = buf.len(); + (channel_message, buf) = ChannelMessage::decode_with_channel(buf, current_channel)?; + trace!( + "Decoded ChannelMessage::{:?} using [{} bytes]", + channel_message.message, + bl - buf.len() + ); messages.push(channel_message); // After that, if there is an extra 0x00, that means the channel // changed. This works because of LE encoding, and channels starting @@ -128,25 +137,25 @@ fn vec_channel_messages_encoded_size(messages: &[ChannelMessage]) -> Result msg.encoded_size()?, }, msgs => { - let mut out = 2; + let mut out = MULTI_MESSAGE_PREFIX.len(); let mut current_channel: u64 = messages[0].channel; out += current_channel.encoded_size()?; for message in msgs.iter() { if message.channel != current_channel { // Channel changed, need to add a 0x00 in between and then the new // channel - out += 1 + message.channel.encoded_size()?; + out += CHANNEL_CHANGE_SEPERATOR.len() + message.channel.encoded_size()?; current_channel = message.channel; } let message_length = message.message.encoded_size()?; - out += message.encoded_size()? + message_length; + out += message_length + (message_length as u64).encoded_size()?; } out } }) } -fn checked_write_uint24_le(n: usize, buf: &mut [u8]) -> Result<&mut [u8], EncodingError> { +fn encode_usize_as_u24(n: usize, buf: &mut [u8]) -> Result<&mut [u8], EncodingError> { let (header, rest) = take_array_mut::(buf)?; write_uint24_le(n, header); Ok(rest) @@ -239,6 +248,7 @@ impl CompactEncoding for Message { #[instrument(skip_all, fields(name = self.name()))] fn encode<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8], EncodingError> { + debug!("Encoding {self:?}"); let rest = if let Self::Open(_) | Self::Close(_) = &self { buffer } else { @@ -394,6 +404,7 @@ impl ChannelMessage { /// bytes in it #[instrument(skip_all, err)] pub(crate) fn decode_open_message(buf: &[u8]) -> io::Result<(Self, usize)> { + debug!("Decode ChannelMessage::Open"); let og_len = buf.len(); if og_len <= 5 { return Err(io::Error::new( @@ -417,6 +428,7 @@ impl ChannelMessage { /// Note: `buf` has to have a valid length, and without the 3 LE /// bytes in it pub(crate) fn decode_close_message(buf: &[u8]) -> io::Result<(Self, usize)> { + debug!("Decode ChannelMessage::Close"); let og_len = buf.len(); if buf.is_empty() { return Err(io::Error::new( @@ -441,6 +453,10 @@ impl ChannelMessage { //::decode(buf) let (channel, buf) = u64::decode(buf)?; let (message, buf) = ::decode(buf)?; + debug!( + "Decode ChannelMessage{{ channel: {channel}, message: {} }}", + message.name() + ); Ok((Self { channel, message }, buf)) } /// Decode a normal channel message from a buffer. @@ -448,7 +464,7 @@ impl ChannelMessage { /// Note: `buf` has to have a valid length, and without the 3 LE /// bytes in it #[instrument(err, skip(buf))] - pub(crate) fn decode(buf: &[u8], channel: u64) -> io::Result<(Self, &[u8])> { + pub(crate) fn decode_with_channel(buf: &[u8], channel: u64) -> io::Result<(Self, &[u8])> { if buf.len() <= 1 { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, @@ -504,8 +520,13 @@ impl VecEncodable for ChannelMessage { where Self: Sized, { - let body_len = prencode_channel_messages(vec)?; - let mut buffer = checked_write_uint24_le(body_len, buffer)?; + let in_buf_len = buffer.len(); + trace!( + "Vec::encode to buf.len() = [{}]", + buffer.len() + ); + let body_len = vec_channel_messages_encoded_size(vec)?; + let mut buffer = encode_usize_as_u24(body_len, buffer)?; match vec { [] => Ok(buffer), [msg] => { @@ -527,9 +548,10 @@ impl VecEncodable for ChannelMessage { current_channel = msg.channel; } let msg_len = msg.message.encoded_size()?; - buffer = (msg_len as u32).encode(buffer)?; + buffer = (msg_len as u64).encode(buffer)?; buffer = msg.message.encode(buffer)?; } + trace!("wrote [{}] bytes to buffer", in_buf_len - buffer.len()); Ok(buffer) } } @@ -541,16 +563,19 @@ impl VecEncodable for ChannelMessage { { let mut index = 0; let mut combined_messages: Vec = vec![]; + let mut rest = buffer; while index < buffer.len() { // There might be zero bytes in between, and with LE, the next message will // start with a non-zero - if buffer[index] == 0 { + if rest[index] == 0 { index += 1; continue; } - let (frame_len, next_frame_start) = decode_u24(&buffer[index..])?; - let (msgs, length) = decode_unframed_channel_messages(&next_frame_start[..frame_len]) + let frame_len; + (frame_len, rest) = decode_u24(&rest[index..])?; + let (msgs, length) = decode_unframed_channel_messages(&rest[..frame_len]) .map_err(|e| EncodingError::external(&format!("{e}")))?; + rest = &rest[length..]; if length != frame_len { warn!( "Did not know what to do with all the bytes, got {frame_len} but decoded {length}. \ @@ -561,7 +586,7 @@ impl VecEncodable for ChannelMessage { combined_messages.extend(msgs); index += UINT24_HEADER_LEN + frame_len; } - todo!() + Ok((combined_messages, rest)) } } @@ -692,13 +717,12 @@ mod tests { length: 4, }); let msgs = vec![ChannelMessage::new(1, one), ChannelMessage::new(1, two)]; - let buff = msgs.to_encoded_bytes()?; - - let res = as CompactEncoding>::decode(&buff); - assert!(res.is_err()); log(); - let buff = msgs.to_encoded_bytes()?; + let (result, rest) = as CompactEncoding>::decode(&buff)?; + assert!(rest.is_empty()); + assert_eq!(result, msgs); + let (res2, _size) = decode_framed_channel_messages(&buff).unwrap(); assert_eq!(res2, msgs); diff --git a/src/schema.rs b/src/schema.rs index 049a590..d1bec1e 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -122,7 +122,6 @@ impl CompactEncoding for Synchronize { let mut flags: u8 = if self.can_upgrade { 1 } else { 0 }; flags |= if self.uploading { 2 } else { 0 }; flags |= if self.downloading { 4 } else { 0 }; - dbg!(flags); let rest = write_array(&[flags], buffer)?; Ok(map_encode!( rest, From 3da6c62f2659ac259ad41ac8e6382335e3c95fc2 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Fri, 9 May 2025 14:51:11 -0400 Subject: [PATCH 103/206] remove redundant names --- tests/js_interop.rs | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/tests/js_interop.rs b/tests/js_interop.rs index 41d8160..7764141 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -51,64 +51,64 @@ const TEST_SET_SIMPLE: &str = "simple"; #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_ncns_server_writer() -> Result<()> { - js_interop_ncns_simple(true, 8101).await?; +async fn ncns_server_writer() -> Result<()> { + ncns(true, 8101).await?; Ok(()) } #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_ncns_client_writer() -> Result<()> { - js_interop_ncns_simple(false, 8102).await?; +async fn ncns_client_writer() -> Result<()> { + ncns(false, 8102).await?; Ok(()) } #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_rcns_server_writer() -> Result<()> { - js_interop_rcns(true, 8103).await?; +async fn rcns_server_writer() -> Result<()> { + rcns(true, 8103).await?; Ok(()) } #[tokio::test] //#[cfg_attr(not(feature = "js_interop_tests"), ignore)] #[ignore] // FIXME this tests hangs sporadically -async fn js_interop_rcns_client_writer() -> Result<()> { - js_interop_rcns(false, 8104).await?; +async fn rcns_client_writer() -> Result<()> { + rcns(false, 8104).await?; Ok(()) } #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_ncrs_server_writer() -> Result<()> { - js_interop_ncrs_simple(true, 8105).await?; +async fn ncrs_server_writer() -> Result<()> { + ncrs(true, 8105).await?; Ok(()) } #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_ncrs_client_writer() -> Result<()> { - js_interop_ncrs_simple(false, 8106).await?; +async fn ncrs_client_writer() -> Result<()> { + ncrs(false, 8106).await?; Ok(()) } #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] -async fn js_interop_rcrs_server_writer() -> Result<()> { +async fn rcrs_server_writer() -> Result<()> { _util::log(); - js_interop_rcrs_simple(true, 8107).await?; + rcrs(true, 8107).await?; Ok(()) } #[tokio::test] //#[cfg_attr(not(feature = "js_interop_tests"), ignore)] -#[ignore] // FIXME this tests hangs sporadically -async fn js_interop_rcrs_client_writer() -> Result<()> { - js_interop_rcrs_simple(false, 8108).await?; +//#[ignore] // FIXME this tests hangs sporadically +async fn rcrs_client_writer() -> Result<()> { + rcrs(false, 8108).await?; Ok(()) } -async fn js_interop_ncns_simple(server_writer: bool, port: u32) -> Result<()> { +async fn ncns(server_writer: bool, port: u32) -> Result<()> { init(); let test_set = format!( "{}_{}_{}", @@ -147,7 +147,7 @@ async fn js_interop_ncns_simple(server_writer: bool, port: u32) -> Result<()> { Ok(()) } -async fn js_interop_rcns(server_writer: bool, port: u32) -> Result<()> { +async fn rcns(server_writer: bool, port: u32) -> Result<()> { init(); let test_set = format!( "{}_{}_{}", @@ -192,7 +192,7 @@ async fn js_interop_rcns(server_writer: bool, port: u32) -> Result<()> { Ok(()) } -async fn js_interop_ncrs_simple(server_writer: bool, port: u32) -> Result<()> { +async fn ncrs(server_writer: bool, port: u32) -> Result<()> { init(); let test_set = format!( "{}_{}_{}", @@ -238,7 +238,7 @@ async fn js_interop_ncrs_simple(server_writer: bool, port: u32) -> Result<()> { Ok(()) } -async fn js_interop_rcrs_simple(server_writer: bool, port: u32) -> Result<()> { +async fn rcrs(server_writer: bool, port: u32) -> Result<()> { init(); let test_set = format!( "{}_{}_{}", From 74033f1ece4e3a1906ae7f785912f87f7b80cf8b Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sat, 10 May 2025 20:33:06 -0400 Subject: [PATCH 104/206] RMME --- tests/js_interop.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/js_interop.rs b/tests/js_interop.rs index 7764141..e0dc216 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -1,4 +1,4 @@ -use _util::wait_for_localhost_port; +use _util::{log, wait_for_localhost_port}; use anyhow::Result; use futures::Future; use futures_lite::stream::StreamExt; @@ -66,6 +66,7 @@ async fn ncns_client_writer() -> Result<()> { #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] async fn rcns_server_writer() -> Result<()> { + log(); rcns(true, 8103).await?; Ok(()) } @@ -159,7 +160,9 @@ async fn rcns(server_writer: bool, port: u32) -> Result<()> { }, TEST_SET_SIMPLE ); + dbg!(); let (result_path, writer_path, reader_path) = prepare_test_set(&test_set); + dbg!(); let item_count = 4; let item_size = 4; let data_char = '1'; From d87c498d801f88d46d0e37718fe8eec7a3d85d06 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 18 May 2025 01:39:06 -0400 Subject: [PATCH 105/206] More logging rm unused --- src/noise.rs | 20 ++++++++++++-------- src/protocol.rs | 1 + src/test_utils.rs | 10 ---------- 3 files changed, 13 insertions(+), 18 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index 9f35393..8185162 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -170,7 +170,7 @@ impl< *is_initiator, flush, ); - poll_outgoing_encrypted_messages(io, cx, encrypted_tx, *is_initiator, flush); + poll_outgoing_encrypted_messages(io, cx, encrypted_tx, *is_initiator, flush, step); // check if we've done all possible work if did_as_much_as_possible( @@ -202,6 +202,7 @@ impl< } /// Check that we've done as much work as possible. Sending, receiving, encrypting and decrypting. +#[instrument(skip_all, ret)] fn did_as_much_as_possible< IO: Stream>> + Sink> + Send + Unpin + 'static, >( @@ -214,7 +215,7 @@ fn did_as_much_as_possible< is_initiator: bool, ) -> bool { // No incoming encrypted messages available. - poll_incomming_encrypted_messages(io, cx, encrypted_rx, is_initiator).is_pending() + poll_incomming_encrypted_messages(io, cx, encrypted_rx, is_initiator, step).is_pending() // We're unable to send any anymore encrypted/setup messages either because we have none or the `Sink` is unavailable. && (encrypted_tx.is_empty() || Sink::poll_ready(Pin::new(io), cx).is_pending()) // No encrypted messages waiting to be decrypted. @@ -228,7 +229,7 @@ impl>> + Sink> + Send + Unpin + 'static { type Item = Event; - #[instrument(skip_all, fields(initiator = %self.is_initiator))] + #[instrument(skip_all, fields(initiator = %self.is_initiator, ret, err))] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let Encrypted { io, @@ -268,6 +269,7 @@ impl>> + Sink> + Send + Unpin + 'static /// Handle all message throughput. Sends, encrypts and decrypts messages /// Returns `true` `step` is already [`Step::Established`]. #[allow(clippy::too_many_arguments)] +#[instrument(skip_all, ret)] fn poll_message_throughput< IO: Stream>> + Sink> + Send + Unpin + 'static, >( @@ -281,8 +283,8 @@ fn poll_message_throughput< is_initiator: bool, flush: &mut bool, ) -> bool { - poll_outgoing_encrypted_messages(io, cx, encrypted_tx, is_initiator, flush); - let _ = poll_incomming_encrypted_messages(io, cx, encrypted_rx, is_initiator); + poll_outgoing_encrypted_messages(io, cx, encrypted_tx, is_initiator, flush, step); + let _ = poll_incomming_encrypted_messages(io, cx, encrypted_rx, is_initiator, step); if let Step::Established((encryptor, decryptor, ..)) = step { // decrypt incomming msgs poll_decrypt(decryptor, encrypted_rx, plain_rx, is_initiator); @@ -369,11 +371,12 @@ fn poll_outgoing_encrypted_messages< encrypted_tx: &mut VecDeque>, is_initiator: bool, flush: &mut bool, + step: &Step ) { // send any pending outgoing messages while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { if let Some(encrypted_out) = encrypted_tx.pop_front() { - trace!(initiator = %is_initiator, msg_len = encrypted_out.len(), "TX message"); + trace!(initiator = %is_initiator, msg_len = encrypted_out.len(), step = %step, "TX message"); if let Err(_e) = Sink::start_send(Pin::new(io), encrypted_out) { error!("Error polling encyrpted side io") } @@ -407,11 +410,12 @@ fn poll_incomming_encrypted_messages< cx: &mut Context<'_>, encrypted_rx: &mut VecDeque>>, is_initiator: bool, + step: &Step, ) -> Poll<()> { // pull in any incomming encrypted messages let mut got_some = false; while let Poll::Ready(Some(encrypted_msg)) = Stream::poll_next(Pin::new(io), cx) { - trace!(initiator = %is_initiator, "RX message"); + trace!(initiator = %is_initiator, step = %step, "RX message"); encrypted_rx.push_back(encrypted_msg); got_some = true; } @@ -437,8 +441,8 @@ fn poll_decrypt( trace!(initiator = %is_initiator, "encrypted_rx dequeue decrypt"); match decryptor.decrypt_buf(&incoming_msg) { Ok((plain_msg, _tag)) => { - trace!(initiator = %is_initiator, "plain rx queue"); plain_rx.push_back(Event::from(Ok(plain_msg))); + trace!(initiator = %is_initiator, n_plain_rx_msgs = plain_rx.len(), "plain_rx enqueue"); } Err(e) => { error!(initiator = %is_initiator,"RX message failed to decrypt: {e:?}") diff --git a/src/protocol.rs b/src/protocol.rs index 955ef7c..615cd53 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -348,6 +348,7 @@ where } } + #[instrument(skip_all)] fn on_inbound_channel_messages(&mut self, channel_messages: Vec) -> Result<()> { for channel_message in channel_messages { self.on_inbound_message(channel_message)? diff --git a/src/test_utils.rs b/src/test_utils.rs index 24256cb..8a4dd74 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -197,13 +197,3 @@ pub(crate) fn create_result_connected() -> ( let b = Moo::from(result_channel()); a.connect(b) } - -#[tokio::test] -async fn foo() -> Result<(), Box> { - let a = Moo::from(result_channel()); - let b = Moo::from(result_channel()); - let (mut left, mut right) = a.connect(b); - left.send(b"hello".to_vec()).await?; - assert_eq!(right.next().await.unwrap()?, b"hello".to_vec()); - Ok(()) -} From 539a0179fe93787f912a574ac37df4c06a53dfea Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 18 May 2025 01:40:51 -0400 Subject: [PATCH 106/206] Notes --- src/schema.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/schema.rs b/src/schema.rs index d1bec1e..bc8f141 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -299,6 +299,10 @@ macro_rules! opt_encoded_size { }; } +// TODO we could write a macro where it takes a $cond that returns an opt. +// if the option is Some(T) then do T::encode(buf) +// also if some add $flag. +// This would simplify some of these impls macro_rules! opt_encoded_bytes { ($opt:expr, $buf:ident) => { if let Some(thing) = $opt { From 39d0dbec60eba60cf5ffa5bbbf21b239044c0c84 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 18:18:02 -0400 Subject: [PATCH 107/206] use checked get --- src/message.rs | 39 +++++++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/src/message.rs b/src/message.rs index ff2ea2b..beabb8d 100644 --- a/src/message.rs +++ b/src/message.rs @@ -7,7 +7,7 @@ use compact_encoding::{ use pretty_hash::fmt as pretty_fmt; use std::fmt; use std::io; -use tracing::{debug, instrument, trace, warn}; +use tracing::{debug, error, instrument, trace, warn}; const UINT24_HEADER_LEN: usize = 3; const OPEN_MESSAGE_PREFIX: [u8; 2] = [0, 1]; @@ -31,19 +31,38 @@ pub(crate) fn decode_framed_channel_messages( let stat = stat_uint24_le(&buf[index..]); if let Some((header_len, body_len)) = stat { - let (msgs, length) = decode_unframed_channel_messages( - &buf[index + header_len..index + header_len + body_len as usize], - )?; - if length != body_len as usize { - warn!( - "Did not know what to do with all the bytes, got {} but decoded {}. \ + dbg!(&body_len); + if let Some(frame_body) = + buf.get(index + header_len..index + header_len + body_len as usize) + { + let (msgs, length) = decode_unframed_channel_messages(frame_body)?; + if length != body_len as usize { + warn!( + "Did not know what to do with all the bytes, got {} but decoded {}. \ This may be because the peer implements a newer protocol version \ that has extra fields.", - body_len, length + body_len, length + ); + } + combined_messages.extend(msgs); + index += header_len + body_len as usize; + } else { + error!( + "Could not get bytes for whole frame. +frame_header_length + frame_body_length \t= [{}] +remaining buffer (after current index) \t= [{}] +total_buffer_len \t= [{}] +current_index \t= [{}] +buffer_starts_with \t= [{:?}] +", + header_len + (body_len as usize), + buf.len() - index, + buf.len(), + index, + &buf, ); + return Ok((combined_messages, index)); } - combined_messages.extend(msgs); - index += header_len + body_len as usize; } else { return Err(io::Error::new( io::ErrorKind::InvalidData, From d97bc646e46bf86b21ac61ef3908e4e0b037cc24 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 18:18:31 -0400 Subject: [PATCH 108/206] RMME --- src/message.rs | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/message.rs b/src/message.rs index beabb8d..69b18b1 100644 --- a/src/message.rs +++ b/src/message.rs @@ -753,4 +753,19 @@ mod tests { //assert!(res2.is_ok()); Ok(()) } + + #[test] + fn foo() -> Result<(), EncodingError> { + log(); + let buf = vec![ + 0, 1, 1, 15, 104, 121, 112, 101, 114, 99, 111, 114, 101, 47, 97, 108, 112, 104, 97, 32, + 23, 228, 138, 218, 81, 18, 123, 111, 160, 195, 104, 154, 55, 116, 18, 132, 44, 229, 77, + 118, 217, 54, 41, 162, 97, 118, 95, 4, 213, 142, 79, 124, 1, 89, 165, 64, 201, 94, 50, + 58, 137, 153, 119, 156, 234, 18, 164, 157, 161, 49, 16, 28, 206, 84, 241, 0, 245, 14, + 143, 129, 9, 151, 247, 29, 10, + ]; + let res = decode_unframed_channel_messages(&buf).unwrap(); + dbg!(&res); + Ok(()) + } } From 592699981e25456c466ec01ddba9653d8520d4b8 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 18:21:29 -0400 Subject: [PATCH 109/206] un-ignore tests --- src/message.rs | 1 + tests/js_interop.rs | 4 +--- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/message.rs b/src/message.rs index 69b18b1..41cd267 100644 --- a/src/message.rs +++ b/src/message.rs @@ -754,6 +754,7 @@ mod tests { Ok(()) } + // TODO RMME #[test] fn foo() -> Result<(), EncodingError> { log(); diff --git a/tests/js_interop.rs b/tests/js_interop.rs index e0dc216..320436a 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -66,14 +66,12 @@ async fn ncns_client_writer() -> Result<()> { #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] async fn rcns_server_writer() -> Result<()> { - log(); rcns(true, 8103).await?; Ok(()) } #[tokio::test] -//#[cfg_attr(not(feature = "js_interop_tests"), ignore)] -#[ignore] // FIXME this tests hangs sporadically +#[cfg_attr(not(feature = "js_interop_tests"), ignore)] async fn rcns_client_writer() -> Result<()> { rcns(false, 8104).await?; Ok(()) From 81e0dad195678093966ce8af3a51a1a6c369db99 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 18:21:43 -0400 Subject: [PATCH 110/206] rm debug --- src/message.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/message.rs b/src/message.rs index 41cd267..3805158 100644 --- a/src/message.rs +++ b/src/message.rs @@ -31,7 +31,6 @@ pub(crate) fn decode_framed_channel_messages( let stat = stat_uint24_le(&buf[index..]); if let Some((header_len, body_len)) = stat { - dbg!(&body_len); if let Some(frame_body) = buf.get(index + header_len..index + header_len + body_len as usize) { From fb14f920265d875c9bb353a42694ea69c494f215 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 20:40:46 -0400 Subject: [PATCH 111/206] rm framing stuff from messages --- src/message.rs | 156 ++++++++----------------------------------------- 1 file changed, 25 insertions(+), 131 deletions(-) diff --git a/src/message.rs b/src/message.rs index 3805158..637e99e 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,76 +1,18 @@ use crate::schema::*; -use crate::util::{stat_uint24_le, write_uint24_le}; use compact_encoding::{ - decode_usize, take_array, take_array_mut, write_array, CompactEncoding, EncodingError, - EncodingErrorKind, VecEncodable, + decode_usize, take_array, write_array, CompactEncoding, EncodingError, EncodingErrorKind, + VecEncodable, }; use pretty_hash::fmt as pretty_fmt; use std::fmt; use std::io; -use tracing::{debug, error, instrument, trace, warn}; +use tracing::{debug, instrument, trace, warn}; -const UINT24_HEADER_LEN: usize = 3; const OPEN_MESSAGE_PREFIX: [u8; 2] = [0, 1]; const CLOSE_MESSAGE_PREFIX: [u8; 2] = [0, 3]; const MULTI_MESSAGE_PREFIX: [u8; 2] = [0, 0]; const CHANNEL_CHANGE_SEPERATOR: [u8; 1] = [0]; -#[instrument(skip_all)] -pub(crate) fn decode_framed_channel_messages( - buf: &[u8], -) -> Result<(Vec, usize), io::Error> { - let mut index = 0; - let mut combined_messages: Vec = vec![]; - while index < buf.len() { - // There might be zero bytes in between, and with LE, the next message will - // start with a non-zero - if buf[index] == 0 { - index += 1; - continue; - } - - let stat = stat_uint24_le(&buf[index..]); - if let Some((header_len, body_len)) = stat { - if let Some(frame_body) = - buf.get(index + header_len..index + header_len + body_len as usize) - { - let (msgs, length) = decode_unframed_channel_messages(frame_body)?; - if length != body_len as usize { - warn!( - "Did not know what to do with all the bytes, got {} but decoded {}. \ - This may be because the peer implements a newer protocol version \ - that has extra fields.", - body_len, length - ); - } - combined_messages.extend(msgs); - index += header_len + body_len as usize; - } else { - error!( - "Could not get bytes for whole frame. -frame_header_length + frame_body_length \t= [{}] -remaining buffer (after current index) \t= [{}] -total_buffer_len \t= [{}] -current_index \t= [{}] -buffer_starts_with \t= [{:?}] -", - header_len + (body_len as usize), - buf.len() - index, - buf.len(), - index, - &buf, - ); - return Ok((combined_messages, index)); - } - } else { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "received invalid data in multi-message chunk", - )); - } - } - Ok((combined_messages, index)) -} #[instrument(skip_all err)] pub(crate) fn decode_unframed_channel_messages( buf: &[u8], @@ -142,7 +84,7 @@ pub(crate) fn decode_unframed_channel_messages( } else { Err(io::Error::new( io::ErrorKind::InvalidData, - format!("received too short message, {buf:02X?}"), + format!("received too short message, {buf:?}"), )) } } @@ -173,19 +115,6 @@ fn vec_channel_messages_encoded_size(messages: &[ChannelMessage]) -> Result Result<&mut [u8], EncodingError> { - let (header, rest) = take_array_mut::(buf)?; - write_uint24_le(n, header); - Ok(rest) -} - -/// decode a u24 from `buffer` as a `usize` -fn decode_u24(buffer: &[u8]) -> Result<(usize, &[u8]), EncodingError> { - let (u24_bytes, rest) = take_array::(buffer)?; - let (_, out) = stat_uint24_le(&u24_bytes).expect("input garunteed to be long enough"); - Ok((out as usize, rest)) -} - /// A protocol message. #[derive(Debug, Clone, PartialEq)] #[allow(missing_docs)] @@ -530,7 +459,7 @@ impl VecEncodable for ChannelMessage { where Self: Sized, { - Ok(vec_channel_messages_encoded_size(vec)? + UINT24_HEADER_LEN) + Ok(vec_channel_messages_encoded_size(vec)?) } #[instrument(skip_all)] @@ -543,34 +472,33 @@ impl VecEncodable for ChannelMessage { "Vec::encode to buf.len() = [{}]", buffer.len() ); - let body_len = vec_channel_messages_encoded_size(vec)?; - let mut buffer = encode_usize_as_u24(body_len, buffer)?; + let mut rest = buffer; match vec { - [] => Ok(buffer), + [] => Ok(rest), [msg] => { - buffer = match msg.message { - Message::Open(_) => write_array(&OPEN_MESSAGE_PREFIX, buffer)?, - Message::Close(_) => write_array(&CLOSE_MESSAGE_PREFIX, buffer)?, - _ => msg.channel.encode(buffer)?, + rest = match msg.message { + Message::Open(_) => write_array(&OPEN_MESSAGE_PREFIX, rest)?, + Message::Close(_) => write_array(&CLOSE_MESSAGE_PREFIX, rest)?, + _ => msg.channel.encode(rest)?, }; - msg.message.encode(buffer) + msg.message.encode(rest) } msgs => { - buffer = write_array(&MULTI_MESSAGE_PREFIX, buffer)?; + rest = write_array(&MULTI_MESSAGE_PREFIX, rest)?; let mut current_channel: u64 = msgs[0].channel; - buffer = current_channel.encode(buffer)?; + rest = current_channel.encode(rest)?; for msg in msgs { if msg.channel != current_channel { - buffer = write_array(&CHANNEL_CHANGE_SEPERATOR, buffer)?; - buffer = msg.channel.encode(buffer)?; + rest = write_array(&CHANNEL_CHANGE_SEPERATOR, rest)?; + rest = msg.channel.encode(rest)?; current_channel = msg.channel; } let msg_len = msg.message.encoded_size()?; - buffer = (msg_len as u64).encode(buffer)?; - buffer = msg.message.encode(buffer)?; + rest = (msg_len as u64).encode(rest)?; + rest = msg.message.encode(rest)?; } - trace!("wrote [{}] bytes to buffer", in_buf_len - buffer.len()); - Ok(buffer) + trace!("wrote [{}] bytes to buffer", in_buf_len - rest.len()); + Ok(rest) } } } @@ -579,30 +507,13 @@ impl VecEncodable for ChannelMessage { where Self: Sized, { - let mut index = 0; let mut combined_messages: Vec = vec![]; let mut rest = buffer; - while index < buffer.len() { - // There might be zero bytes in between, and with LE, the next message will - // start with a non-zero - if rest[index] == 0 { - index += 1; - continue; - } - let frame_len; - (frame_len, rest) = decode_u24(&rest[index..])?; - let (msgs, length) = decode_unframed_channel_messages(&rest[..frame_len]) + while !rest.is_empty() { + let (msgs, length) = decode_unframed_channel_messages(rest) .map_err(|e| EncodingError::external(&format!("{e}")))?; rest = &rest[length..]; - if length != frame_len { - warn!( - "Did not know what to do with all the bytes, got {frame_len} but decoded {length}. \ - This may be because the peer implements a newer protocol version \ - that has extra fields.", - ); - } combined_messages.extend(msgs); - index += UINT24_HEADER_LEN + frame_len; } Ok((combined_messages, rest)) } @@ -662,7 +573,9 @@ mod tests { upgrade: Some(RequestUpgrade { start: 0, length: 10 - }) + }), + manifest: false, + priority: 0 }), Message::Cancel(Cancel { request: 1, @@ -741,9 +654,6 @@ mod tests { assert!(rest.is_empty()); assert_eq!(result, msgs); - let (res2, _size) = decode_framed_channel_messages(&buff).unwrap(); - assert_eq!(res2, msgs); - // from js interop tests // [0, 0, 1, 5, 0, 7, 0, 4, 0, 4, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0] // [0, 0, 1, 5, 0, 7, 0, 4, 0, 4, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0] @@ -752,20 +662,4 @@ mod tests { //assert!(res2.is_ok()); Ok(()) } - - // TODO RMME - #[test] - fn foo() -> Result<(), EncodingError> { - log(); - let buf = vec![ - 0, 1, 1, 15, 104, 121, 112, 101, 114, 99, 111, 114, 101, 47, 97, 108, 112, 104, 97, 32, - 23, 228, 138, 218, 81, 18, 123, 111, 160, 195, 104, 154, 55, 116, 18, 132, 44, 229, 77, - 118, 217, 54, 41, 162, 97, 118, 95, 4, 213, 142, 79, 124, 1, 89, 165, 64, 201, 94, 50, - 58, 137, 153, 119, 156, 234, 18, 164, 157, 161, 49, 16, 28, 206, 84, 241, 0, 245, 14, - 143, 129, 9, 151, 247, 29, 10, - ]; - let res = decode_unframed_channel_messages(&buf).unwrap(); - dbg!(&res); - Ok(()) - } } From 3b9430b440856bc0eccc7a6ac8dbf069ce9d1aa3 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 20:54:28 -0400 Subject: [PATCH 112/206] Add manifest & priority to Request --- examples/replication.rs | 4 ++++ src/mqueue.rs | 6 +----- src/schema.rs | 22 ++++++++++++++++++++++ tests/js_interop.rs | 4 ++++ 4 files changed, 31 insertions(+), 5 deletions(-) diff --git a/examples/replication.rs b/examples/replication.rs index 459df9f..ac10df6 100644 --- a/examples/replication.rs +++ b/examples/replication.rs @@ -299,6 +299,8 @@ async fn onmessage( start: info.length, length: peer_state.remote_length - info.length, }), + manifest: false, + priority: 0, }; messages.push(Message::Request(msg)); } @@ -405,6 +407,8 @@ async fn onmessage( block: Some(request_block), seek: None, upgrade: None, + manifest: false, + priority: 0, })); } channel.send_batch(&messages).await.unwrap(); diff --git a/src/mqueue.rs b/src/mqueue.rs index e5df5b8..0997f36 100644 --- a/src/mqueue.rs +++ b/src/mqueue.rs @@ -12,11 +12,7 @@ use compact_encoding::CompactEncoding as _; use futures::{Sink, Stream}; use tracing::{error, instrument}; -use crate::{ - message::{decode_framed_channel_messages, ChannelMessage}, - noise::EncryptionInfo, - NoiseEvent, -}; +use crate::{message::ChannelMessage, noise::EncryptionInfo, NoiseEvent}; #[derive(Debug)] pub(crate) enum MqueueEvent { diff --git a/src/schema.rs b/src/schema.rs index bc8f141..b27f9e4 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -169,6 +169,13 @@ pub struct Request { pub seek: Option, /// Request upgrade pub upgrade: Option, + // TODO what is this + /// Request manifest + pub manifest: bool, + // TODO what is this + // this could prob be usize + /// Request priority + pub priority: u64, } macro_rules! maybe_decode { @@ -206,6 +213,8 @@ impl CompactEncoding for Request { flags |= if self.hash.is_some() { 2 } else { 0 }; flags |= if self.seek.is_some() { 4 } else { 0 }; flags |= if self.upgrade.is_some() { 8 } else { 0 }; + flags |= if self.manifest { 16 } else { 0 }; + flags |= if self.priority != 0 { 32 } else { 0 }; let mut rest = write_array(&[flags], buffer)?; rest = map_encode!(rest, self.id, self.fork); @@ -221,6 +230,11 @@ impl CompactEncoding for Request { if let Some(upgrade) = &self.upgrade { rest = upgrade.encode(rest)?; } + + if self.priority != 0 { + rest = self.priority.encode(rest)?; + } + Ok(rest) } @@ -235,6 +249,12 @@ impl CompactEncoding for Request { let (hash, rest) = maybe_decode!(flags & 2 != 0, RequestBlock, rest); let (seek, rest) = maybe_decode!(flags & 4 != 0, RequestSeek, rest); let (upgrade, rest) = maybe_decode!(flags & 8 != 0, RequestUpgrade, rest); + let manifest = flags & 16 != 0; + let (priority, rest) = if flags & 32 != 0 { + u64::decode(rest)? + } else { + (0, rest) + }; Ok(( Request { id, @@ -243,6 +263,8 @@ impl CompactEncoding for Request { hash, seek, upgrade, + manifest, + priority, }, rest, )) diff --git a/tests/js_interop.rs b/tests/js_interop.rs index 320436a..ee9d55f 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -605,6 +605,8 @@ async fn on_replication_message( start: info.length, length: peer_state.remote_length - info.length, }), + manifest: false, + priority: 0, }; messages.push(Message::Request(msg)); } @@ -716,6 +718,8 @@ async fn on_replication_message( block: Some(request_block), seek: None, upgrade: None, + manifest: false, + priority: 0, })); } let exit = if synced { From f535d3b4c1107d1527014278f6ad25e19c24bd5a Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 20:54:48 -0400 Subject: [PATCH 113/206] rm unused framing stuff --- src/mqueue.rs | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/mqueue.rs b/src/mqueue.rs index 0997f36..87f14d5 100644 --- a/src/mqueue.rs +++ b/src/mqueue.rs @@ -24,16 +24,13 @@ impl From for MqueueEvent { fn from(e: NoiseEvent) -> Self { match e { NoiseEvent::Meta(einf) => Self::Meta(einf), - NoiseEvent::Decrypted(dec_res) => { - match dec_res { - Ok(encoded) => match decode_framed_channel_messages(&encoded) { - //assert_eq!(_n_read, encoded.len()); } - Ok((messsages, _n_read)) => Self::Message(Ok(messsages)), - Err(e) => Self::Message(Err(e)), - }, - Err(e) => Self::Message(Err(e)), - } - } + NoiseEvent::Decrypted(dec_res) => Self::Message(match dec_res { + Ok(encoded) => match >::decode(&encoded) { + Ok((messages, _rest)) => Ok(messages), // _rest.len() == 0 + Err(e) => Err(e.into()), + }, + Err(e) => Err(e), + }), } } } From b59077eeac30de34befe2ee31d82ac96557711c4 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 20:58:01 -0400 Subject: [PATCH 114/206] rm test logging --- tests/js_interop.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/js_interop.rs b/tests/js_interop.rs index ee9d55f..a52900c 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -158,9 +158,7 @@ async fn rcns(server_writer: bool, port: u32) -> Result<()> { }, TEST_SET_SIMPLE ); - dbg!(); let (result_path, writer_path, reader_path) = prepare_test_set(&test_set); - dbg!(); let item_count = 4; let item_size = 4; let data_char = '1'; From f6b37e55ca2ff1f427831c010fa8099213f9b177 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 20:59:08 -0400 Subject: [PATCH 115/206] cargo fmt --- src/noise.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/noise.rs b/src/noise.rs index 8185162..f51b0b7 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -371,7 +371,7 @@ fn poll_outgoing_encrypted_messages< encrypted_tx: &mut VecDeque>, is_initiator: bool, flush: &mut bool, - step: &Step + step: &Step, ) { // send any pending outgoing messages while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { From 76b3a27803f63c6cef5647a78b3b61ffcb58b3c8 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 20:59:34 -0400 Subject: [PATCH 116/206] cargo clippy --fix --- src/message.rs | 2 +- tests/js_interop.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/message.rs b/src/message.rs index 637e99e..1f0fdb5 100644 --- a/src/message.rs +++ b/src/message.rs @@ -459,7 +459,7 @@ impl VecEncodable for ChannelMessage { where Self: Sized, { - Ok(vec_channel_messages_encoded_size(vec)?) + vec_channel_messages_encoded_size(vec) } #[instrument(skip_all)] diff --git a/tests/js_interop.rs b/tests/js_interop.rs index a52900c..5ae6acb 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -1,4 +1,4 @@ -use _util::{log, wait_for_localhost_port}; +use _util::wait_for_localhost_port; use anyhow::Result; use futures::Future; use futures_lite::stream::StreamExt; From 2d1ef1b743d94e436e5dca9f40d58faf2ca07bb8 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 21:05:10 -0400 Subject: [PATCH 117/206] RawEncCipher -> EncCipher --- src/crypto/cipher.rs | 6 +++--- src/crypto/mod.rs | 2 +- src/noise.rs | 10 +++++----- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index aa096c3..94e325e 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -102,17 +102,17 @@ fn write_stream_id(handshake_hash: &[u8], is_initiator: bool, out: &mut [u8]) { //NB "raw" here means UN-framed. No frame header. const RAW_HEADER_MSG_LEN: usize = STREAM_ID_LENGTH + Header::BYTES; -pub(crate) struct RawEncryptCipher { +pub(crate) struct EncryptCipher { push_stream: PushStream, } -impl std::fmt::Debug for RawEncryptCipher { +impl std::fmt::Debug for EncryptCipher { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "RawEncryptCipher(crypto_secretstream)") } } -impl RawEncryptCipher { +impl EncryptCipher { pub(crate) fn from_handshake_tx( handshake_result: &HandshakeResult, ) -> std::io::Result<(Self, Vec)> { diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 9e49c0a..66bb62d 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -1,5 +1,5 @@ mod cipher; mod curve; mod handshake; -pub(crate) use cipher::{DecryptCipher, RawEncryptCipher}; +pub(crate) use cipher::{DecryptCipher, EncryptCipher}; pub(crate) use handshake::{Handshake, HandshakeResult}; diff --git a/src/noise.rs b/src/noise.rs index f51b0b7..7bac01a 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -10,7 +10,7 @@ use std::{ use tracing::{debug, error, instrument, trace, warn}; use crate::{ - crypto::{DecryptCipher, Handshake, HandshakeResult, RawEncryptCipher}, + crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeResult}, Uint24LELengthPrefixedFraming, }; @@ -27,8 +27,8 @@ pub fn encrypted_framed_message_channel), - SecretStream((RawEncryptCipher, HandshakeResult)), - Established((RawEncryptCipher, DecryptCipher, HandshakeResult)), + SecretStream((EncryptCipher, HandshakeResult)), + Established((EncryptCipher, DecryptCipher, HandshakeResult)), } impl Step { @@ -458,7 +458,7 @@ fn poll_decrypt( #[instrument(skip_all)] fn poll_encrypt( - encryptor: &mut RawEncryptCipher, + encryptor: &mut EncryptCipher, encrypted_tx: &mut VecDeque>, plain_tx: &mut VecDeque>, is_initiator: bool, @@ -568,7 +568,7 @@ fn handle_setup_message( }; // The cipher will be put to use to the writer only after the peer's answer has come let (cipher, init_msg) = - match RawEncryptCipher::from_handshake_tx(handshake_result) { + match EncryptCipher::from_handshake_tx(handshake_result) { Ok(x) => x, Err(e) => { error!("from_handshake_tx error {e:?}"); From fe2eed1f073adf687b948ef86773444d15a334f6 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 21:10:02 -0400 Subject: [PATCH 118/206] Remove old notes --- src/crypto/cipher.rs | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index 94e325e..ea11920 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -136,24 +136,9 @@ impl EncryptCipher { Ok((Self { push_stream }, msg)) } - // Possible API's: - // encrypted message is (tag + encrypted + mac ) - // to have *zero* alocations we could - // * take a buffer of the expected final length, plantext starts at 1 to 1 + planetext.len() - // * final length is 1 + plaintext.len() + mac.len() - // * we write tag to 0 - // * encrypt plain text part in place - // * write mac to end - // - // it would be akward to take an array like this. We could infer the plaintext via the buffer - // it's range would be (1..(buf.len() - mac.len())) - // encypt-in-place the palintext, - // For now... let's just return the encrypted buffer - // + // TODO make this work in-place /// Encrypts `msg` and returns the encrypted bytes pub(crate) fn encrypt(&mut self, msg: &[u8]) -> io::Result> { - // NB: the result is written in place to the provided, however the buffer must be able to - // grow, since the encrypted message is bigger. So here we convert the slice to a vec. let mut out = msg.to_vec(); self.push_stream .push(&mut out, &[], Tag::Message) From f92b5b2a83868551bd8685f3f1d6921cedd511e2 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 21:30:53 -0400 Subject: [PATCH 119/206] lint --- src/framing.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/framing.rs b/src/framing.rs index 12d3c41..02730b1 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -1,3 +1,7 @@ +//! Wrap bytes in length prefixed framing. +use crate::util::{stat_uint24_le, wrap_uint24_le}; +use futures::{Sink, Stream}; +use futures_lite::io::{AsyncRead, AsyncWrite}; use std::{ collections::VecDeque, fmt::Debug, @@ -5,14 +9,8 @@ use std::{ pin::Pin, task::{Context, Poll}, }; - -use futures::{Sink, Stream}; - -use futures_lite::io::{AsyncRead, AsyncWrite}; use tracing::{debug, error, info, instrument, trace, warn}; -use crate::util::{stat_uint24_le, wrap_uint24_le}; - const BUF_SIZE: usize = 1024 * 64; const _HEADER_LEN: usize = 3; From f4fb37180d865aa9b5fe8c8114eac09779183e28 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 21:31:03 -0400 Subject: [PATCH 120/206] rename tests --- src/message.rs | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/message.rs b/src/message.rs index 1f0fdb5..09bd294 100644 --- a/src/message.rs +++ b/src/message.rs @@ -633,7 +633,7 @@ mod tests { } #[test] - fn extras() -> Result<(), EncodingError> { + fn enc_dec_vec_chan_message() -> Result<(), EncodingError> { let one = Message::Synchronize(Synchronize { fork: 0, length: 4, @@ -653,13 +653,6 @@ mod tests { let (result, rest) = as CompactEncoding>::decode(&buff)?; assert!(rest.is_empty()); assert_eq!(result, msgs); - - // from js interop tests - // [0, 0, 1, 5, 0, 7, 0, 4, 0, 4, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0] - // [0, 0, 1, 5, 0, 7, 0, 4, 0, 4, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0] - // [0, 0, 1, 5, 0, 7, 0, 4, 0, 4, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0] - // [23, 0, 0, 0, 0, 1, 5, 0, 7, 0, 4, 0, 4, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0] - //assert!(res2.is_ok()); Ok(()) } } From b247a36e0872458ca8c7c33d36fae6f0a24904c7 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 21:31:14 -0400 Subject: [PATCH 121/206] rm old notse --- src/protocol.rs | 1 - src/schema.rs | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/protocol.rs b/src/protocol.rs index 615cd53..4c10a1c 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -323,7 +323,6 @@ where } /// Poll for outbound messages and write them. - /// Reads messages from Self::outbound and sends them over io #[instrument(skip_all)] fn poll_outbound_write(&mut self, cx: &mut Context<'_>) -> Result<()> { loop { diff --git a/src/schema.rs b/src/schema.rs index b27f9e4..49a0ac5 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -51,7 +51,8 @@ impl CompactEncoding for Open { { let ((channel, protocol, discovery_key), rest) = map_decode!(buffer, [u64, String, Vec]); - // TODO this is a CLEAR bug it assumes nothing is encoded after this message + // NB: Open/Close are only sent alone in their own Frame. So we're done when there is no + // more data let (capability, rest) = if !rest.is_empty() { let (_, rest) = take_array::<1>(rest)?; let (capability, rest) = take_array::<32>(rest)?; From cefc744c554f2f8079e7e92514627337457866e3 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 21:32:40 -0400 Subject: [PATCH 122/206] group imports ran: cargo +nightly fmt with: imports_granularity = "crate" in .rustformat.toml --- benches/pipe.rs | 12 ++++++------ benches/throughput.rs | 17 ++++++++++------- examples/replication.rs | 19 ++++++++----------- src/builder.rs | 3 +-- src/channels.rs | 33 +++++++++++++++++++-------------- src/crypto/cipher.rs | 3 +-- src/crypto/handshake.rs | 6 ++++-- src/duplex.rs | 8 +++++--- src/message.rs | 3 +-- src/protocol.rs | 40 +++++++++++++++++++++++----------------- src/util.rs | 9 +++++---- tests/_util.rs | 9 +++++---- tests/basic.rs | 3 +-- tests/js/mod.rs | 8 +++++--- tests/js_interop.rs | 15 +++++++-------- 15 files changed, 101 insertions(+), 87 deletions(-) diff --git a/benches/pipe.rs b/benches/pipe.rs index b726545..6f2a4b8 100644 --- a/benches/pipe.rs +++ b/benches/pipe.rs @@ -1,14 +1,14 @@ use async_std::task; use criterion::{criterion_group, criterion_main, Criterion, Throughput}; -use futures::io::{AsyncRead, AsyncWrite}; -use futures::stream::StreamExt; -use hypercore_protocol::{schema::*, Duplex}; -use hypercore_protocol::{Channel, Event, Message, Protocol, ProtocolBuilder}; +use futures::{ + io::{AsyncRead, AsyncWrite}, + stream::StreamExt, +}; +use hypercore_protocol::{schema::*, Channel, Duplex, Event, Message, Protocol, ProtocolBuilder}; use log::*; use pretty_bytes::converter::convert as pretty_bytes; use sluice::pipe::pipe; -use std::io::Result; -use std::time::Instant; +use std::{io::Result, time::Instant}; const COUNT: u64 = 1000; const SIZE: u64 = 100; diff --git a/benches/throughput.rs b/benches/throughput.rs index cc2c278..1d9c4c0 100644 --- a/benches/throughput.rs +++ b/benches/throughput.rs @@ -1,11 +1,14 @@ -use async_std::net::{Shutdown, TcpListener, TcpStream}; -use async_std::task; +use async_std::{ + net::{Shutdown, TcpListener, TcpStream}, + task, +}; use criterion::{criterion_group, criterion_main, Criterion, Throughput}; -use futures::future::Either; -use futures::io::{AsyncRead, AsyncWrite}; -use futures::stream::{FuturesUnordered, StreamExt}; -use hypercore_protocol::schema::*; -use hypercore_protocol::{Channel, Event, Message, ProtocolBuilder}; +use futures::{ + future::Either, + io::{AsyncRead, AsyncWrite}, + stream::{FuturesUnordered, StreamExt}, +}; +use hypercore_protocol::{schema::*, Channel, Event, Message, ProtocolBuilder}; use log::*; use std::time::Instant; diff --git a/examples/replication.rs b/examples/replication.rs index ac10df6..85d0d11 100644 --- a/examples/replication.rs +++ b/examples/replication.rs @@ -1,22 +1,19 @@ use anyhow::Result; -use async_std::net::{TcpListener, TcpStream}; -use async_std::prelude::*; -use async_std::sync::{Arc, Mutex}; -use async_std::task; +use async_std::{ + net::{TcpListener, TcpStream}, + prelude::*, + sync::{Arc, Mutex}, + task, +}; use futures_lite::stream::StreamExt; use hypercore::{ Hypercore, HypercoreBuilder, PartialKeypair, RequestBlock, RequestUpgrade, Storage, VerifyingKey, }; -use std::collections::HashMap; -use std::convert::TryInto; -use std::env; -use std::fmt::Debug; -use std::sync::OnceLock; +use std::{collections::HashMap, convert::TryInto, env, fmt::Debug, sync::OnceLock}; use tracing::{error, info}; -use hypercore_protocol::schema::*; -use hypercore_protocol::{discovery_key, Channel, Event, Message, ProtocolBuilder}; +use hypercore_protocol::{discovery_key, schema::*, Channel, Event, Message, ProtocolBuilder}; fn main() { log(); diff --git a/src/builder.rs b/src/builder.rs index d797654..0b9127e 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -1,5 +1,4 @@ -use crate::Protocol; -use crate::{duplex::Duplex, protocol::Options}; +use crate::{duplex::Duplex, protocol::Options, Protocol}; use futures_lite::io::{AsyncRead, AsyncWrite}; /// Build a Protocol instance with options. diff --git a/src/channels.rs b/src/channels.rs index 1b94ece..f16ac7f 100644 --- a/src/channels.rs +++ b/src/channels.rs @@ -1,18 +1,23 @@ -use crate::message::ChannelMessage; -use crate::schema::*; -use crate::util::{map_channel_err, pretty_hash}; -use crate::Message; -use crate::{discovery_key, DiscoveryKey, Key}; +use crate::{ + discovery_key, + message::ChannelMessage, + schema::*, + util::{map_channel_err, pretty_hash}, + DiscoveryKey, Key, Message, +}; use async_channel::{Receiver, Sender, TrySendError}; -use futures_lite::ready; -use futures_lite::stream::Stream; -use std::collections::HashMap; -use std::fmt; -use std::io::{Error, ErrorKind, Result}; -use std::pin::Pin; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; -use std::task::Poll; +use futures_lite::{ready, stream::Stream}; +use std::{ + collections::HashMap, + fmt, + io::{Error, ErrorKind, Result}, + pin::Pin, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + task::Poll, +}; use tracing::instrument; /// A protocol channel. diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index ea11920..20cb734 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -5,8 +5,7 @@ use blake2::{ }; use crypto_secretstream::{Header, Key, PullStream, PushStream, Tag}; use rand::rngs::OsRng; -use std::convert::TryInto; -use std::io; +use std::{convert::TryInto, io}; const STREAM_ID_LENGTH: usize = 32; const KEY_LENGTH: usize = 32; diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index 10c111f..53f3889 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -3,8 +3,10 @@ use blake2::{ digest::{typenum::U32, FixedOutput, Update}, Blake2bMac, }; -use snow::resolvers::{DefaultResolver, FallbackResolver}; -use snow::{Builder, Error as SnowError, HandshakeState}; +use snow::{ + resolvers::{DefaultResolver, FallbackResolver}, + Builder, Error as SnowError, HandshakeState, +}; use std::io::{Error, ErrorKind, Result}; use tracing::instrument; diff --git a/src/duplex.rs b/src/duplex.rs index fe79c1b..7b0f1e5 100644 --- a/src/duplex.rs +++ b/src/duplex.rs @@ -1,7 +1,9 @@ use futures_lite::{AsyncRead, AsyncWrite}; -use std::io; -use std::pin::Pin; -use std::task::{Context, Poll}; +use std::{ + io, + pin::Pin, + task::{Context, Poll}, +}; #[derive(Clone, Debug, PartialEq)] /// Duplex IO stream from reader and writer halves. diff --git a/src/message.rs b/src/message.rs index 09bd294..7665df4 100644 --- a/src/message.rs +++ b/src/message.rs @@ -4,8 +4,7 @@ use compact_encoding::{ VecEncodable, }; use pretty_hash::fmt as pretty_fmt; -use std::fmt; -use std::io; +use std::{fmt, io}; use tracing::{debug, instrument, trace, warn}; const OPEN_MESSAGE_PREFIX: [u8; 2] = [0, 1]; diff --git a/src/protocol.rs b/src/protocol.rs index 4c10a1c..e188baf 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -1,25 +1,31 @@ use async_channel::{Receiver, Sender}; -use futures_lite::io::{AsyncRead, AsyncWrite}; -use futures_lite::stream::Stream; +use futures_lite::{ + io::{AsyncRead, AsyncWrite}, + stream::Stream, +}; use futures_timer::Delay; -use std::collections::VecDeque; -use std::convert::TryInto; -use std::fmt; -use std::io::{self, Error, ErrorKind, Result}; -use std::pin::Pin; -use std::task::{Context, Poll}; -use std::time::Duration; +use std::{ + collections::VecDeque, + convert::TryInto, + fmt, + io::{self, Error, ErrorKind, Result}, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; use tracing::{debug, error, instrument, warn}; -use crate::channels::{Channel, ChannelMap}; -use crate::constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}; -use crate::crypto::HandshakeResult; -use crate::message::{ChannelMessage, Message}; -use crate::mqueue::{MessageIo, MqueueEvent}; -use crate::noise::EncryptionInfo; -use crate::util::{map_channel_err, pretty_hash}; use crate::{ - encrypted_framed_message_channel, schema::*, Encrypted, Uint24LELengthPrefixedFraming, + channels::{Channel, ChannelMap}, + constants::{DEFAULT_KEEPALIVE, PROTOCOL_NAME}, + crypto::HandshakeResult, + encrypted_framed_message_channel, + message::{ChannelMessage, Message}, + mqueue::{MessageIo, MqueueEvent}, + noise::EncryptionInfo, + schema::*, + util::{map_channel_err, pretty_hash}, + Encrypted, Uint24LELengthPrefixedFraming, }; macro_rules! return_error { diff --git a/src/util.rs b/src/util.rs index 7e70336..5f243f2 100644 --- a/src/util.rs +++ b/src/util.rs @@ -2,11 +2,12 @@ use blake2::{ digest::{typenum::U32, FixedOutput, Update}, Blake2bMac, }; -use std::convert::TryInto; -use std::io::{Error, ErrorKind}; +use std::{ + convert::TryInto, + io::{Error, ErrorKind}, +}; -use crate::constants::DISCOVERY_NS_BUF; -use crate::DiscoveryKey; +use crate::{constants::DISCOVERY_NS_BUF, DiscoveryKey}; /// Calculate the discovery key of a key. /// diff --git a/tests/_util.rs b/tests/_util.rs index d15be38..fc299ca 100644 --- a/tests/_util.rs +++ b/tests/_util.rs @@ -1,11 +1,12 @@ use async_std::net::TcpStream; -use futures_lite::io::{AsyncRead, AsyncWrite}; -use futures_lite::StreamExt; +use futures_lite::{ + io::{AsyncRead, AsyncWrite}, + StreamExt, +}; use hypercore_protocol::{Channel, DiscoveryKey, Duplex, Event, Protocol, ProtocolBuilder}; use instant::Duration; use std::io; -use tokio::io::DuplexStream; -use tokio::task::JoinHandle; +use tokio::{io::DuplexStream, task::JoinHandle}; #[allow(unused)] pub(crate) fn log() { diff --git a/tests/basic.rs b/tests/basic.rs index 280e5be..f0d2b77 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -3,8 +3,7 @@ use _util::{ event_discovery_key, next_event, }; use futures_lite::StreamExt; -use hypercore_protocol::{discovery_key, Event, Message}; -use hypercore_protocol::{schema::*, DiscoveryKey}; +use hypercore_protocol::{discovery_key, schema::*, DiscoveryKey, Event, Message}; use std::io; use tokio::task; diff --git a/tests/js/mod.rs b/tests/js/mod.rs index 8894b3d..b8cd6ec 100644 --- a/tests/js/mod.rs +++ b/tests/js/mod.rs @@ -1,8 +1,10 @@ use anyhow::Result; use instant::Duration; -use std::fs::{create_dir_all, remove_dir_all, remove_file}; -use std::path::Path; -use std::process::Command; +use std::{ + fs::{create_dir_all, remove_dir_all, remove_file}, + path::Path, + process::Command, +}; #[cfg(feature = "async-std")] use async_std::{ diff --git a/tests/js_interop.rs b/tests/js_interop.rs index 5ae6acb..3c74112 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -2,16 +2,16 @@ use _util::wait_for_localhost_port; use anyhow::Result; use futures::Future; use futures_lite::stream::StreamExt; -use hypercore::SigningKey; use hypercore::{ - Hypercore, HypercoreBuilder, PartialKeypair, RequestBlock, RequestUpgrade, Storage, + Hypercore, HypercoreBuilder, PartialKeypair, RequestBlock, RequestUpgrade, SigningKey, Storage, VerifyingKey, PUBLIC_KEY_LENGTH, SECRET_KEY_LENGTH, }; use instant::Duration; -use std::fmt::Debug; -use std::path::Path; -use std::sync::Arc; -use std::sync::Once; +use std::{ + fmt::Debug, + path::Path, + sync::{Arc, Once}, +}; #[cfg(feature = "tokio")] use async_compat::CompatExt; @@ -25,8 +25,7 @@ use tokio::{ time::sleep, }; -use hypercore_protocol::schema::*; -use hypercore_protocol::{discovery_key, Channel, Event, Message, ProtocolBuilder}; +use hypercore_protocol::{discovery_key, schema::*, Channel, Event, Message, ProtocolBuilder}; pub mod _util; mod js; From f841a2bb846540ea6ff28424d748e3631344eae9 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 21:36:05 -0400 Subject: [PATCH 123/206] format code in docs --- src/lib.rs | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 3602517..7857bd1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -50,17 +50,14 @@ //! //! ```no_run //! # async_std::task::block_on(async { -//! use hypercore_protocol::{ProtocolBuilder, Event, Message}; -//! use hypercore_protocol::schema::*; //! use async_std::prelude::*; +//! use hypercore_protocol::{schema::*, Event, Message, ProtocolBuilder}; //! // Start a tcp server. //! let listener = async_std::net::TcpListener::bind("localhost:8000").await.unwrap(); //! async_std::task::spawn(async move { //! let mut incoming = listener.incoming(); //! while let Some(Ok(stream)) = incoming.next().await { -//! async_std::task::spawn(async move { -//! onconnection(stream, false).await -//! }); +//! async_std::task::spawn(async move { onconnection(stream, false).await }); //! } //! }); //! @@ -69,7 +66,7 @@ //! onconnection(stream, true).await; //! //! /// Start Hypercore protocol on a TcpStream. -//! async fn onconnection (stream: async_std::net::TcpStream, is_initiator: bool) { +//! async fn onconnection(stream: async_std::net::TcpStream, is_initiator: bool) { //! // A peer either is the initiator or a connection or is being connected to. //! let name = if is_initiator { "dialer" } else { "listener" }; //! // A key for the channel we want to open. Usually, this is a pre-shared key that both peers @@ -86,7 +83,7 @@ //! // The handshake event is emitted after the protocol is fully established. //! Event::Handshake(_remote_key) => { //! protocol.open(key.clone()).await; -//! }, +//! } //! // A Channel event is emitted for each established channel. //! Event::Channel(mut channel) => { //! // A Channel can be sent to other tasks. @@ -97,7 +94,7 @@ //! eprintln!("{} received message: {:?}", name, message); //! } //! }); -//! }, +//! } //! _ => {} //! } //! } From 0ee4be6464e22a56150f491552188946ec4b3a0e Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 21:38:43 -0400 Subject: [PATCH 124/206] rm unused --- tests/_util.rs | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/tests/_util.rs b/tests/_util.rs index fc299ca..b6f1d22 100644 --- a/tests/_util.rs +++ b/tests/_util.rs @@ -8,31 +8,6 @@ use instant::Duration; use std::io; use tokio::{io::DuplexStream, task::JoinHandle}; -#[allow(unused)] -pub(crate) fn log() { - static START_LOGS: std::sync::OnceLock<()> = std::sync::OnceLock::new(); - START_LOGS.get_or_init(|| { - use tracing_subscriber::{ - layer::SubscriberExt as _, util::SubscriberInitExt as _, EnvFilter, - }; - let env_filter = EnvFilter::from_default_env(); // Reads `RUST_LOG` environment variable - - // Create the hierarchical layer from tracing_tree - let tree_layer = tracing_tree::HierarchicalLayer::new(2) // 2 spaces per indent level - .with_targets(true) - .with_bracketed_fields(true) - .with_indent_lines(true) - .with_span_modes(true) - .with_thread_ids(false) - .with_thread_names(false); - - tracing_subscriber::registry() - .with(env_filter) - .with(tree_layer) - .init(); - }); -} - type TokioDuplex = tokio_util::compat::Compat; pub(crate) fn duplex(channel_size: usize) -> (TokioDuplex, TokioDuplex) { @@ -111,7 +86,6 @@ where }) } -#[allow(unused)] pub async fn wait_for_localhost_port(port: u32) { const RETRY_TIMEOUT: u64 = 100_u64; const NO_RESPONSE_TIMEOUT: u64 = 1000_u64; From 7f38fdab9d1eba4d368230873d7726715b250d83 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 19 May 2025 21:51:55 -0400 Subject: [PATCH 125/206] rm unwraps --- src/framing.rs | 16 ++++++++-------- tests/basic.rs | 1 - tests/js_interop.rs | 1 - 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/framing.rs b/src/framing.rs index 02730b1..5bc8297 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -331,7 +331,7 @@ pub(crate) mod test { // NB this sluice pipe // for d in data { - rightlp.feed(d.to_vec()).await.unwrap(); + rightlp.feed(d.to_vec()).await?; } let rflush = spawn(async move { rightlp.flush().await.unwrap(); @@ -340,14 +340,14 @@ pub(crate) mod test { let mut result1 = vec![]; for _ in data { - result1.push(leftlp.next().await.unwrap().unwrap()); + result1.push(leftlp.next().await.unwrap()?); } let mut rightlp = rflush.await?; assert_eq!(result1, data); for d in data { - leftlp.feed(d.to_vec()).await.unwrap(); + leftlp.feed(d.to_vec()).await?; } let lflush = spawn(async move { leftlp.flush().await.unwrap(); @@ -356,7 +356,7 @@ pub(crate) mod test { let mut result2 = vec![]; for _ in data { - result2.push(rightlp.next().await.unwrap().unwrap()); + result2.push(rightlp.next().await.unwrap()?); } let mut leftlp = lflush.await?; assert_eq!(result2, data); @@ -365,13 +365,13 @@ pub(crate) mod test { let mut r4 = vec![]; for d in data { - rightlp.send(d.to_vec()).await.unwrap(); - leftlp.send(d.to_vec()).await.unwrap(); + rightlp.send(d.to_vec()).await?; + leftlp.send(d.to_vec()).await?; } for _ in data { - r3.push(rightlp.next().await.unwrap().unwrap()); - r4.push(leftlp.next().await.unwrap().unwrap()); + r3.push(rightlp.next().await.unwrap()?); + r4.push(leftlp.next().await.unwrap()?); } assert_eq!(r3, data); diff --git a/tests/basic.rs b/tests/basic.rs index f0d2b77..d713937 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -11,7 +11,6 @@ mod _util; #[tokio::test] async fn basic_protocol() -> anyhow::Result<()> { - _util::log(); let (proto_a, proto_b) = create_pair_memory2().await?; let next_a = next_event(proto_a); diff --git a/tests/js_interop.rs b/tests/js_interop.rs index 3c74112..619841b 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -93,7 +93,6 @@ async fn ncrs_client_writer() -> Result<()> { #[tokio::test] #[cfg_attr(not(feature = "js_interop_tests"), ignore)] async fn rcrs_server_writer() -> Result<()> { - _util::log(); rcrs(true, 8107).await?; Ok(()) } From 2904f01d90566ee88bc437eaa793a11a7ed0883e Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 20 May 2025 00:51:39 -0400 Subject: [PATCH 126/206] clean up noise module --- src/noise.rs | 605 ++++++++++++++++++++------------------------------- 1 file changed, 240 insertions(+), 365 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index 7bac01a..4115cbe 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -39,16 +39,13 @@ impl Step { impl std::fmt::Display for Step { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - Step::NotInitialized => "NotInitialized", - Step::Handshake(_) => "Handshake", - Step::SecretStream(_) => "SecretStream", - Step::Established(_) => "Established", - } - ) + let x = match self { + Step::NotInitialized => "NotInitialized", + Step::Handshake(_) => "Handshake", + Step::SecretStream(_) => "SecretStream", + Step::Established(_) => "Established", + }; + write!(f, "{}", x) } } @@ -111,6 +108,230 @@ where pub fn encryption_established(&self) -> bool { self.step.established() } + + /// Check that we've done as much work as possible. Sending, receiving, encrypting and decrypting. + #[instrument(skip_all, ret)] + fn did_as_much_as_possible(&mut self, cx: &mut Context<'_>) -> bool { + // No incoming encrypted messages available. + self.poll_incomming_encrypted_messages(cx).is_pending() + // We're unable to send any anymore encrypted/setup messages either because we have none or the `Sink` is unavailable. + && (self.encrypted_tx.is_empty() || Sink::poll_ready(Pin::new(&mut self.io), cx).is_pending()) + // No encrypted messages waiting to be decrypted. + && self.encrypted_rx.is_empty() + // No plaint text messages waiting to be enccrypted or we're still setting up + && (self.plain_tx.is_empty() || !self.step.established()) + } + + /// Handle all message throughput. Sends, encrypts and decrypts messages + /// Returns `true` `step` is already [`Step::Established`]. + #[allow(clippy::too_many_arguments)] + #[instrument(skip_all, ret)] + fn poll_message_throughput(&mut self, cx: &mut Context<'_>) -> bool { + self.poll_outgoing_encrypted_messages(cx); + let _ = self.poll_incomming_encrypted_messages(cx); + if let Step::Established((encryptor, decryptor, ..)) = &mut self.step { + // decrypt incomming msgs + poll_decrypt( + decryptor, + &mut self.encrypted_rx, + &mut self.plain_rx, + self.is_initiator, + ); + // encrypt any pending plaintext outgoinng messages + poll_encrypt( + encryptor, + &mut self.encrypted_tx, + &mut self.plain_tx, + self.is_initiator, + &mut self.flush, + ); + true + } else { + self.poll_setup(); + false + } + } + #[instrument(skip_all, fields(initiator = %self.is_initiator))] + fn poll_setup(&mut self) { + // if we get an error, it could be because the other side reset, and is sending a new + // initialization message. + // If this is the case, we should retry this message after the error. + // But to avoid repeatedly retrying the first message, we should only retry if it is *not* the first msg. + // Still setting up + if let Ok(Some(msg)) = maybe_init(&mut self.step, self.is_initiator) { + // queue the init message to send first + trace!(initiator = %self.is_initiator,"queue initial msg"); + self.encrypted_tx.push_front(msg); + } + // TODO handle error + while let Some(enc_res) = self.encrypted_rx.pop_front() { + match enc_res { + Err(e) => { + error!("Recieved an error during setup encryption setup: {e:?}"); + break; + } + Ok(incoming_msg) => { + trace!(initiator = %self.is_initiator, "encrypted_rx dequeue recieved setup msg"); + if let Ok(msgs) = match self.handle_setup_message(&incoming_msg) { + Ok(x) => Ok(x), + Err(e) => { + error!("handle_setup_message error: {e:?}"); + Err(e) + } + } { + for msg in msgs.into_iter().rev() { + trace!(initiator = %self.is_initiator,"queue more setup msg"); + self.encrypted_tx.push_front(msg); + } + } + } + } + + if self.step.established() { + return; + } + } + } + #[instrument(skip_all, fields(initiator = %self.is_initiator))] + /// Fills `encrypted_rx` and drains `encrypted_tx`. + fn poll_outgoing_encrypted_messages(&mut self, cx: &mut Context<'_>) { + // send any pending outgoing messages + while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(&mut self.io), cx) { + if let Some(encrypted_out) = self.encrypted_tx.pop_front() { + trace!(initiator = %self.is_initiator, msg_len = encrypted_out.len(), step = %self.step, "TX message"); + if let Err(_e) = Sink::start_send(Pin::new(&mut self.io), encrypted_out) { + error!("Error polling encyrpted side io") + } + + self.flush = true; + } else { + break; + } + } + if self.flush { + match Sink::poll_flush(Pin::new(&mut self.io), cx) { + Poll::Ready(Ok(())) => { + self.flush = false; + trace!(initiator = %self.is_initiator, "all flushed"); + } + Poll::Ready(Err(_e)) => { + error!(initiator = %self.is_initiator, "Error sending encrypted msg") + } + Poll::Pending => { + // flush not complete try again later + self.flush = true; + } + } + } + } + + fn poll_incomming_encrypted_messages(&mut self, cx: &mut Context<'_>) -> Poll<()> { + // pull in any incomming encrypted messages + let mut got_some = false; + while let Poll::Ready(Some(encrypted_msg)) = Stream::poll_next(Pin::new(&mut self.io), cx) { + trace!(initiator = %self.is_initiator, step = %self.step, "RX message"); + self.encrypted_rx.push_back(encrypted_msg); + got_some = true; + } + if got_some { + Poll::Ready(()) + } else { + Poll::Pending + } + } + /// handle setup messages: if any are incorrect (cause an error) the state is reset + #[instrument(err, skip_all, fields(initiator = %self.is_initiator))] + fn handle_setup_message(&mut self, msg: &[u8]) -> Result>> { + // this would only happen after reset with a bad message. + let mut first_message = false; + if let Step::NotInitialized = self.step { + first_message = true; + assert!(!self.is_initiator); + warn!(initiator = %self.is_initiator, "Encrypted state was reset"); + let mut handshake = Handshake::new(self.is_initiator)?; + let _ = handshake.start_raw()?; + self.step = Step::Handshake(Box::new(handshake)); + } + match &self.step { + Step::NotInitialized => { + unreachable!("should not happen") + } + Step::Handshake(_) => { + let mut out = vec![]; + if let Step::Handshake(mut handshake) = + replace(&mut self.step, Step::NotInitialized) + { + trace!("RX handshake msg"); + if let Some(response) = match handshake.read_raw(msg) { + Ok(x) => x, + Err(e) => { + let maybe_init_message = + (!first_message && !self.is_initiator).then_some(msg.to_vec()); + + self.reset_encrypted(maybe_init_message); + return Err(e); + } + } { + trace!( + initiator = %self.is_initiator, + "read message and emitting response", + ); + out.push(response); + } + + if handshake.complete() { + debug!(initiator = %self.is_initiator, "Handshake completed"); + let handshake_result = match handshake.get_result() { + Ok(x) => x, + Err(e) => { + error!("into-result error {e:?}"); + return Err(e); + } + }; + // The cipher will be put to use to the writer only after the peer's answer has come + let (cipher, init_msg) = + match EncryptCipher::from_handshake_tx(handshake_result) { + Ok(x) => x, + Err(e) => { + error!("from_handshake_tx error {e:?}"); + return Err(e); + } + }; + out.push(init_msg); + self.step = Step::SecretStream((cipher, handshake_result.clone())); + debug!(initiator = %self.is_initiator, "Step changed to {}", self.step); + } else { + self.step = Step::Handshake(handshake); + } + } + Ok(out) + } + Step::SecretStream(_) => { + if let Step::SecretStream((enc_cipher, hs_result)) = + replace(&mut self.step, Step::NotInitialized) + { + let dec_cipher = + DecryptCipher::from_handshake_rx_and_init_msg(&hs_result, msg)?; + self.plain_rx.push_back(Event::from(hs_result.clone())); + self.step = Step::Established((enc_cipher, dec_cipher, hs_result)); + debug!(initiator = %self.is_initiator, "Step changed to {}", self.step); + } + Ok(vec![]) + } + Step::Established((..)) => todo!(), + } + } + #[instrument(skip_all)] + fn reset_encrypted(&mut self, maybe_init_message: Option>) { + error!("Encrypted RESET"); + self.step = Step::NotInitialized; + self.encrypted_tx.clear(); + self.encrypted_rx.clear(); + if let Some(msg) = maybe_init_message { + self.encrypted_rx.push_front(Ok(msg)); + } + self.flush = false; + } } impl< @@ -139,51 +360,21 @@ impl< #[instrument(skip_all, fields(initiator = %self.is_initiator))] fn poll_flush( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { // The flow here can be understood as reading from the encrypted side moving those messages // through to the plaintext side, then reading new plaintext messages and moving them to // the encrypted side. // We do this repeatedly until there's nothing else to do - let Encrypted { - io, - step, - is_initiator, - encrypted_tx, - encrypted_rx, - plain_tx, - plain_rx, - flush, - .. - } = self.get_mut(); - loop { - poll_message_throughput( - io, - cx, - step, - encrypted_tx, - encrypted_rx, - plain_rx, - plain_tx, - *is_initiator, - flush, - ); - poll_outgoing_encrypted_messages(io, cx, encrypted_tx, *is_initiator, flush, step); + self.poll_message_throughput(cx); + self.poll_outgoing_encrypted_messages(cx); // check if we've done all possible work - if did_as_much_as_possible( - io, - cx, - step, - encrypted_tx, - encrypted_rx, - plain_tx, - *is_initiator, - ) { - if !step.established() || !encrypted_tx.is_empty() || *flush { - trace!(not_established = !step.established(), tx_msgs_waiting = !encrypted_tx.is_empty(), flush = ?flush, "not done flushing"); + if self.did_as_much_as_possible(cx) { + if !self.step.established() || !self.encrypted_tx.is_empty() || self.flush { + trace!(not_established = !self.step.established(), tx_msgs_waiting = !self.encrypted_tx.is_empty(), flush = ?self.flush, "not done flushing"); cx.waker().wake_by_ref(); return Poll::Pending; } @@ -201,60 +392,15 @@ impl< } } -/// Check that we've done as much work as possible. Sending, receiving, encrypting and decrypting. -#[instrument(skip_all, ret)] -fn did_as_much_as_possible< - IO: Stream>> + Sink> + Send + Unpin + 'static, ->( - io: &mut IO, - cx: &mut Context<'_>, - step: &mut Step, - encrypted_tx: &mut VecDeque>, - encrypted_rx: &mut VecDeque>>, - plain_tx: &mut VecDeque>, - is_initiator: bool, -) -> bool { - // No incoming encrypted messages available. - poll_incomming_encrypted_messages(io, cx, encrypted_rx, is_initiator, step).is_pending() - // We're unable to send any anymore encrypted/setup messages either because we have none or the `Sink` is unavailable. - && (encrypted_tx.is_empty() || Sink::poll_ready(Pin::new(io), cx).is_pending()) - // No encrypted messages waiting to be decrypted. - && encrypted_rx.is_empty() - // No plaint text messages waiting to be enccrypted or we're still setting up - && (plain_tx.is_empty() || !step.established()) -} - impl>> + Sink> + Send + Unpin + 'static> Stream for Encrypted { type Item = Event; #[instrument(skip_all, fields(initiator = %self.is_initiator, ret, err))] - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let Encrypted { - io, - step, - is_initiator, - encrypted_tx, - encrypted_rx, - plain_tx, - plain_rx, - flush, - .. - } = self.get_mut(); - - if poll_message_throughput( - io, - cx, - step, - encrypted_tx, - encrypted_rx, - plain_rx, - plain_tx, - *is_initiator, - flush, - ) { - if let Some(msg) = plain_rx.pop_front() { + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.poll_message_throughput(cx) { + if let Some(msg) = self.plain_rx.pop_front() { Poll::Ready(Some(msg)) } else { Poll::Pending @@ -266,166 +412,6 @@ impl>> + Sink> + Send + Unpin + 'static } } -/// Handle all message throughput. Sends, encrypts and decrypts messages -/// Returns `true` `step` is already [`Step::Established`]. -#[allow(clippy::too_many_arguments)] -#[instrument(skip_all, ret)] -fn poll_message_throughput< - IO: Stream>> + Sink> + Send + Unpin + 'static, ->( - io: &mut IO, - cx: &mut Context<'_>, - step: &mut Step, - encrypted_tx: &mut VecDeque>, - encrypted_rx: &mut VecDeque>>, - plain_rx: &mut VecDeque, - plain_tx: &mut VecDeque>, - is_initiator: bool, - flush: &mut bool, -) -> bool { - poll_outgoing_encrypted_messages(io, cx, encrypted_tx, is_initiator, flush, step); - let _ = poll_incomming_encrypted_messages(io, cx, encrypted_rx, is_initiator, step); - if let Step::Established((encryptor, decryptor, ..)) = step { - // decrypt incomming msgs - poll_decrypt(decryptor, encrypted_rx, plain_rx, is_initiator); - // encrypt any pending plaintext outgoinng messages - poll_encrypt(encryptor, encrypted_tx, plain_tx, is_initiator, flush); - true - } else { - poll_setup( - step, - encrypted_tx, - encrypted_rx, - plain_rx, - is_initiator, - flush, - ); - false - } -} - -#[instrument(skip_all, fields(initiator = %is_initiator))] -fn poll_setup( - step: &mut Step, - encrypted_tx: &mut VecDeque>, - encrypted_rx: &mut VecDeque>>, - plain_rx: &mut VecDeque, - is_initiator: bool, - flush: &mut bool, -) { - // if we get an error, it could be because the other side reset, and is sending a new - // initialization message. - // If this is the case, we should retry this message after the error. - // But to avoid repeatedly retrying the first message, we should only retry if it is *not* the first msg. - // Still setting up - if let Ok(Some(msg)) = maybe_init(step, is_initiator) { - // queue the init message to send first - trace!(initiator = %is_initiator,"queue initial msg"); - encrypted_tx.push_front(msg); - } - // TODO handle error - while let Some(enc_res) = encrypted_rx.pop_front() { - match enc_res { - Err(e) => { - error!("Recieved an error during setup encryption setup: {e:?}"); - break; - } - Ok(incoming_msg) => { - trace!(initiator = %is_initiator, "encrypted_rx dequeue recieved setup msg"); - if let Ok(msgs) = match handle_setup_message( - step, - &incoming_msg, - is_initiator, - encrypted_tx, - encrypted_rx, - plain_rx, - flush, - ) { - Ok(x) => Ok(x), - Err(e) => { - error!("handle_setup_message error: {e:?}"); - Err(e) - } - } { - for msg in msgs.into_iter().rev() { - trace!(initiator = %is_initiator,"queue more setup msg"); - encrypted_tx.push_front(msg); - } - } - } - } - - if step.established() { - return; - } - } -} - -#[instrument(skip_all, fields(initiator = %is_initiator))] -/// Fills `encrypted_rx` and drains `encrypted_tx`. -fn poll_outgoing_encrypted_messages< - IO: Stream>> + Sink> + Send + Unpin + 'static, ->( - io: &mut IO, - cx: &mut Context<'_>, - encrypted_tx: &mut VecDeque>, - is_initiator: bool, - flush: &mut bool, - step: &Step, -) { - // send any pending outgoing messages - while let Poll::Ready(Ok(())) = Sink::poll_ready(Pin::new(io), cx) { - if let Some(encrypted_out) = encrypted_tx.pop_front() { - trace!(initiator = %is_initiator, msg_len = encrypted_out.len(), step = %step, "TX message"); - if let Err(_e) = Sink::start_send(Pin::new(io), encrypted_out) { - error!("Error polling encyrpted side io") - } - - *flush = true; - } else { - break; - } - } - if *flush { - match Sink::poll_flush(Pin::new(io), cx) { - Poll::Ready(Ok(())) => { - *flush = false; - trace!(initiator = %is_initiator, "all flushed"); - } - Poll::Ready(Err(_e)) => { - error!(initiator = %is_initiator, "Error sending encrypted msg") - } - Poll::Pending => { - // flush not complete try again later - *flush = true; - } - } - } -} - -fn poll_incomming_encrypted_messages< - IO: Stream>> + Sink> + Send + Unpin + 'static, ->( - io: &mut IO, - cx: &mut Context<'_>, - encrypted_rx: &mut VecDeque>>, - is_initiator: bool, - step: &Step, -) -> Poll<()> { - // pull in any incomming encrypted messages - let mut got_some = false; - while let Poll::Ready(Some(encrypted_msg)) = Stream::poll_next(Pin::new(io), cx) { - trace!(initiator = %is_initiator, step = %step, "RX message"); - encrypted_rx.push_back(encrypted_msg); - got_some = true; - } - if got_some { - Poll::Ready(()) - } else { - Poll::Pending - } -} - #[instrument(skip_all)] fn poll_decrypt( decryptor: &mut DecryptCipher, @@ -487,117 +473,6 @@ fn maybe_init(step: &mut Step, is_initiator: bool) -> Result>> { Ok(out) } -#[instrument(skip_all)] -fn reset_encrypted( - step: &mut Step, - maybe_init_message: Option>, - encrypted_tx: &mut VecDeque>, - encrypted_rx: &mut VecDeque>>, - flush: &mut bool, -) { - error!("Encrypted RESET"); - *step = Step::NotInitialized; - encrypted_tx.clear(); - encrypted_rx.clear(); - if let Some(msg) = maybe_init_message { - encrypted_rx.push_front(Ok(msg)); - } - *flush = false; -} - -/// handle setup messages: if any are incorrect (cause an error) the state is reset -#[instrument(err, skip_all, fields(initiator = %is_initiator))] -fn handle_setup_message( - step: &mut Step, - msg: &[u8], - is_initiator: bool, - encrypted_tx: &mut VecDeque>, - encrypted_rx: &mut VecDeque>>, - plain_rx: &mut VecDeque, - flush: &mut bool, -) -> Result>> { - // this would only happen after reset with a bad message. - let mut first_message = false; - if let Step::NotInitialized = step { - first_message = true; - assert!(!is_initiator); - warn!(initiator = %is_initiator, "Encrypted state was reset"); - let mut handshake = Handshake::new(is_initiator)?; - let _ = handshake.start_raw()?; - *step = Step::Handshake(Box::new(handshake)); - } - match &step { - Step::NotInitialized => { - unreachable!("should not happen") - } - Step::Handshake(_) => { - let mut out = vec![]; - if let Step::Handshake(mut handshake) = replace(step, Step::NotInitialized) { - trace!("RX handshake msg"); - if let Some(response) = match handshake.read_raw(msg) { - Ok(x) => x, - Err(e) => { - let maybe_init_message = - (!first_message && !is_initiator).then_some(msg.to_vec()); - - reset_encrypted( - step, - maybe_init_message, - encrypted_tx, - encrypted_rx, - flush, - ); - return Err(e); - } - } { - trace!( - initiator = %is_initiator, - "read message and emitting response", - ); - out.push(response); - } - - if handshake.complete() { - debug!(initiator = %is_initiator, "Handshake completed"); - let handshake_result = match handshake.get_result() { - Ok(x) => x, - Err(e) => { - error!("into-result error {e:?}"); - return Err(e); - } - }; - // The cipher will be put to use to the writer only after the peer's answer has come - let (cipher, init_msg) = - match EncryptCipher::from_handshake_tx(handshake_result) { - Ok(x) => x, - Err(e) => { - error!("from_handshake_tx error {e:?}"); - return Err(e); - } - }; - out.push(init_msg); - *step = Step::SecretStream((cipher, handshake_result.clone())); - debug!(initiator = %is_initiator, "Step changed to {step}"); - } else { - *step = Step::Handshake(handshake); - } - } - Ok(out) - } - Step::SecretStream(_) => { - if let Step::SecretStream((enc_cipher, hs_result)) = replace(step, Step::NotInitialized) - { - let dec_cipher = DecryptCipher::from_handshake_rx_and_init_msg(&hs_result, msg)?; - plain_rx.push_back(Event::from(hs_result.clone())); - *step = Step::Established((enc_cipher, dec_cipher, hs_result)); - debug!(initiator = %is_initiator, "Step changed to {step}"); - } - Ok(vec![]) - } - Step::Established((..)) => todo!(), - } -} - impl std::fmt::Debug for Encrypted { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Encrypted") From 0b17c7aa6eb538c72ef753023c2e5ece77d00202 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 20 May 2025 13:13:20 -0400 Subject: [PATCH 127/206] lint --- src/mqueue.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/mqueue.rs b/src/mqueue.rs index 87f14d5..bb92824 100644 --- a/src/mqueue.rs +++ b/src/mqueue.rs @@ -90,9 +90,7 @@ impl + Sink> + Send + Unpin + 'static> Mes } match Sink::poll_flush(Pin::new(&mut self.io), cx) { - Poll::Ready(Err(_e)) => { - todo!() - } + Poll::Ready(Err(_e)) => todo!(), Poll::Pending => { cx.waker().wake_by_ref(); return Poll::Pending; From 47a690044af077946e3bd2eb7292cea4f5159b06 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 20 May 2025 13:35:25 -0400 Subject: [PATCH 128/206] Remove log and env_log depndencies --- Cargo.toml | 2 -- benches/pipe.rs | 6 ++-- benches/throughput.rs | 8 +++-- examples/replication.rs | 65 ++++++++++++++--------------------------- src/test_utils.rs | 6 ++-- tests/js_interop.rs | 2 ++ 6 files changed, 37 insertions(+), 52 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a651734..173c907 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,7 +51,6 @@ path = "../core" async-std = { version = "1.12.0", features = ["attributes", "unstable"] } async-compat = "0.2.1" tokio = { version = "1.27.0", features = ["macros", "net", "process", "rt", "rt-multi-thread", "sync", "time"] } -env_logger = "0.7.1" anyhow = "1.0.28" instant = "0.1" criterion = { version = "0.4", features = ["async_std"] } @@ -59,7 +58,6 @@ pretty-bytes = "0.2.2" duplexify = "1.1.0" sluice = "0.5.4" futures = "0.3.13" -log = "0.4" tracing-subscriber = { version = "0.3.19", features = ["env-filter", "fmt"] } tracing-tree = "0.4.0" tokio-util = { version = "0.7.14", features = ["compat"] } diff --git a/benches/pipe.rs b/benches/pipe.rs index 6f2a4b8..9f87d84 100644 --- a/benches/pipe.rs +++ b/benches/pipe.rs @@ -1,3 +1,5 @@ +#[path = "../src/test_utils.rs"] +mod test_utils; use async_std::task; use criterion::{criterion_group, criterion_main, Criterion, Throughput}; use futures::{ @@ -5,17 +7,17 @@ use futures::{ stream::StreamExt, }; use hypercore_protocol::{schema::*, Channel, Duplex, Event, Message, Protocol, ProtocolBuilder}; -use log::*; use pretty_bytes::converter::convert as pretty_bytes; use sluice::pipe::pipe; use std::{io::Result, time::Instant}; +use tracing::{debug, error}; const COUNT: u64 = 1000; const SIZE: u64 = 100; const CONNS: u64 = 10; fn bench_throughput(c: &mut Criterion) { - env_logger::from_env(env_logger::Env::default().default_filter_or("error")).init(); + test_utils::log(); let mut group = c.benchmark_group("pipe"); group.sample_size(10); group.throughput(Throughput::Bytes(SIZE * COUNT * CONNS)); diff --git a/benches/throughput.rs b/benches/throughput.rs index 1d9c4c0..b19167e 100644 --- a/benches/throughput.rs +++ b/benches/throughput.rs @@ -1,3 +1,5 @@ +#[path = "../src/test_utils.rs"] +mod test_utils; use async_std::{ net::{Shutdown, TcpListener, TcpStream}, task, @@ -9,8 +11,8 @@ use futures::{ stream::{FuturesUnordered, StreamExt}, }; use hypercore_protocol::{schema::*, Channel, Event, Message, ProtocolBuilder}; -use log::*; use std::time::Instant; +use tracing::{debug, info, trace}; const PORT: usize = 11011; const SIZE: u64 = 1000; @@ -18,7 +20,7 @@ const COUNT: u64 = 200; const CLIENTS: usize = 1; fn bench_throughput(c: &mut Criterion) { - env_logger::from_env(env_logger::Env::default().default_filter_or("error")).init(); + test_utils::log(); let address = format!("localhost:{}", PORT); let mut group = c.benchmark_group("throughput"); @@ -67,7 +69,7 @@ criterion_main!(server_benches); async fn start_server(address: &str) -> futures::channel::oneshot::Sender<()> { let listener = TcpListener::bind(&address).await.unwrap(); - log::info!("listening on {}", listener.local_addr().unwrap()); + info!("listening on {}", listener.local_addr().unwrap()); let (kill_tx, mut kill_rx) = futures::channel::oneshot::channel(); task::spawn(async move { let mut incoming = listener.incoming(); diff --git a/examples/replication.rs b/examples/replication.rs index 85d0d11..35e2908 100644 --- a/examples/replication.rs +++ b/examples/replication.rs @@ -1,3 +1,5 @@ +#[path = "../src/test_utils.rs"] +mod test_utils; use anyhow::Result; use async_std::{ net::{TcpListener, TcpStream}, @@ -10,13 +12,13 @@ use hypercore::{ Hypercore, HypercoreBuilder, PartialKeypair, RequestBlock, RequestUpgrade, Storage, VerifyingKey, }; -use std::{collections::HashMap, convert::TryInto, env, fmt::Debug, sync::OnceLock}; -use tracing::{error, info}; +use std::{collections::HashMap, convert::TryInto, env, fmt::Debug}; +use tracing::{error, info, instrument}; use hypercore_protocol::{discovery_key, schema::*, Channel, Event, Message, ProtocolBuilder}; fn main() { - log(); + test_utils::log(); if env::args().count() < 3 { usage(); } @@ -62,12 +64,11 @@ fn main() { hypercore_store.add(hypercore_wrapper); let hypercore_store = Arc::new(hypercore_store); - let result = match mode.as_ref() { + let _ = match mode.as_ref() { "server" => tcp_server(address, onconnection, hypercore_store).await, "client" => tcp_client(address, onconnection, hypercore_store).await, _ => panic!("{:?}", usage()), }; - log_if_error(&result); }); } @@ -81,6 +82,7 @@ fn usage() { // or once when connected (if client). // Unfortunately, everything that touches the hypercore_store or a hypercore has to be generic // at the moment. +#[instrument(skip_all, ret)] async fn onconnection( stream: TcpStream, is_initiator: bool, @@ -123,17 +125,17 @@ struct HypercoreStore { hypercores: HashMap>, } impl HypercoreStore { - pub fn new() -> Self { + fn new() -> Self { let hypercores = HashMap::new(); Self { hypercores } } - pub fn add(&mut self, hypercore: HypercoreWrapper) { + fn add(&mut self, hypercore: HypercoreWrapper) { let hdkey = hex::encode(hypercore.discovery_key); self.hypercores.insert(hdkey, Arc::new(hypercore)); } - pub fn get(&self, discovery_key: &[u8; 32]) -> Option<&Arc> { + fn get(&self, discovery_key: &[u8; 32]) -> Option<&Arc> { let hdkey = hex::encode(discovery_key); self.hypercores.get(&hdkey) } @@ -148,7 +150,7 @@ struct HypercoreWrapper { } impl HypercoreWrapper { - pub fn from_memory_hypercore(hypercore: Hypercore) -> Self { + fn from_memory_hypercore(hypercore: Hypercore) -> Self { let key = hypercore.key_pair().public.to_bytes(); HypercoreWrapper { key, @@ -157,11 +159,11 @@ impl HypercoreWrapper { } } - pub fn key(&self) -> &[u8; 32] { + fn key(&self) -> &[u8; 32] { &self.key } - pub fn onpeer(&self, mut channel: Channel) { + fn onpeer(&self, mut channel: Channel) { let mut peer_state = PeerState::default(); let mut hypercore = self.hypercore.clone(); task::spawn(async move { @@ -415,32 +417,9 @@ async fn onmessage( Ok(()) } -#[allow(unused)] -pub fn log() { - use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; - static START_LOGS: OnceLock<()> = OnceLock::new(); - START_LOGS.get_or_init(|| { - tracing_subscriber::fmt() - .with_target(true) - .with_line_number(true) - // print when instrumented funtion enters - .with_span_events(FmtSpan::ENTER | FmtSpan::EXIT) - .with_file(true) - .with_env_filter(EnvFilter::from_default_env()) // Reads `RUST_LOG` environment variable - .without_time() - .init(); - }); -} - -/// Log a result if it's an error. -pub fn log_if_error(result: &Result<()>) { - if let Err(err) = result.as_ref() { - log::error!("error: {}", err); - } -} - /// A simple async TCP server that calls an async function for each incoming connection. -pub async fn tcp_server( +#[instrument(skip_all, ret)] +async fn tcp_server( address: String, onconnection: impl Fn(TcpStream, bool, C) -> F + Send + Sync + Copy + 'static, context: C, @@ -450,22 +429,22 @@ where C: Clone + Send + 'static, { let listener = TcpListener::bind(&address).await?; - log::info!("listening on {}", listener.local_addr()?); + tracing::info!("listening on {}", listener.local_addr()?); let mut incoming = listener.incoming(); while let Some(Ok(stream)) = incoming.next().await { let context = context.clone(); let peer_addr = stream.peer_addr().unwrap(); - log::info!("new connection from {}", peer_addr); + tracing::info!("new connection from {}", peer_addr); task::spawn(async move { - let result = onconnection(stream, false, context).await; - log_if_error(&result); - log::info!("connection closed from {}", peer_addr); + let _ = onconnection(stream, false, context).await; + tracing::info!("connection closed from {}", peer_addr); }); } Ok(()) } /// A simple async TCP client that calls an async function when connected. +#[instrument(skip_all, ret)] pub async fn tcp_client( address: String, onconnection: impl Fn(TcpStream, bool, C) -> F + Send + Sync + Copy + 'static, @@ -475,8 +454,8 @@ where F: Future> + Send, C: Clone + Send + 'static, { - log::info!("attempting connection to {address}"); + tracing::info!("attempting connection to {address}"); let stream = TcpStream::connect(&address).await?; - log::info!("connected to {address}"); + tracing::info!("connected to {address}"); onconnection(stream, true, context).await } diff --git a/src/test_utils.rs b/src/test_utils.rs index 8a4dd74..b1fc32d 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -1,13 +1,13 @@ +#![allow(dead_code)] use std::{ io::{self, ErrorKind}, pin::Pin, task::{Context, Poll}, }; -//use async_channel::{unbounded, Receiver, io::Error, Sender}; use futures::{ channel::mpsc::{unbounded, UnboundedReceiver as Receiver, UnboundedSender as Sender}, - Sink, SinkExt, Stream, StreamExt, + Sink, Stream, StreamExt, }; #[derive(Debug)] @@ -99,6 +99,7 @@ pub(crate) fn log() { #[tokio::test] async fn way_one() { + use futures::SinkExt; let mut a = Io::default(); let _ = a.send(b"hello".into()).await; let Some(res) = a.next().await else { panic!() }; @@ -107,6 +108,7 @@ async fn way_one() { #[tokio::test] async fn split() { + use futures::SinkExt; let (mut left, mut right) = (TwoWay::default()).split_sides(); left.send(b"hello".to_vec()).await.unwrap(); diff --git a/tests/js_interop.rs b/tests/js_interop.rs index 619841b..e115288 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -1,3 +1,5 @@ +#[path = "../src/test_utils.rs"] +mod test_utils; use _util::wait_for_localhost_port; use anyhow::Result; use futures::Future; From 13eeee77eb4c4fd52cd437a4fbd81de764a28e66 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 20 May 2025 13:43:49 -0400 Subject: [PATCH 129/206] clean up test_utils --- src/noise.rs | 4 +++- src/test_utils.rs | 44 +++++++++++++++++++++++--------------------- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index 4115cbe..e18f697 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -318,7 +318,9 @@ where } Ok(vec![]) } - Step::Established((..)) => todo!(), + Step::Established(_) => { + unreachable!("`handle_setup_message` should never be called when Step::Established") + } } } #[instrument(skip_all)] diff --git a/src/test_utils.rs b/src/test_utils.rs index b1fc32d..d8be13a 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -97,27 +97,6 @@ pub(crate) fn log() { }); } -#[tokio::test] -async fn way_one() { - use futures::SinkExt; - let mut a = Io::default(); - let _ = a.send(b"hello".into()).await; - let Some(res) = a.next().await else { panic!() }; - assert_eq!(res, b"hello"); -} - -#[tokio::test] -async fn split() { - use futures::SinkExt; - let (mut left, mut right) = (TwoWay::default()).split_sides(); - - left.send(b"hello".to_vec()).await.unwrap(); - let Some(res) = right.next().await else { - panic!(); - }; - assert_eq!(res, b"hello"); -} - pub(crate) struct Moo { receiver: Rx, sender: Tx, @@ -199,3 +178,26 @@ pub(crate) fn create_result_connected() -> ( let b = Moo::from(result_channel()); a.connect(b) } + +#[cfg(test)] +mod test_test_utils { + use super::*; + use futures::SinkExt; + #[tokio::test] + async fn way_one() { + let mut a = Io::default(); + let _ = a.send(b"hello".into()).await; + let Some(res) = a.next().await else { panic!() }; + assert_eq!(res, b"hello"); + } + + #[tokio::test] + async fn split() { + let (mut left, mut right) = (TwoWay::default()).split_sides(); + left.send(b"hello".to_vec()).await.unwrap(); + let Some(res) = right.next().await else { + panic!(); + }; + assert_eq!(res, b"hello"); + } +} From a40c4ff78139613a01fde5c996701acda434034a Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 20 May 2025 13:55:13 -0400 Subject: [PATCH 130/206] rename js_interop_tests to js_tests It was too much to type --- .github/workflows/ci.yml | 6 +++--- Cargo.toml | 4 ++-- README.md | 4 ++-- tests/js_interop.rs | 16 ++++++++-------- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5b58d59..3842f94 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -33,9 +33,9 @@ jobs: cargo check --all-targets cargo check --all-targets --no-default-features --features tokio cargo check --all-targets --no-default-features --features async-std - cargo test --features js_interop_tests - cargo test --no-default-features --features js_interop_tests,tokio - cargo test --no-default-features --features js_interop_tests,async-std + cargo test --features js_tests + cargo test --no-default-features --features js_tests,tokio + cargo test --no-default-features --features js_tests,async-std cargo test --benches build-extra: diff --git a/Cargo.toml b/Cargo.toml index 173c907..df292c7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -73,8 +73,8 @@ tokio = ["hypercore/tokio"] async-std = ["hypercore/async-std"] # Used only in interoperability tests under tests/js-interop which use the javascript version of hypercore # to verify that this crate works. To run them, use: -# cargo test --features js_interop_tests -js_interop_tests = [] +# cargo test --features js_tests +js_tests = [] [profile.bench] # debug = true diff --git a/README.md b/README.md index b8ed180..fada9df 100644 --- a/README.md +++ b/README.md @@ -72,10 +72,10 @@ node examples-nodejs/run.js node ## Development -To test interoperability with Javascript, enable the `js_interop_tests` feature: +To test interoperability with Javascript, enable the `js_tests` feature: ```bash -cargo test --features js_interop_tests +cargo test --features js_tests ``` Run benches with: diff --git a/tests/js_interop.rs b/tests/js_interop.rs index e115288..935382c 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -51,56 +51,56 @@ const TEST_SET_CLIENT_WRITER: &str = "cw"; const TEST_SET_SIMPLE: &str = "simple"; #[tokio::test] -#[cfg_attr(not(feature = "js_interop_tests"), ignore)] +#[cfg_attr(not(feature = "js_tests"), ignore)] async fn ncns_server_writer() -> Result<()> { ncns(true, 8101).await?; Ok(()) } #[tokio::test] -#[cfg_attr(not(feature = "js_interop_tests"), ignore)] +#[cfg_attr(not(feature = "js_tests"), ignore)] async fn ncns_client_writer() -> Result<()> { ncns(false, 8102).await?; Ok(()) } #[tokio::test] -#[cfg_attr(not(feature = "js_interop_tests"), ignore)] +#[cfg_attr(not(feature = "js_tests"), ignore)] async fn rcns_server_writer() -> Result<()> { rcns(true, 8103).await?; Ok(()) } #[tokio::test] -#[cfg_attr(not(feature = "js_interop_tests"), ignore)] +#[cfg_attr(not(feature = "js_tests"), ignore)] async fn rcns_client_writer() -> Result<()> { rcns(false, 8104).await?; Ok(()) } #[tokio::test] -#[cfg_attr(not(feature = "js_interop_tests"), ignore)] +#[cfg_attr(not(feature = "js_tests"), ignore)] async fn ncrs_server_writer() -> Result<()> { ncrs(true, 8105).await?; Ok(()) } #[tokio::test] -#[cfg_attr(not(feature = "js_interop_tests"), ignore)] +#[cfg_attr(not(feature = "js_tests"), ignore)] async fn ncrs_client_writer() -> Result<()> { ncrs(false, 8106).await?; Ok(()) } #[tokio::test] -#[cfg_attr(not(feature = "js_interop_tests"), ignore)] +#[cfg_attr(not(feature = "js_tests"), ignore)] async fn rcrs_server_writer() -> Result<()> { rcrs(true, 8107).await?; Ok(()) } #[tokio::test] -//#[cfg_attr(not(feature = "js_interop_tests"), ignore)] +//#[cfg_attr(not(feature = "js_tests"), ignore)] //#[ignore] // FIXME this tests hangs sporadically async fn rcrs_client_writer() -> Result<()> { rcrs(false, 8108).await?; From a3f3e6e2cb144014c9ad7d59d8d05bf930f7f9bf Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 20 May 2025 14:12:03 -0400 Subject: [PATCH 131/206] refactor tsets --- src/test_utils.rs | 10 +++++----- tests/js_interop.rs | 9 +++++---- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/test_utils.rs b/src/test_utils.rs index d8be13a..c2bba98 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -180,12 +180,12 @@ pub(crate) fn create_result_connected() -> ( } #[cfg(test)] -mod test_test_utils { - use super::*; - use futures::SinkExt; +mod test { + #![allow(unused_imports)] // test's within tests confused clippy + use futures::{SinkExt, StreamExt}; #[tokio::test] async fn way_one() { - let mut a = Io::default(); + let mut a = super::Io::default(); let _ = a.send(b"hello".into()).await; let Some(res) = a.next().await else { panic!() }; assert_eq!(res, b"hello"); @@ -193,7 +193,7 @@ mod test_test_utils { #[tokio::test] async fn split() { - let (mut left, mut right) = (TwoWay::default()).split_sides(); + let (mut left, mut right) = (super::TwoWay::default()).split_sides(); left.send(b"hello".to_vec()).await.unwrap(); let Some(res) = right.next().await else { panic!(); diff --git a/tests/js_interop.rs b/tests/js_interop.rs index 935382c..05174f6 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -1,7 +1,11 @@ +pub mod _util; #[path = "../src/test_utils.rs"] mod test_utils; + use _util::wait_for_localhost_port; use anyhow::Result; +#[cfg(feature = "tokio")] +use async_compat::CompatExt; use futures::Future; use futures_lite::stream::StreamExt; use hypercore::{ @@ -14,9 +18,6 @@ use std::{ path::Path, sync::{Arc, Once}, }; - -#[cfg(feature = "tokio")] -use async_compat::CompatExt; #[cfg(feature = "tokio")] use tokio::{ fs::{metadata, File}, @@ -29,7 +30,6 @@ use tokio::{ use hypercore_protocol::{discovery_key, schema::*, Channel, Event, Message, ProtocolBuilder}; -pub mod _util; mod js; use js::{cleanup, install, js_run_client, js_start_server, prepare_test_set}; @@ -40,6 +40,7 @@ fn init() { cleanup(); install(); }); + test_utils::log(); } const TEST_SET_NODE_CLIENT_NODE_SERVER: &str = "ncns"; From 5002b60c8bfe8c3a04445d6db0b435dd30ed3e9c Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 20 May 2025 16:07:48 -0400 Subject: [PATCH 132/206] RMME reset --- src/mqueue.rs | 13 +++++++++++-- src/noise.rs | 5 ++--- src/test_utils.rs | 5 +++-- tests/js_interop.rs | 14 ++++++++++++++ 4 files changed, 30 insertions(+), 7 deletions(-) diff --git a/src/mqueue.rs b/src/mqueue.rs index bb92824..9a2d91a 100644 --- a/src/mqueue.rs +++ b/src/mqueue.rs @@ -10,7 +10,7 @@ use std::{ use compact_encoding::CompactEncoding as _; use futures::{Sink, Stream}; -use tracing::{error, instrument}; +use tracing::{error, info, instrument}; use crate::{message::ChannelMessage, noise::EncryptionInfo, NoiseEvent}; @@ -119,9 +119,18 @@ impl + Sink> + Send + Unpin + 'static> Str { type Item = MqueueEvent; + #[instrument(skip_all, ret)] fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let _ = self.poll_outbound(cx); - self.poll_inbound(cx) + match self.poll_inbound(cx) { + Poll::Ready(Some(MqueueEvent::Message(Ok(x)))) => { + for m in x.iter() { + info!("RX ChannelMessage::{m}"); + } + Poll::Ready(Some(MqueueEvent::Message(Ok(x)))) + } + x => x, + } } } diff --git a/src/noise.rs b/src/noise.rs index e18f697..2c6e001 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -110,7 +110,7 @@ where } /// Check that we've done as much work as possible. Sending, receiving, encrypting and decrypting. - #[instrument(skip_all, ret)] + #[instrument(name = "did_as_much_as_possible", skip_all, ret)] fn did_as_much_as_possible(&mut self, cx: &mut Context<'_>) -> bool { // No incoming encrypted messages available. self.poll_incomming_encrypted_messages(cx).is_pending() @@ -124,8 +124,7 @@ where /// Handle all message throughput. Sends, encrypts and decrypts messages /// Returns `true` `step` is already [`Step::Established`]. - #[allow(clippy::too_many_arguments)] - #[instrument(skip_all, ret)] + #[instrument(name = "poll_message_throughput", skip_all, ret)] fn poll_message_throughput(&mut self, cx: &mut Context<'_>) -> bool { self.poll_outgoing_encrypted_messages(cx); let _ = self.poll_incomming_encrypted_messages(cx); diff --git a/src/test_utils.rs b/src/test_utils.rs index c2bba98..2e5e994 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -86,9 +86,10 @@ pub(crate) fn log() { .with_targets(true) .with_bracketed_fields(true) .with_indent_lines(true) - .with_span_modes(true) .with_thread_ids(false) - .with_thread_names(false); + .with_thread_names(true) + //.with_span_modes(true) + ; tracing_subscriber::registry() .with(env_filter) diff --git a/tests/js_interop.rs b/tests/js_interop.rs index 05174f6..2b2ee38 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -27,6 +27,7 @@ use tokio::{ task, time::sleep, }; +use tracing::instrument; use hypercore_protocol::{discovery_key, schema::*, Channel, Event, Message, ProtocolBuilder}; @@ -186,6 +187,7 @@ async fn rcns(server_writer: bool, port: u32) -> Result<()> { &result_path, ) .await?; + dbg!(); assert_result(result_path, item_count, item_size, data_char).await?; drop(server); @@ -330,11 +332,15 @@ async fn run_client( data_path: &str, result_path: &str, ) -> Result<()> { + dbg!(); let hypercore = if is_writer { + dbg!(); create_writer_hypercore(data_count, data_size, data_char, data_path).await? } else { + dbg!(); create_reader_hypercore(data_path).await? }; + dbg!(); let hypercore_wrapper = HypercoreWrapper::from_disk_hypercore( hypercore, if is_writer { @@ -343,7 +349,9 @@ async fn run_client( Some(result_path.to_string()) }, ); + dbg!(); tcp_client(port, on_replication_connection, Arc::new(hypercore_wrapper)).await?; + dbg!(); Ok(()) } @@ -433,21 +441,26 @@ pub fn get_test_key_pair(include_secret: bool) -> PartialKeypair { } #[cfg(feature = "tokio")] +#[instrument(skip_all)] async fn on_replication_connection( stream: TcpStream, is_initiator: bool, hypercore: Arc, ) -> Result<()> { + use tracing::info; + let mut protocol = ProtocolBuilder::new(is_initiator).connect(stream.compat()); while let Some(event) = protocol.next().await { let event = event?; match event { Event::Handshake(_) => { + info!("Event::Handshake"); if is_initiator { protocol.open(*hypercore.key()).await?; } } Event::DiscoveryKey(dkey) => { + info!("Event::DiscoveryKey"); if hypercore.discovery_key == dkey { protocol.open(*hypercore.key()).await?; } else { @@ -455,6 +468,7 @@ async fn on_replication_connection( } } Event::Channel(channel) => { + info!("Event::Channel"); hypercore.on_replication_peer(channel); } Event::Close(_dkey) => { From 083dfa9604dafc3044f1ddc0ba24d043add58b12 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 22 May 2025 19:41:16 -0400 Subject: [PATCH 133/206] docs & logging --- src/framing.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/framing.rs b/src/framing.rs index 5bc8297..760b7c6 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -9,7 +9,7 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use tracing::{debug, error, info, instrument, trace, warn}; +use tracing::{error, info, instrument, trace, warn}; const BUF_SIZE: usize = 1024 * 64; const _HEADER_LEN: usize = 3; @@ -21,7 +21,7 @@ pub struct Uint24LELengthPrefixedFraming { to_stream: Vec, /// Data from the `Sink` interface to be written out to [`Self::io`]'s [`AsyncWrite`] interface. from_sink: VecDeque>, - /// The index in [`Self::to_stream`] of the last byte that was to the [`Stream`]. + /// The index in [`Self::to_stream`] of the last byte that was sent to the [`Stream`]. last_out_idx: usize, /// The index in [`Self::to_stream`] of the last byte that was read from [`Self::io`]'s /// [`AsyncRead`] @@ -80,7 +80,7 @@ where step, .. } = self.get_mut(); - debug!( + trace!( "Try to AsyncRead up to (buff_size[{}] - last_data_idx[{}]) = [{}]", to_stream.len(), *last_data_idx, @@ -92,7 +92,7 @@ where Poll::Pending => 0, }; // TODO handle if to_stream is full - debug!("adding #=[{n_bytes_read}] bytes to end=[{}]", last_data_idx); + trace!("adding #=[{n_bytes_read}] bytes to end=[{}]", last_data_idx); *last_data_idx += n_bytes_read; // grow buffer if it's full if *last_data_idx == to_stream.len() - 1 { @@ -121,7 +121,7 @@ where if let Step::Body { start, end } = step { let end = *end as usize; if end <= *last_data_idx { - debug!(frame_size = end - *start, "Frame ready"); + trace!(frame_size = end - *start, "Frame ready"); let out = to_stream[*start..end].to_vec(); *step = Step::Header; @@ -173,7 +173,7 @@ where from_sink.push_front(msg[n..].to_vec()); warn!("only wrote [{n} / {}] bytes of message", msg.len()); } - debug!("flushed whole message of N=[{n}] bytes"); + trace!("flushed whole message of N=[{n}] bytes"); } Poll::Ready(Err(e)) => { error!("Error flushing data"); @@ -181,7 +181,7 @@ where } } } else { - debug!("No more messages to flush"); + trace!("No more messages to flush"); return Poll::Ready(Ok(())); } } From f1e0eb37dc61ed01c4e11a16e8fa9b81f063a3bf Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 22 May 2025 19:47:27 -0400 Subject: [PATCH 134/206] Allow unused for test utils --- tests/_util.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/_util.rs b/tests/_util.rs index b6f1d22..78c89e4 100644 --- a/tests/_util.rs +++ b/tests/_util.rs @@ -86,6 +86,7 @@ where }) } +#[allow(unused)] pub async fn wait_for_localhost_port(port: u32) { const RETRY_TIMEOUT: u64 = 100_u64; const NO_RESPONSE_TIMEOUT: u64 = 1000_u64; From 8ef2e2f8b716c74baf1f846a20be6558bc5120b1 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 22 May 2025 20:05:12 -0400 Subject: [PATCH 135/206] Don't open 2 channels with same peer in tests --- tests/js_interop.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/js_interop.rs b/tests/js_interop.rs index 2b2ee38..d81a812 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -450,25 +450,28 @@ async fn on_replication_connection( use tracing::info; let mut protocol = ProtocolBuilder::new(is_initiator).connect(stream.compat()); + let mut channel_opened = false; while let Some(event) = protocol.next().await { let event = event?; match event { Event::Handshake(_) => { info!("Event::Handshake"); - if is_initiator { + if is_initiator && !channel_opened { protocol.open(*hypercore.key()).await?; + channel_opened = true; } } Event::DiscoveryKey(dkey) => { info!("Event::DiscoveryKey"); - if hypercore.discovery_key == dkey { + if hypercore.discovery_key == dkey && !channel_opened { protocol.open(*hypercore.key()).await?; + channel_opened = true; } else { panic!("Invalid discovery key"); } } Event::Channel(channel) => { - info!("Event::Channel"); + info!("Event::Channel is_initiator = {is_initiator}"); hypercore.on_replication_peer(channel); } Event::Close(_dkey) => { From 43c79c4dc3ffaa98a8923546ac3b033059843efa Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Fri, 6 Jun 2025 13:08:42 -0400 Subject: [PATCH 136/206] RMME use local deps --- Cargo.toml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index df292c7..e4eeddc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,6 +62,13 @@ tracing-subscriber = { version = "0.3.19", features = ["env-filter", "fmt"] } tracing-tree = "0.4.0" tokio-util = { version = "0.7.14", features = ["compat"] } +[dev-dependencies.async-udx] +path = "../async-udx/" + +[dev-dependencies.rusty_nodejs_repl] +path = "../js-repl-rs/" +features = ["serde"] + [features] default = ["tokio", "sparse"] wasm-bindgen = [ From 72d3299cd9affd64cef5d855500b19b645729e60 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Fri, 6 Jun 2025 13:09:35 -0400 Subject: [PATCH 137/206] Add Encrypted.establish_encryption() method Poll until encryption established --- src/noise.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/noise.rs b/src/noise.rs index 2c6e001..82ae0c8 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -109,6 +109,18 @@ where self.step.established() } + /// Wait for the encrypted connection to be established + pub async fn establish_encryption(&mut self) -> Result<()> { + if self.encryption_established() { + return Ok(()); + } + match self.next().await { + Some(Event::Meta(EncryptionInfo::Handshake(_))) => Ok(()), + None => todo!("Return some error about stream closing"), + Some(Event::Decrypted(_)) => panic!("We should garuntee this would never happen after checking self.encryption_established() == false"), + } + } + /// Check that we've done as much work as possible. Sending, receiving, encrypting and decrypting. #[instrument(name = "did_as_much_as_possible", skip_all, ret)] fn did_as_much_as_possible(&mut self, cx: &mut Context<'_>) -> bool { From 20fea71da44d1286d1d3dbabdd742dd50893ead5 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Fri, 6 Jun 2025 13:10:13 -0400 Subject: [PATCH 138/206] Use encryption established method --- src/noise.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/noise.rs b/src/noise.rs index 82ae0c8..46701a1 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -1,4 +1,4 @@ -use futures::{AsyncRead, AsyncWrite, Sink, Stream}; +use futures::{AsyncRead, AsyncWrite, Sink, Stream, StreamExt as _}; use std::{ collections::VecDeque, fmt::Debug, @@ -131,7 +131,7 @@ where // No encrypted messages waiting to be decrypted. && self.encrypted_rx.is_empty() // No plaint text messages waiting to be enccrypted or we're still setting up - && (self.plain_tx.is_empty() || !self.step.established()) + && (self.plain_tx.is_empty() || !self.encryption_established()) } /// Handle all message throughput. Sends, encrypts and decrypts messages @@ -198,11 +198,12 @@ where } } - if self.step.established() { + if self.encryption_established() { return; } } } + #[instrument(skip_all, fields(initiator = %self.is_initiator))] /// Fills `encrypted_rx` and drains `encrypted_tx`. fn poll_outgoing_encrypted_messages(&mut self, cx: &mut Context<'_>) { @@ -386,8 +387,8 @@ impl< // check if we've done all possible work if self.did_as_much_as_possible(cx) { - if !self.step.established() || !self.encrypted_tx.is_empty() || self.flush { - trace!(not_established = !self.step.established(), tx_msgs_waiting = !self.encrypted_tx.is_empty(), flush = ?self.flush, "not done flushing"); + if !self.encryption_established() || !self.encrypted_tx.is_empty() || self.flush { + trace!(not_established = !self.encryption_established(), tx_msgs_waiting = !self.encrypted_tx.is_empty(), flush = ?self.flush, "not done flushing"); cx.waker().wake_by_ref(); return Poll::Pending; } From e9fb5cec372b7f2328f7099cc8e45b0b57b51de4 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Fri, 6 Jun 2025 13:10:33 -0400 Subject: [PATCH 139/206] rm dbg!() --- tests/js_interop.rs | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/js_interop.rs b/tests/js_interop.rs index d81a812..fe19db3 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -187,7 +187,6 @@ async fn rcns(server_writer: bool, port: u32) -> Result<()> { &result_path, ) .await?; - dbg!(); assert_result(result_path, item_count, item_size, data_char).await?; drop(server); @@ -332,15 +331,11 @@ async fn run_client( data_path: &str, result_path: &str, ) -> Result<()> { - dbg!(); let hypercore = if is_writer { - dbg!(); create_writer_hypercore(data_count, data_size, data_char, data_path).await? } else { - dbg!(); create_reader_hypercore(data_path).await? }; - dbg!(); let hypercore_wrapper = HypercoreWrapper::from_disk_hypercore( hypercore, if is_writer { @@ -349,9 +344,7 @@ async fn run_client( Some(result_path.to_string()) }, ); - dbg!(); tcp_client(port, on_replication_connection, Arc::new(hypercore_wrapper)).await?; - dbg!(); Ok(()) } From 2a3e50972560c297c31307cd494241aec6b1e42d Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 10 Jun 2025 14:18:57 -0400 Subject: [PATCH 140/206] const & vis cleanup. Confugarable Handshake --- src/crypto/handshake.rs | 70 +++++++++++++++++++++++++++++++++-------- src/crypto/mod.rs | 3 +- src/lib.rs | 2 ++ 3 files changed, 61 insertions(+), 14 deletions(-) diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index 53f3889..0d07c2d 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -4,6 +4,7 @@ use blake2::{ Blake2bMac, }; use snow::{ + params::HandshakePattern, resolvers::{DefaultResolver, FallbackResolver}, Builder, Error as SnowError, HandshakeState, }; @@ -11,7 +12,19 @@ use std::io::{Error, ErrorKind, Result}; use tracing::instrument; const CIPHERKEYLEN: usize = 32; -const HANDSHAKE_PATTERN: &str = "Noise_XX_Ed25519_ChaChaPoly_BLAKE2b"; + +/// The [`HandshakePattern`]s we support for connections +pub mod handshake_patterns { + use snow::params::HandshakePattern; + /// [`HandshakePattern`] used by the hyperdht crate. + pub const DHT: HandshakePattern = HandshakePattern::IK; + /// Noise protocol name used in hyperdht crate. + pub const DHT_NAME: &str = "Noise_IK_Ed25519_ChaChaPoly_BLAKE2b"; + /// [`HandshakePattern`] used by the hypercore-protocol crate. + pub const PROTOCOL: HandshakePattern = HandshakePattern::XX; + /// Noise protocol name used in hypercore-protocol crate. + pub const PROTOCOL_NAME: &str = "Noise_XX_Ed25519_ChaChaPoly_BLAKE2b"; +} // These the output of, see `hash_namespace` test below for how they are produced // https://github.com/hypercore-protocol/hypercore/blob/70b271643c4e4b1e5ecae5bb579966dfe6361ff3/lib/caps.js#L9 @@ -73,8 +86,9 @@ impl HandshakeResult { } } +/// Noise handshake for establishing secure connections #[derive(Debug)] -pub(crate) struct Handshake { +pub struct Handshake { result: HandshakeResult, state: HandshakeState, payload: Vec, @@ -119,7 +133,7 @@ impl Handshake { self.complete } - pub(crate) fn is_initiator(&self) -> bool { + fn is_initiator(&self) -> bool { self.result.is_initiator } @@ -128,7 +142,7 @@ impl Handshake { .read_message(msg, &mut self.rx_buf) .map_err(map_err) } - pub(crate) fn send(&mut self) -> Result { + fn send(&mut self) -> Result { self.state .write_message(&self.payload, &mut self.tx_buf) .map_err(map_err) @@ -188,24 +202,37 @@ impl Handshake { fn build_handshake_state( is_initiator: bool, +) -> std::result::Result<(HandshakeState, Vec), SnowError> { + build_handshake_state_with_config(is_initiator, &HandshakeConfig::default()) +} + +fn build_handshake_state_with_config( + is_initiator: bool, + config: &HandshakeConfig, ) -> std::result::Result<(HandshakeState, Vec), SnowError> { use snow::params::{ - BaseChoice, CipherChoice, DHChoice, HandshakeChoice, HandshakeModifierList, - HandshakePattern, HashChoice, NoiseParams, + BaseChoice, CipherChoice, DHChoice, HandshakeChoice, HandshakeModifierList, HashChoice, + NoiseParams, + }; + + let pattern_str = match config.pattern { + HandshakePattern::XX => "Noise_XX_Ed25519_ChaChaPoly_BLAKE2b", + HandshakePattern::IK => "Noise_IK_Ed25519_ChaChaPoly_BLAKE2b", + _ => return Err(SnowError::Input), }; - // NB: HANDSHAKE_PATTERN.parse() doesn't work because the pattern has "Ed25519" - // instead of "25519". + let noise_params = NoiseParams::new( - HANDSHAKE_PATTERN.to_string(), + pattern_str.to_string(), BaseChoice::Noise, HandshakeChoice { - pattern: HandshakePattern::XX, + pattern: config.pattern, modifiers: HandshakeModifierList { list: vec![] }, }, DHChoice::Curve25519, CipherChoice::ChaChaPoly, HashChoice::Blake2b, ); + let builder: Builder<'_> = Builder::with_resolver( noise_params, Box::new(FallbackResolver::new( @@ -213,15 +240,32 @@ fn build_handshake_state( Box::::default(), )), ); + let key_pair = builder.generate_keypair().unwrap(); - let builder = builder.local_private_key(&key_pair.private); + let mut builder = builder.local_private_key(&key_pair.private); + + // Set prologue if provided + if let Some(ref prologue) = config.prologue { + builder = builder.prologue(prologue); + } + + // Set remote public key for IK pattern initiator + if is_initiator && config.pattern == handshake_patterns::PROTOCOL { + if let Some(ref remote_key) = config.remote_public_key { + builder = builder.remote_public_key(remote_key); + } else { + return Err(SnowError::Input); + } + } + let handshake_state = if is_initiator { - tracing::debug!("building initiator"); + tracing::debug!("building initiator with pattern {:?}", config.pattern); builder.build_initiator()? } else { - tracing::debug!("building responder"); + tracing::debug!("building responder with pattern {:?}", config.pattern); builder.build_responder()? }; + Ok((handshake_state, key_pair.public)) } diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 66bb62d..48af488 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -2,4 +2,5 @@ mod cipher; mod curve; mod handshake; pub(crate) use cipher::{DecryptCipher, EncryptCipher}; -pub(crate) use handshake::{Handshake, HandshakeResult}; +pub(crate) use handshake::HandshakeResult; +pub use handshake::{handshake_patterns, Handshake, HandshakeConfig}; diff --git a/src/lib.rs b/src/lib.rs index 7857bd1..97b3629 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -144,3 +144,5 @@ pub use hypercore; // Re-export hypercore pub use message::Message; pub use protocol::{Command, CommandTx, DiscoveryKey, Event, Key, Protocol}; pub use util::discovery_key; +// Export DHT-related crypto functionality +pub use crypto::{handshake_patterns, Handshake, HandshakeConfig}; From c6cd7c873cb39a18f91fd7abcca1a2652730dfb6 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 10 Jun 2025 14:26:02 -0400 Subject: [PATCH 141/206] Add HandshakeConfig. Clean up hs pattern/name stuff --- src/crypto/handshake.rs | 54 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 48 insertions(+), 6 deletions(-) diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index 0d07c2d..1245434 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -3,6 +3,7 @@ use blake2::{ digest::{typenum::U32, FixedOutput, Update}, Blake2bMac, }; +use handshake_patterns::name_from_pattern; use snow::{ params::HandshakePattern, resolvers::{DefaultResolver, FallbackResolver}, @@ -16,6 +17,7 @@ const CIPHERKEYLEN: usize = 32; /// The [`HandshakePattern`]s we support for connections pub mod handshake_patterns { use snow::params::HandshakePattern; + use tracing::error; /// [`HandshakePattern`] used by the hyperdht crate. pub const DHT: HandshakePattern = HandshakePattern::IK; /// Noise protocol name used in hyperdht crate. @@ -24,6 +26,18 @@ pub mod handshake_patterns { pub const PROTOCOL: HandshakePattern = HandshakePattern::XX; /// Noise protocol name used in hypercore-protocol crate. pub const PROTOCOL_NAME: &str = "Noise_XX_Ed25519_ChaChaPoly_BLAKE2b"; + + /// Get a Noise protocol name from a handshake pattern + pub fn name_from_pattern(pattern: &HandshakePattern) -> Result<&'static str, snow::Error> { + Ok(match *pattern { + DHT => DHT_NAME, + PROTOCOL => PROTOCOL_NAME, + unsupported_pattern => { + error!("Got an unsupported handshake pattern {unsupported_pattern:?}"); + return Err(snow::Error::Input); + } + }) + } } // These the output of, see `hash_namespace` test below for how they are produced @@ -200,6 +214,38 @@ impl Handshake { } } +/// Configuration for creating a handshake with specific parameters +#[derive(Debug, Clone)] +pub struct HandshakeConfig { + /// The noise handshake pattern to use (XX, IK, etc.) + pub pattern: HandshakePattern, + /// Optional prologue data to include in the handshake + pub prologue: Option>, + /// Remote public key (required for IK pattern initiator) + // TODO replace with an actual key type + pub remote_public_key: Option<[u8; 32]>, +} + +impl HandshakeConfig { + fn new( + pattern: HandshakePattern, + prologue: Option>, + remote_public_key: Option<[u8; 32]>, + ) -> Self { + Self { + pattern, + prologue, + remote_public_key, + } + } +} + +impl Default for HandshakeConfig { + fn default() -> Self { + Self::new(HandshakePattern::XX, None, None) + } +} + fn build_handshake_state( is_initiator: bool, ) -> std::result::Result<(HandshakeState, Vec), SnowError> { @@ -215,14 +261,10 @@ fn build_handshake_state_with_config( NoiseParams, }; - let pattern_str = match config.pattern { - HandshakePattern::XX => "Noise_XX_Ed25519_ChaChaPoly_BLAKE2b", - HandshakePattern::IK => "Noise_IK_Ed25519_ChaChaPoly_BLAKE2b", - _ => return Err(SnowError::Input), - }; + let hs_name = name_from_pattern(&config.pattern)?; let noise_params = NoiseParams::new( - pattern_str.to_string(), + hs_name.to_string(), BaseChoice::Noise, HandshakeChoice { pattern: config.pattern, From debcc3b8a65735024c9edaf3d5d4a08ae1311ac6 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 10 Jun 2025 14:32:43 -0400 Subject: [PATCH 142/206] const renames --- src/crypto/handshake.rs | 14 +++++++------- src/crypto/mod.rs | 2 +- src/lib.rs | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index 1245434..cac1704 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -3,7 +3,7 @@ use blake2::{ digest::{typenum::U32, FixedOutput, Update}, Blake2bMac, }; -use handshake_patterns::name_from_pattern; +use handshake_constants::name_from_pattern; use snow::{ params::HandshakePattern, resolvers::{DefaultResolver, FallbackResolver}, @@ -15,23 +15,23 @@ use tracing::instrument; const CIPHERKEYLEN: usize = 32; /// The [`HandshakePattern`]s we support for connections -pub mod handshake_patterns { +pub mod handshake_constants { use snow::params::HandshakePattern; use tracing::error; /// [`HandshakePattern`] used by the hyperdht crate. - pub const DHT: HandshakePattern = HandshakePattern::IK; + pub const DHT_PATTERN: HandshakePattern = HandshakePattern::IK; /// Noise protocol name used in hyperdht crate. pub const DHT_NAME: &str = "Noise_IK_Ed25519_ChaChaPoly_BLAKE2b"; /// [`HandshakePattern`] used by the hypercore-protocol crate. - pub const PROTOCOL: HandshakePattern = HandshakePattern::XX; + pub const PROTOCOL_PATTERN: HandshakePattern = HandshakePattern::XX; /// Noise protocol name used in hypercore-protocol crate. pub const PROTOCOL_NAME: &str = "Noise_XX_Ed25519_ChaChaPoly_BLAKE2b"; /// Get a Noise protocol name from a handshake pattern pub fn name_from_pattern(pattern: &HandshakePattern) -> Result<&'static str, snow::Error> { Ok(match *pattern { - DHT => DHT_NAME, - PROTOCOL => PROTOCOL_NAME, + DHT_PATTERN => DHT_NAME, + PROTOCOL_PATTERN => PROTOCOL_NAME, unsupported_pattern => { error!("Got an unsupported handshake pattern {unsupported_pattern:?}"); return Err(snow::Error::Input); @@ -292,7 +292,7 @@ fn build_handshake_state_with_config( } // Set remote public key for IK pattern initiator - if is_initiator && config.pattern == handshake_patterns::PROTOCOL { + if is_initiator && config.pattern == handshake_constants::PROTOCOL_PATTERN { if let Some(ref remote_key) = config.remote_public_key { builder = builder.remote_public_key(remote_key); } else { diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 48af488..15900f5 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -3,4 +3,4 @@ mod curve; mod handshake; pub(crate) use cipher::{DecryptCipher, EncryptCipher}; pub(crate) use handshake::HandshakeResult; -pub use handshake::{handshake_patterns, Handshake, HandshakeConfig}; +pub use handshake::{handshake_constants, Handshake, HandshakeConfig}; diff --git a/src/lib.rs b/src/lib.rs index 97b3629..a767931 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -145,4 +145,4 @@ pub use message::Message; pub use protocol::{Command, CommandTx, DiscoveryKey, Event, Key, Protocol}; pub use util::discovery_key; // Export DHT-related crypto functionality -pub use crypto::{handshake_patterns, Handshake, HandshakeConfig}; +pub use crypto::{handshake_constants, Handshake, HandshakeConfig}; From b1b1f5bd71e0ad84e4eb5effe67560fc2e41b178 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 6 Jul 2025 16:33:44 -0400 Subject: [PATCH 143/206] tests need udx-native --- tests/js/package.json | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/js/package.json b/tests/js/package.json index c3a57ff..56fe846 100644 --- a/tests/js/package.json +++ b/tests/js/package.json @@ -2,6 +2,7 @@ "name": "hypercore-protocol-rs-js-interop-tests", "version": "0.0.1", "dependencies": { - "hypercore": "10.31.12" + "hypercore": "10.31.12", + "udx-native": "^1.0.0" } } From d8cb5c0f6b67dbfc9f7c072c747b1142430e916f Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sat, 26 Jul 2025 00:02:54 -0400 Subject: [PATCH 144/206] make stuff pub to use in hdht --- src/crypto/cipher.rs | 12 ++++++--- src/crypto/handshake.rs | 57 +++++++++++++++++++++++++++-------------- src/crypto/mod.rs | 2 +- src/lib.rs | 2 +- src/noise.rs | 15 ++++------- 5 files changed, 53 insertions(+), 35 deletions(-) diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index 20cb734..98fde28 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -10,7 +10,8 @@ use std::{convert::TryInto, io}; const STREAM_ID_LENGTH: usize = 32; const KEY_LENGTH: usize = 32; -pub(crate) struct DecryptCipher { +/// Convert ciphertext to plaintext +pub struct DecryptCipher { pull_stream: PullStream, } @@ -21,7 +22,8 @@ impl std::fmt::Debug for DecryptCipher { } impl DecryptCipher { - pub(crate) fn from_handshake_rx_and_init_msg( + /// Create from `HandshakeResult` and init message + pub fn from_handshake_rx_and_init_msg( handshake_result: &HandshakeResult, init_msg: &[u8], ) -> io::Result { @@ -101,7 +103,8 @@ fn write_stream_id(handshake_hash: &[u8], is_initiator: bool, out: &mut [u8]) { //NB "raw" here means UN-framed. No frame header. const RAW_HEADER_MSG_LEN: usize = STREAM_ID_LENGTH + Header::BYTES; -pub(crate) struct EncryptCipher { +/// Convert plaintext to ciphertext +pub struct EncryptCipher { push_stream: PushStream, } @@ -112,7 +115,8 @@ impl std::fmt::Debug for EncryptCipher { } impl EncryptCipher { - pub(crate) fn from_handshake_tx( + /// Create from HandshakeResult + pub fn from_handshake_tx( handshake_result: &HandshakeResult, ) -> std::io::Result<(Self, Vec)> { let key: [u8; KEY_LENGTH] = handshake_result.split_tx[..KEY_LENGTH] diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index cac1704..363448a 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -52,6 +52,8 @@ const REPLICATE_RESPONDER: [u8; 32] = [ ]; #[derive(Debug, Clone, Default)] +/// The thing created by [`Handshake`] which is used to encrypt and decrypt. NB while it is created +/// at the beginning of the handshake, it is not "ready" until the handshake completes. pub struct HandshakeResult { pub(crate) is_initiator: bool, pub(crate) local_pubkey: Vec, @@ -100,22 +102,27 @@ impl HandshakeResult { } } -/// Noise handshake for establishing secure connections +/// Object for holding the data needed Noise handshake protocol. The resulting object can be used +/// for encryption. #[derive(Debug)] pub struct Handshake { result: HandshakeResult, - state: HandshakeState, + /// Internal state of the handshake + pub state: HandshakeState, payload: Vec, tx_buf: Vec, rx_buf: Vec, complete: bool, did_receive: bool, + #[allow(dead_code)] + pattern: HandshakePattern, } impl Handshake { #[instrument] - pub(crate) fn new(is_initiator: bool) -> Result { - let (state, local_pubkey) = build_handshake_state(is_initiator).map_err(map_err)?; + /// Build a [`Handshake`] + pub fn new(is_initiator: bool, config: &HandshakeConfig) -> Result { + let (state, local_pubkey) = new_handshake_state(is_initiator, config).map_err(map_err)?; let payload = vec![]; let result = HandshakeResult { @@ -131,10 +138,17 @@ impl Handshake { rx_buf: vec![0u8; 512], complete: false, did_receive: false, + pattern: config.pattern, }) } - pub(crate) fn start_raw(&mut self) -> Result>> { + /// Set the payload for the next handshake message + pub fn set_payload(&mut self, payload: Vec) { + self.payload = payload; + } + + /// Start the handshake and return the initial message (for initiators) + pub fn start_raw(&mut self) -> Result>> { if self.is_initiator() { let tx_len = self.send()?; Ok(Some(self.tx_buf[..tx_len].to_vec())) @@ -143,7 +157,8 @@ impl Handshake { } } - pub(crate) fn complete(&self) -> bool { + /// Check if the handshake is completed + pub fn complete(&self) -> bool { self.complete } @@ -151,6 +166,7 @@ impl Handshake { self.result.is_initiator } + #[instrument(skip_all, err)] fn recv(&mut self, msg: &[u8]) -> Result { self.state .read_message(msg, &mut self.rx_buf) @@ -162,14 +178,20 @@ impl Handshake { .map_err(map_err) } - #[instrument(skip_all, fields(is_initiator = %self.result.is_initiator))] - pub(crate) fn read_raw(&mut self, msg: &[u8]) -> Result>> { + /// Read in a handshake message + #[instrument(skip_all, fields(is_initiator = %self.result.is_initiator), err)] + pub fn read_raw(&mut self, msg: &[u8]) -> Result>> { // eprintln!("hs read len {}", msg.len()); if self.complete() { return Err(Error::new(ErrorKind::Other, "Handshake read after finish")); } - let _rx_len = self.recv(msg)?; + let rx_len = self.recv(msg)?; + if self.state.is_handshake_finished() && self.is_initiator() { + let recieved = self.rx_buf[..rx_len].to_vec(); + self.complete = true; + return Ok(Some(recieved)); + } // first non-init if !self.is_initiator() && !self.did_receive { @@ -179,7 +201,8 @@ impl Handshake { return Ok(Some(wrapped)); } - let tx_buf = if self.is_initiator() { + /* when not IK pattern we need to send another message */ + let tx_buf = if self.is_initiator() && !self.state.is_handshake_finished() { let tx_len = self.send()?; let wrapped = self.tx_buf[..tx_len].to_vec(); Some(wrapped) @@ -205,7 +228,8 @@ impl Handshake { Ok(tx_buf) } - pub(crate) fn get_result(&self) -> Result<&HandshakeResult> { + /// get the handshake result when it is completed + pub fn get_result(&self) -> Result<&HandshakeResult> { if !self.complete() { Err(Error::new(ErrorKind::Other, "Handshake is not complete")) } else { @@ -246,13 +270,8 @@ impl Default for HandshakeConfig { } } -fn build_handshake_state( - is_initiator: bool, -) -> std::result::Result<(HandshakeState, Vec), SnowError> { - build_handshake_state_with_config(is_initiator, &HandshakeConfig::default()) -} - -fn build_handshake_state_with_config( +// TODO should this be infallible +fn new_handshake_state( is_initiator: bool, config: &HandshakeConfig, ) -> std::result::Result<(HandshakeState, Vec), SnowError> { @@ -292,7 +311,7 @@ fn build_handshake_state_with_config( } // Set remote public key for IK pattern initiator - if is_initiator && config.pattern == handshake_constants::PROTOCOL_PATTERN { + if (is_initiator) && (config.pattern) == handshake_constants::DHT_PATTERN { if let Some(ref remote_key) = config.remote_public_key { builder = builder.remote_public_key(remote_key); } else { diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 15900f5..f7c9c65 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -1,6 +1,6 @@ mod cipher; mod curve; mod handshake; -pub(crate) use cipher::{DecryptCipher, EncryptCipher}; +pub use cipher::{DecryptCipher, EncryptCipher}; pub(crate) use handshake::HandshakeResult; pub use handshake::{handshake_constants, Handshake, HandshakeConfig}; diff --git a/src/lib.rs b/src/lib.rs index a767931..5dbd28d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -145,4 +145,4 @@ pub use message::Message; pub use protocol::{Command, CommandTx, DiscoveryKey, Event, Key, Protocol}; pub use util::discovery_key; // Export DHT-related crypto functionality -pub use crypto::{handshake_constants, Handshake, HandshakeConfig}; +pub use crypto::{handshake_constants, DecryptCipher, EncryptCipher, Handshake, HandshakeConfig}; diff --git a/src/noise.rs b/src/noise.rs index 46701a1..a98108c 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -74,7 +74,8 @@ impl From for Event { } } -/// Wrap a stream with encryption +/// Wrap a Stream/Sink interface with encryption. +/// It sets up the handshake and encryption automatically. pub struct Encrypted { io: IO, step: Step, @@ -260,7 +261,7 @@ where first_message = true; assert!(!self.is_initiator); warn!(initiator = %self.is_initiator, "Encrypted state was reset"); - let mut handshake = Handshake::new(self.is_initiator)?; + let mut handshake = Handshake::new(self.is_initiator, &Default::default())?; let _ = handshake.start_raw()?; self.step = Step::Handshake(Box::new(handshake)); } @@ -302,13 +303,7 @@ where }; // The cipher will be put to use to the writer only after the peer's answer has come let (cipher, init_msg) = - match EncryptCipher::from_handshake_tx(handshake_result) { - Ok(x) => x, - Err(e) => { - error!("from_handshake_tx error {e:?}"); - return Err(e); - } - }; + EncryptCipher::from_handshake_tx(handshake_result)?; out.push(init_msg); self.step = Step::SecretStream((cipher, handshake_result.clone())); debug!(initiator = %self.is_initiator, "Step changed to {}", self.step); @@ -481,7 +476,7 @@ fn maybe_init(step: &mut Step, is_initiator: bool) -> Result>> { return Ok(None); } trace!(initiator = %is_initiator, "Init, state {step:?}"); - let mut handshake = Handshake::new(is_initiator)?; + let mut handshake = Handshake::new(is_initiator, &Default::default())?; let out = handshake.start_raw()?; *step = Step::Handshake(Box::new(handshake)); Ok(out) From ea7a1ab78b0c1bab14c6bec8acdf244be8ae83e9 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sat, 26 Jul 2025 16:31:02 -0400 Subject: [PATCH 145/206] working on ik stream --- src/sstream.rs | 132 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 src/sstream.rs diff --git a/src/sstream.rs b/src/sstream.rs new file mode 100644 index 0000000..6984edc --- /dev/null +++ b/src/sstream.rs @@ -0,0 +1,132 @@ +use crate::Error; +use snow::{params::NoiseParams, Builder, HandshakeState}; +const PARAMS: &str = "Noise_IK_25519_ChaChaPoly_BLAKE2b"; + +struct InitiatorConfig { + remote_public_key: [u8; 32], +} +impl InitiatorConfig { + fn new(remote_public_key: [u8; 32]) -> Self { + Self { remote_public_key } + } +} + +struct IkSecretStream { + is_initiator: bool, + state: HandshakeState, + msg_buf: [u8; 1024], +} + +macro_rules! is_initiator { + ($self:expr) => { + if $self.is_initiator { + panic!(); + } + }; +} +macro_rules! is_not_initiator { + ($self:expr) => { + if !$self.is_initiator { + panic!(); + } + }; +} + +impl IkSecretStream { + fn new_initiator(config: InitiatorConfig) -> Result { + let params: NoiseParams = PARAMS.parse().expect("known to work"); + let kp = Builder::new(params.clone()).generate_keypair()?; + let state = Builder::new(params.clone()) + .local_private_key(&kp.private)? + .remote_public_key(&config.remote_public_key)? + .build_initiator()?; + + Ok(Self { + is_initiator: false, + state, + msg_buf: [0; 1024], + }) + } + + fn make_first_msg(&mut self, prologue: &[u8]) -> Result, Error> { + is_initiator!(&self); + let len = self.state.write_message(prologue, &mut self.msg_buf)?; + Ok(self.msg_buf[..len].to_vec()) + } + + fn read_first_msg(&mut self, msg: &[u8]) -> Result, Error> { + is_not_initiator!(&self); + let len = self.state.read_message(msg, &mut self.msg_buf)?; + Ok(self.msg_buf[..len].to_vec()) + } + + fn new_responder() -> Result { + let params: NoiseParams = PARAMS.parse().expect("known to work"); + let kp = Builder::new(params.clone()).generate_keypair()?; + let state = Builder::new(params.clone()) + .local_private_key(&kp.private)? + .build_responder()?; + Ok(Self { + is_initiator: false, + state, + msg_buf: [0; 1024], + }) + } +} +#[cfg(test)] +mod test { + use super::*; + + #[tokio::test] + async fn sstream() -> std::result::Result<(), Box> { + let params: NoiseParams = PARAMS.parse().expect("known to work"); + let kp = Builder::new(params.clone()).generate_keypair()?; + let config = InitiatorConfig::new(kp.public.try_into().unwrap()); + let mut init = IkSecretStream::new_initiator(config)?; + let mut resp = IkSecretStream::new_responder()?; + let msg = init.make_first_msg(&[])?; + let msg = resp.read_first_msg(&msg)?; + Ok(()) + } + + #[tokio::test] + async fn hs() -> std::result::Result<(), Box> { + let params: NoiseParams = "Noise_IK_25519_ChaChaPoly_BLAKE2b".parse()?; + + let initiator_kp = Builder::new(params.clone()).generate_keypair()?; + let responder_kp = Builder::new(params.clone()).generate_keypair()?; + + let mut initiator = Builder::new(params.clone()) + .local_private_key(&initiator_kp.private)? + .remote_public_key(&responder_kp.public)? + .build_initiator()?; + + let mut responder = Builder::new(params.clone()) + .local_private_key(&responder_kp.private)? + .build_responder()?; + let (mut read_buf, mut first_msg, mut second_msg, mut enc_buf) = + ([0u8; 1024], [0u8; 1024], [0u8; 1024], [0u8; 1024]); + + // -> e, es, s, ss + let first_len = initiator.write_message(&[], &mut first_msg)?; + // responder processes the first message... + let read_len = responder.read_message(&first_msg[..first_len], &mut read_buf)?; + dbg!(&first_len); + dbg!(&read_buf[..read_len]); + + // <- e, ee, se + let second_len = responder.write_message(&[], &mut second_msg)?; + let _read_len = initiator.read_message(&second_msg[..second_len], &mut read_buf)?; + + let mut resp_transport = responder.into_transport_mode()?; + let mut init_transport = initiator.into_transport_mode()?; + + let msg = b"my message"; + let elen = resp_transport.write_message(msg, &mut enc_buf)?; + println!("{:?}", &enc_buf[..elen]); + let rlen = init_transport.read_message(&enc_buf[..elen], &mut read_buf)?; + println!("{}", String::from_utf8_lossy(&read_buf[..rlen])); + + Ok(()) + } +} From ab840946dffbeda077db4c47553fdddc384e340f Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sat, 26 Jul 2025 16:48:20 -0400 Subject: [PATCH 146/206] bump to snow 0.10.0 --- Cargo.toml | 2 +- src/crypto/curve.rs | 3 ++- src/crypto/handshake.rs | 6 +++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e4eeddc..351ba0d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,7 @@ bench = false [dependencies] async-channel = "1" -snow = { version = "0.9", features = ["risky-raw-split"] } +snow = { version = "0.10", features = ["risky-raw-split"] } bytes = "1" rand = "0.8" blake2 = "0.10" diff --git a/src/crypto/curve.rs b/src/crypto/curve.rs index 48ed841..e71d582 100644 --- a/src/crypto/curve.rs +++ b/src/crypto/curve.rs @@ -37,7 +37,7 @@ impl Dh for Ed25519 { self.pubkey[..public_key_bytes.len()].copy_from_slice(public_key_bytes); } - fn generate(&mut self, _: &mut dyn Random) { + fn generate(&mut self, _: &mut dyn Random) -> Result<(), snow::Error> { // NB: Given Random can't be used with ed25519_dalek's SigningKey::generate(), // use OS's random here from hypercore. let signing_key = generate_signing_key(); @@ -46,6 +46,7 @@ impl Dh for Ed25519 { let verifying_key = signing_key.verifying_key(); let public_key_bytes = verifying_key.as_bytes(); self.pubkey[..public_key_bytes.len()].copy_from_slice(public_key_bytes); + Ok(()) } fn pubkey(&self) -> &[u8] { diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index 363448a..9c130c4 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -303,17 +303,17 @@ fn new_handshake_state( ); let key_pair = builder.generate_keypair().unwrap(); - let mut builder = builder.local_private_key(&key_pair.private); + let mut builder = builder.local_private_key(&key_pair.private)?; // Set prologue if provided if let Some(ref prologue) = config.prologue { - builder = builder.prologue(prologue); + builder = builder.prologue(prologue)?; } // Set remote public key for IK pattern initiator if (is_initiator) && (config.pattern) == handshake_constants::DHT_PATTERN { if let Some(ref remote_key) = config.remote_public_key { - builder = builder.remote_public_key(remote_key); + builder = builder.remote_public_key(remote_key)?; } else { return Err(SnowError::Input); } From 9cbf0fc223944da00090bf8c2879e898cca8d6bc Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sat, 26 Jul 2025 16:49:01 -0400 Subject: [PATCH 147/206] Add error.rs --- Cargo.toml | 1 + src/error.rs | 5 +++++ src/lib.rs | 1 + 3 files changed, 7 insertions(+) create mode 100644 src/error.rs diff --git a/Cargo.toml b/Cargo.toml index 351ba0d..5695165 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,6 +41,7 @@ curve25519-dalek = "4" crypto_secretstream = "0.2" futures = "0.3.31" compact-encoding = "2" +thiserror = "2.0.12" [dependencies.hypercore] path = "../core" diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..6c210ed --- /dev/null +++ b/src/error.rs @@ -0,0 +1,5 @@ +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("Error from `snow`: {0}")] + Snow(#[from] snow::Error), +} diff --git a/src/lib.rs b/src/lib.rs index 5dbd28d..4b23b25 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -119,6 +119,7 @@ mod channels; mod constants; mod crypto; mod duplex; +mod error; mod framing; mod message; mod mqueue; From 05d577c22c132e8d50b2cc3da8952709636c2cef Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sat, 26 Jul 2025 16:49:58 -0400 Subject: [PATCH 148/206] wip sstream --- src/lib.rs | 2 + src/sstream.rs | 133 ++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 107 insertions(+), 28 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 4b23b25..5c0dcf6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -125,12 +125,14 @@ mod message; mod mqueue; mod noise; mod protocol; +mod sstream; #[cfg(test)] mod test_utils; mod util; /// The wire messages used by the protocol. pub mod schema; +use error::Error; pub use builder::Builder as ProtocolBuilder; pub use channels::Channel; diff --git a/src/sstream.rs b/src/sstream.rs index 6984edc..cff3b82 100644 --- a/src/sstream.rs +++ b/src/sstream.rs @@ -1,3 +1,5 @@ +use std::marker::PhantomData; + use crate::Error; use snow::{params::NoiseParams, Builder, HandshakeState}; const PARAMS: &str = "Noise_IK_25519_ChaChaPoly_BLAKE2b"; @@ -11,28 +13,21 @@ impl InitiatorConfig { } } -struct IkSecretStream { +struct IkSecretStream { is_initiator: bool, state: HandshakeState, msg_buf: [u8; 1024], + _step: PhantomData, } -macro_rules! is_initiator { - ($self:expr) => { - if $self.is_initiator { - panic!(); - } - }; -} -macro_rules! is_not_initiator { - ($self:expr) => { - if !$self.is_initiator { - panic!(); - } - }; -} +struct InitiatorInitial; +struct ResponderInitial; +struct InitiatorInitialSent; +struct ResponderReplied; +struct InitiatorReady; +struct ResponderReady; -impl IkSecretStream { +impl IkSecretStream { fn new_initiator(config: InitiatorConfig) -> Result { let params: NoiseParams = PARAMS.parse().expect("known to work"); let kp = Builder::new(params.clone()).generate_keypair()?; @@ -45,21 +40,35 @@ impl IkSecretStream { is_initiator: false, state, msg_buf: [0; 1024], + _step: PhantomData, }) } - fn make_first_msg(&mut self, prologue: &[u8]) -> Result, Error> { - is_initiator!(&self); + fn make_first_msg( + mut self, + prologue: &[u8], + ) -> Result<(IkSecretStream, Vec), Error> { let len = self.state.write_message(prologue, &mut self.msg_buf)?; - Ok(self.msg_buf[..len].to_vec()) - } - - fn read_first_msg(&mut self, msg: &[u8]) -> Result, Error> { - is_not_initiator!(&self); - let len = self.state.read_message(msg, &mut self.msg_buf)?; - Ok(self.msg_buf[..len].to_vec()) + let msg = self.msg_buf[..len].to_vec(); + let Self { + is_initiator, + state, + msg_buf, + .. + } = self; + Ok(( + IkSecretStream { + is_initiator, + state, + msg_buf, + _step: PhantomData, + }, + msg, + )) } +} +impl IkSecretStream { fn new_responder() -> Result { let params: NoiseParams = PARAMS.parse().expect("known to work"); let kp = Builder::new(params.clone()).generate_keypair()?; @@ -70,22 +79,90 @@ impl IkSecretStream { is_initiator: false, state, msg_buf: [0; 1024], + _step: PhantomData, }) } + + fn read_first_msg( + mut self, + msg: &[u8], + ) -> Result<(IkSecretStream, Vec), Error> { + let len = dbg!(self.state.read_message(msg, &mut self.msg_buf))?; + let msg = self.msg_buf[..len].to_vec(); + let Self { + is_initiator, + state, + msg_buf, + .. + } = self; + Ok(( + IkSecretStream { + is_initiator, + state, + msg_buf, + _step: PhantomData, + }, + msg, + )) + } +} + +impl IkSecretStream { + fn receive_msg( + mut self, + msg: &[u8], + ) -> Result<(IkSecretStream, Vec), Error> { + // create encrypt cipher here + // send response + todo!() + } +} + +impl IkSecretStream { + fn receive_msg( + mut self, + msg: &[u8], + ) -> Result<(IkSecretStream, Vec), Error> { + // create DecryptCipher + todo!() + } +} + +/* +impl IkSecretStream { + + fn make_first_msg(&mut self, prologue: &[u8]) -> Result, Error> { + is_initiator!(&self); + let len = self.state.write_message(prologue, &mut self.msg_buf)?; + Ok(self.msg_buf[..len].to_vec()) + } + + fn read_first_msg(&mut self, msg: &[u8]) -> Result, Error> { + is_not_initiator!(&self); + let len = self.state.read_message(msg, &mut self.msg_buf)?; + Ok(self.msg_buf[..len].to_vec()) + } } +*/ + #[cfg(test)] mod test { use super::*; #[tokio::test] - async fn sstream() -> std::result::Result<(), Box> { + async fn foosstream() -> std::result::Result<(), Box> { let params: NoiseParams = PARAMS.parse().expect("known to work"); let kp = Builder::new(params.clone()).generate_keypair()?; let config = InitiatorConfig::new(kp.public.try_into().unwrap()); + dbg!(); let mut init = IkSecretStream::new_initiator(config)?; + dbg!(); let mut resp = IkSecretStream::new_responder()?; - let msg = init.make_first_msg(&[])?; - let msg = resp.read_first_msg(&msg)?; + dbg!(); + let (init, msg) = init.make_first_msg(&[])?; + dbg!(); + let (resp, msg) = resp.read_first_msg(&msg)?; + dbg!(); Ok(()) } From 8368fd8aeabd40088bda81bd7cd4dfe1b68de6fe Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 28 Jul 2025 00:44:14 -0400 Subject: [PATCH 149/206] work push & pull stream! --- src/sstream.rs | 251 +++++++++++++++++++++++++++++++++++++------------ 1 file changed, 191 insertions(+), 60 deletions(-) diff --git a/src/sstream.rs b/src/sstream.rs index cff3b82..1934096 100644 --- a/src/sstream.rs +++ b/src/sstream.rs @@ -1,8 +1,14 @@ -use std::marker::PhantomData; - -use crate::Error; +#![allow(dead_code)] +use crypto_secretstream::{Header, Key, PullStream, PushStream, Tag}; +use rand::rngs::OsRng; use snow::{params::NoiseParams, Builder, HandshakeState}; + +use crate::{crypto::write_stream_id, Error}; + const PARAMS: &str = "Noise_IK_25519_ChaChaPoly_BLAKE2b"; +const STREAM_ID_LENGTH: usize = 32; +const RAW_HEADER_MSG_LEN: usize = STREAM_ID_LENGTH + Header::BYTES; +const SNOW_CIPHERKEYLEN: usize = 32; struct InitiatorConfig { remote_public_key: [u8; 32], @@ -17,15 +23,39 @@ struct IkSecretStream { is_initiator: bool, state: HandshakeState, msg_buf: [u8; 1024], - _step: PhantomData, + _step: Step, +} + +impl IkSecretStream { + // split handshake into (tx, rx) + fn split_handshake(&mut self) -> ([u8; SNOW_CIPHERKEYLEN], [u8; SNOW_CIPHERKEYLEN]) { + let (a, b) = self.state.dangerously_get_raw_split(); + if self.is_initiator { + (a, b) + } else { + (b, a) + } + } } struct InitiatorInitial; struct ResponderInitial; struct InitiatorInitialSent; -struct ResponderReplied; -struct InitiatorReady; -struct ResponderReady; +struct ResponderReplied { + rx: Key, + pusher: PushStream, + handshake_hash: Vec, +} + +struct InitiatorEnc { + pusher: PushStream, + rx: Key, + handshake_hash: Vec, +} +struct Ready { + puller: PullStream, + pusher: PushStream, +} impl IkSecretStream { fn new_initiator(config: InitiatorConfig) -> Result { @@ -37,10 +67,10 @@ impl IkSecretStream { .build_initiator()?; Ok(Self { - is_initiator: false, + is_initiator: true, state, msg_buf: [0; 1024], - _step: PhantomData, + _step: InitiatorInitial, }) } @@ -61,7 +91,7 @@ impl IkSecretStream { is_initiator, state, msg_buf, - _step: PhantomData, + _step: InitiatorInitialSent, }, msg, )) @@ -69,26 +99,43 @@ impl IkSecretStream { } impl IkSecretStream { - fn new_responder() -> Result { + fn new_responder(private: &[u8]) -> Result { let params: NoiseParams = PARAMS.parse().expect("known to work"); - let kp = Builder::new(params.clone()).generate_keypair()?; let state = Builder::new(params.clone()) - .local_private_key(&kp.private)? + .local_private_key(&private)? .build_responder()?; Ok(Self { is_initiator: false, state, msg_buf: [0; 1024], - _step: PhantomData, + _step: ResponderInitial, }) } fn read_first_msg( mut self, msg: &[u8], - ) -> Result<(IkSecretStream, Vec), Error> { - let len = dbg!(self.state.read_message(msg, &mut self.msg_buf))?; - let msg = self.msg_buf[..len].to_vec(); + ) -> Result<(IkSecretStream, [Vec; 2]), Error> { + self.state.read_message(msg, &mut self.msg_buf)?; + let len = self.state.write_message(&[], &mut self.msg_buf)?; + let hs_msg = self.msg_buf[..len].to_vec(); + assert!(self.state.is_handshake_finished()); + + let handshake_hash = self.state.get_handshake_hash().to_vec(); + let mut pull_stream_msg: [u8; RAW_HEADER_MSG_LEN] = [0; RAW_HEADER_MSG_LEN]; + // write stream id to front of pull_stream_msg + write_stream_id( + &handshake_hash, + self.is_initiator, + &mut pull_stream_msg[..STREAM_ID_LENGTH], + ); + + let (tx, rx) = self.split_handshake(); + let (header, pusher) = PushStream::init(OsRng, &Key::from(tx)); + + // write push header to back of pull_stream_msg + pull_stream_msg[STREAM_ID_LENGTH..].copy_from_slice(header.as_ref()); + let Self { is_initiator, state, @@ -100,50 +147,130 @@ impl IkSecretStream { is_initiator, state, msg_buf, - _step: PhantomData, + _step: ResponderReplied { + rx: Key::from(rx), + pusher, + handshake_hash, + }, }, - msg, + [hs_msg, pull_stream_msg.to_vec()], )) } } impl IkSecretStream { - fn receive_msg( - mut self, - msg: &[u8], - ) -> Result<(IkSecretStream, Vec), Error> { - // create encrypt cipher here - // send response - todo!() + fn receive_msg(mut self, msg: &[u8]) -> Result<(IkSecretStream, Vec), Error> { + self.state.read_message(msg, &mut self.msg_buf)?; + + let (tx, rx) = self.split_handshake(); + let key: [u8; SNOW_CIPHERKEYLEN] = tx[..SNOW_CIPHERKEYLEN] + .try_into() + .expect("split_tx with incorrect length"); + let key = Key::from(key); + let handshake_hash = self.state.get_handshake_hash().to_vec(); + let (header, pusher) = PushStream::init(OsRng, &key); + + let mut msg: [u8; RAW_HEADER_MSG_LEN] = [0; RAW_HEADER_MSG_LEN]; + // write stream id to front of msg + write_stream_id( + &handshake_hash, + self.is_initiator, + &mut msg[..STREAM_ID_LENGTH], + ); + // write push header to back of msg + msg[STREAM_ID_LENGTH..].copy_from_slice(header.as_ref()); + + let Self { + is_initiator, + state, + msg_buf, + .. + } = self; + Ok(( + IkSecretStream { + is_initiator, + state, + msg_buf, + _step: InitiatorEnc { + pusher, + rx: Key::from(rx), + handshake_hash, + }, + }, + msg.to_vec(), + )) } } +impl IkSecretStream { + fn receive_msg(self, msg: &[u8]) -> Result, Error> { + let Self { + is_initiator, + _step: + InitiatorEnc { + pusher, + rx, + handshake_hash, + }, + state, + msg_buf, + } = self; + // Read the received message from the other peer + let mut expected_stream_id: [u8; STREAM_ID_LENGTH] = [0; STREAM_ID_LENGTH]; + write_stream_id(&handshake_hash, !is_initiator, &mut expected_stream_id); + if expected_stream_id != msg[..32] { + panic!() + } -impl IkSecretStream { - fn receive_msg( - mut self, - msg: &[u8], - ) -> Result<(IkSecretStream, Vec), Error> { - // create DecryptCipher - todo!() + let header: [u8; 24] = msg[32..].try_into().expect("TODO wrong size"); + let puller = PullStream::init(header.into(), &rx); + Ok(IkSecretStream { + is_initiator, + state, + msg_buf, + _step: Ready { pusher, puller }, + }) } } -/* -impl IkSecretStream { +impl IkSecretStream { + fn receive_msg(self, msg: &[u8]) -> Result, Error> { + let Self { + is_initiator, + _step: + ResponderReplied { + pusher, + rx, + handshake_hash, + }, + state, + msg_buf, + } = self; + // Read the received message from the other peer + let mut expected_stream_id: [u8; STREAM_ID_LENGTH] = [0; STREAM_ID_LENGTH]; + write_stream_id(&handshake_hash, !is_initiator, &mut expected_stream_id); + if expected_stream_id != msg[..32] { + panic!() + } - fn make_first_msg(&mut self, prologue: &[u8]) -> Result, Error> { - is_initiator!(&self); - let len = self.state.write_message(prologue, &mut self.msg_buf)?; - Ok(self.msg_buf[..len].to_vec()) + let header: [u8; 24] = msg[32..].try_into().expect("TODO wrong size"); + let puller = PullStream::init(header.into(), &rx); + Ok(IkSecretStream { + is_initiator, + state, + msg_buf, + _step: Ready { puller, pusher }, + }) } +} - fn read_first_msg(&mut self, msg: &[u8]) -> Result, Error> { - is_not_initiator!(&self); - let len = self.state.read_message(msg, &mut self.msg_buf)?; - Ok(self.msg_buf[..len].to_vec()) +impl IkSecretStream { + fn push(&mut self, msg: &mut Vec, associated_data: &[u8], tag: Tag) -> Result<(), Error> { + Ok(self._step.pusher.push(msg, associated_data, tag)?) + } + fn pull(&mut self, msg: &mut Vec, associated_data: &[u8]) -> Result { + Ok(self._step.puller.pull(msg, associated_data)?) } } -*/ #[cfg(test)] mod test { @@ -154,21 +281,29 @@ mod test { let params: NoiseParams = PARAMS.parse().expect("known to work"); let kp = Builder::new(params.clone()).generate_keypair()?; let config = InitiatorConfig::new(kp.public.try_into().unwrap()); - dbg!(); - let mut init = IkSecretStream::new_initiator(config)?; - dbg!(); - let mut resp = IkSecretStream::new_responder()?; - dbg!(); + let init = IkSecretStream::new_initiator(config)?; + let resp = IkSecretStream::new_responder(&kp.private)?; let (init, msg) = init.make_first_msg(&[])?; - dbg!(); - let (resp, msg) = resp.read_first_msg(&msg)?; - dbg!(); + let (resp, [msg1, msg2]) = resp.read_first_msg(&msg)?; + let (init, to_resp) = init.receive_msg(&msg1)?; + let mut resp = resp.receive_msg(&to_resp)?; + let mut init = init.receive_msg(&msg2)?; + + let hello = b"hello".to_vec(); + let mut msg = hello.clone(); + + println!("msg {msg:?}"); + init.push(&mut msg, &[], Tag::Message)?; + println!("enc {msg:?}"); + let tag = resp.pull(&mut msg, &[])?; + println!("res {msg:?} tag: {tag:?}"); + assert_eq!(msg, hello); Ok(()) } #[tokio::test] async fn hs() -> std::result::Result<(), Box> { - let params: NoiseParams = "Noise_IK_25519_ChaChaPoly_BLAKE2b".parse()?; + let params: NoiseParams = PARAMS.parse()?; let initiator_kp = Builder::new(params.clone()).generate_keypair()?; let responder_kp = Builder::new(params.clone()).generate_keypair()?; @@ -187,22 +322,18 @@ mod test { // -> e, es, s, ss let first_len = initiator.write_message(&[], &mut first_msg)?; // responder processes the first message... - let read_len = responder.read_message(&first_msg[..first_len], &mut read_buf)?; - dbg!(&first_len); - dbg!(&read_buf[..read_len]); + responder.read_message(&first_msg[..first_len], &mut read_buf)?; // <- e, ee, se let second_len = responder.write_message(&[], &mut second_msg)?; + let mut resp_transport = responder.into_transport_mode()?; let _read_len = initiator.read_message(&second_msg[..second_len], &mut read_buf)?; - let mut resp_transport = responder.into_transport_mode()?; let mut init_transport = initiator.into_transport_mode()?; let msg = b"my message"; let elen = resp_transport.write_message(msg, &mut enc_buf)?; - println!("{:?}", &enc_buf[..elen]); - let rlen = init_transport.read_message(&enc_buf[..elen], &mut read_buf)?; - println!("{}", String::from_utf8_lossy(&read_buf[..rlen])); + init_transport.read_message(&enc_buf[..elen], &mut read_buf)?; Ok(()) } From 9c679a97bb31727291b784e7157280b8fc2394fe Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 29 Jul 2025 22:07:57 -0400 Subject: [PATCH 150/206] export write_stream_id --- src/crypto/cipher.rs | 3 ++- src/crypto/mod.rs | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index 98fde28..dee19f6 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -87,7 +87,8 @@ const NS_RESPONDER: [u8; 32] = [ 0xcf, 0x29, 0x54, 0x42, 0xbb, 0xca, 0x18, 0x99, 0x6e, 0x59, 0x70, 0x97, 0x72, 0x3b, 0x10, 0x61, ]; -fn write_stream_id(handshake_hash: &[u8], is_initiator: bool, out: &mut [u8]) { +/// write hash of handsdake_hash in a domain sep constant to out (32 bytes) +pub(crate) fn write_stream_id(handshake_hash: &[u8], is_initiator: bool, out: &mut [u8]) { let mut hasher = Blake2bMac::::new_with_salt_and_personal(handshake_hash, &[], &[]).unwrap(); if is_initiator { diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index f7c9c65..af87688 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -1,6 +1,7 @@ mod cipher; mod curve; mod handshake; +pub(crate) use cipher::write_stream_id; pub use cipher::{DecryptCipher, EncryptCipher}; pub(crate) use handshake::HandshakeResult; pub use handshake::{handshake_constants, Handshake, HandshakeConfig}; From 39c7daecbaf96c554ce9e1da74ab44ee052f6de1 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 29 Jul 2025 22:12:31 -0400 Subject: [PATCH 151/206] clean up sstream --- src/error.rs | 8 +++ src/lib.rs | 2 +- src/noise.rs | 2 +- src/sstream.rs | 143 ++++++++++++++++++++++++++-------------------- src/test_utils.rs | 2 +- 5 files changed, 91 insertions(+), 66 deletions(-) diff --git a/src/error.rs b/src/error.rs index 6c210ed..81764e3 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,4 +2,12 @@ pub enum Error { #[error("Error from `snow`: {0}")] Snow(#[from] snow::Error), + #[error("Error from `crypto_secretstream`: {0}")] + SecretStream(crypto_secretstream::aead::Error), +} + +impl From for Error { + fn from(value: crypto_secretstream::aead::Error) -> Self { + Error::SecretStream(value) + } } diff --git a/src/lib.rs b/src/lib.rs index 5c0dcf6..08518f0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -125,7 +125,7 @@ mod message; mod mqueue; mod noise; mod protocol; -mod sstream; +pub mod sstream; #[cfg(test)] mod test_utils; mod util; diff --git a/src/noise.rs b/src/noise.rs index a98108c..ec56ca6 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -514,7 +514,7 @@ mod test { panic!() } #[tokio::test] - async fn encrypted() -> Result<()> { + async fn encrypted_one() -> Result<()> { let hello = b"hello".to_vec(); let world = b"world".to_vec(); let (lc, rc) = create_result_connected(); diff --git a/src/sstream.rs b/src/sstream.rs index 1934096..3ca028b 100644 --- a/src/sstream.rs +++ b/src/sstream.rs @@ -1,7 +1,8 @@ -#![allow(dead_code)] +//! Create a sodium secret stream using the IK pattern use crypto_secretstream::{Header, Key, PullStream, PushStream, Tag}; use rand::rngs::OsRng; use snow::{params::NoiseParams, Builder, HandshakeState}; +use std::fmt::Debug; use crate::{crypto::write_stream_id, Error}; @@ -9,17 +10,24 @@ const PARAMS: &str = "Noise_IK_25519_ChaChaPoly_BLAKE2b"; const STREAM_ID_LENGTH: usize = 32; const RAW_HEADER_MSG_LEN: usize = STREAM_ID_LENGTH + Header::BYTES; const SNOW_CIPHERKEYLEN: usize = 32; +const PUBLIC_KEYLEN: usize = 32; -struct InitiatorConfig { - remote_public_key: [u8; 32], +#[derive(Debug)] +/// Data for creating an initiator +pub struct InitiatorConfig { + remote_public_key: [u8; PUBLIC_KEYLEN], } + impl InitiatorConfig { - fn new(remote_public_key: [u8; 32]) -> Self { + /// Create a new [`InitiatorConfig`] + pub fn new(remote_public_key: [u8; PUBLIC_KEYLEN]) -> Self { Self { remote_public_key } } } -struct IkSecretStream { +#[derive(Debug)] +/// Secret Stream protocol state +pub struct IkSecretStream { is_initiator: bool, state: HandshakeState, msg_buf: [u8; 1024], @@ -27,8 +35,8 @@ struct IkSecretStream { } impl IkSecretStream { - // split handshake into (tx, rx) - fn split_handshake(&mut self) -> ([u8; SNOW_CIPHERKEYLEN], [u8; SNOW_CIPHERKEYLEN]) { + /// split handshake into (tx, rx) + pub fn split_handshake(&mut self) -> ([u8; SNOW_CIPHERKEYLEN], [u8; SNOW_CIPHERKEYLEN]) { let (a, b) = self.state.dangerously_get_raw_split(); if self.is_initiator { (a, b) @@ -38,27 +46,51 @@ impl IkSecretStream { } } -struct InitiatorInitial; -struct ResponderInitial; -struct InitiatorInitialSent; -struct ResponderReplied { +/// Initial initiator state +#[derive(Debug)] +pub struct InitiatorInitial; + +/// Initial responder state +#[derive(Debug)] +pub struct ResponderInitial; + +/// Initiator has sent the first message +#[derive(Debug)] +pub struct InitiatorInitialSent; + +/// No decryptor yet +pub struct EncryptorReady { rx: Key, pusher: PushStream, handshake_hash: Vec, } -struct InitiatorEnc { - pusher: PushStream, - rx: Key, - handshake_hash: Vec, -} -struct Ready { +/// Encryptor and decryptor +pub struct Ready { puller: PullStream, pusher: PushStream, } +impl Debug for EncryptorReady { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("InitiatorEnc") + .field("rx", &"Key(..)") + .field("pusher", &"PushStream(..)") + .field("handshake_hash", &self.handshake_hash) + .finish() + } +} +impl Debug for Ready { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Ready") + .field("pusher", &"PushStream(..)") + .field("puller", &"PullStream(..)") + .finish() + } +} impl IkSecretStream { - fn new_initiator(config: InitiatorConfig) -> Result { + /// Create an initiator of a secret stream + pub fn new_initiator(config: InitiatorConfig) -> Result { let params: NoiseParams = PARAMS.parse().expect("known to work"); let kp = Builder::new(params.clone()).generate_keypair()?; let state = Builder::new(params.clone()) @@ -74,7 +106,8 @@ impl IkSecretStream { }) } - fn make_first_msg( + /// Create the first message the initiator sends to the responder + pub fn make_first_msg( mut self, prologue: &[u8], ) -> Result<(IkSecretStream, Vec), Error> { @@ -99,7 +132,8 @@ impl IkSecretStream { } impl IkSecretStream { - fn new_responder(private: &[u8]) -> Result { + /// Create a responder of a secret stream + pub fn new_responder(private: &[u8]) -> Result { let params: NoiseParams = PARAMS.parse().expect("known to work"); let state = Builder::new(params.clone()) .local_private_key(&private)? @@ -112,10 +146,11 @@ impl IkSecretStream { }) } - fn read_first_msg( + /// Read the first message of the protocol, create the next two messages to send to the initiator. + pub fn read_first_msg( mut self, msg: &[u8], - ) -> Result<(IkSecretStream, [Vec; 2]), Error> { + ) -> Result<(IkSecretStream, [Vec; 2]), Error> { self.state.read_message(msg, &mut self.msg_buf)?; let len = self.state.write_message(&[], &mut self.msg_buf)?; let hs_msg = self.msg_buf[..len].to_vec(); @@ -147,7 +182,7 @@ impl IkSecretStream { is_initiator, state, msg_buf, - _step: ResponderReplied { + _step: EncryptorReady { rx: Key::from(rx), pusher, handshake_hash, @@ -159,7 +194,11 @@ impl IkSecretStream { } impl IkSecretStream { - fn receive_msg(mut self, msg: &[u8]) -> Result<(IkSecretStream, Vec), Error> { + /// Recieve the last message to complet the handsake + pub fn receive_msg( + mut self, + msg: &[u8], + ) -> Result<(IkSecretStream, Vec), Error> { self.state.read_message(msg, &mut self.msg_buf)?; let (tx, rx) = self.split_handshake(); @@ -191,7 +230,7 @@ impl IkSecretStream { is_initiator, state, msg_buf, - _step: InitiatorEnc { + _step: EncryptorReady { pusher, rx: Key::from(rx), handshake_hash, @@ -201,12 +240,13 @@ impl IkSecretStream { )) } } -impl IkSecretStream { - fn receive_msg(self, msg: &[u8]) -> Result, Error> { +impl IkSecretStream { + /// Recieve message the last message, used to set up the decryption stream + pub fn receive_msg(self, msg: &[u8]) -> Result, Error> { let Self { is_initiator, _step: - InitiatorEnc { + EncryptorReady { pusher, rx, handshake_hash, @@ -217,11 +257,12 @@ impl IkSecretStream { // Read the received message from the other peer let mut expected_stream_id: [u8; STREAM_ID_LENGTH] = [0; STREAM_ID_LENGTH]; write_stream_id(&handshake_hash, !is_initiator, &mut expected_stream_id); - if expected_stream_id != msg[..32] { + if expected_stream_id != msg[..STREAM_ID_LENGTH] { panic!() } - let header: [u8; 24] = msg[32..].try_into().expect("TODO wrong size"); + let header: [u8; Header::BYTES] = + msg[STREAM_ID_LENGTH..].try_into().expect("TODO wrong size"); let puller = PullStream::init(header.into(), &rx); Ok(IkSecretStream { is_initiator, @@ -232,42 +273,18 @@ impl IkSecretStream { } } -impl IkSecretStream { - fn receive_msg(self, msg: &[u8]) -> Result, Error> { - let Self { - is_initiator, - _step: - ResponderReplied { - pusher, - rx, - handshake_hash, - }, - state, - msg_buf, - } = self; - // Read the received message from the other peer - let mut expected_stream_id: [u8; STREAM_ID_LENGTH] = [0; STREAM_ID_LENGTH]; - write_stream_id(&handshake_hash, !is_initiator, &mut expected_stream_id); - if expected_stream_id != msg[..32] { - panic!() - } - - let header: [u8; 24] = msg[32..].try_into().expect("TODO wrong size"); - let puller = PullStream::init(header.into(), &rx); - Ok(IkSecretStream { - is_initiator, - state, - msg_buf, - _step: Ready { puller, pusher }, - }) - } -} - impl IkSecretStream { - fn push(&mut self, msg: &mut Vec, associated_data: &[u8], tag: Tag) -> Result<(), Error> { + /// Encrypt a message in place + pub fn push( + &mut self, + msg: &mut Vec, + associated_data: &[u8], + tag: Tag, + ) -> Result<(), Error> { Ok(self._step.pusher.push(msg, associated_data, tag)?) } - fn pull(&mut self, msg: &mut Vec, associated_data: &[u8]) -> Result { + /// Decrypt a message in place + pub fn pull(&mut self, msg: &mut Vec, associated_data: &[u8]) -> Result { Ok(self._step.puller.pull(msg, associated_data)?) } } diff --git a/src/test_utils.rs b/src/test_utils.rs index 2e5e994..ac80d55 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -87,7 +87,7 @@ pub(crate) fn log() { .with_bracketed_fields(true) .with_indent_lines(true) .with_thread_ids(false) - .with_thread_names(true) + //.with_thread_names(true) //.with_span_modes(true) ; From b21e02d4b8d2592d59807b9faf67c689564520a2 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 29 Jul 2025 23:13:58 -0400 Subject: [PATCH 152/206] clippy --fix --- benches/throughput.rs | 2 +- src/channels.rs | 2 +- src/crypto/cipher.rs | 4 ++-- src/crypto/handshake.rs | 4 ++-- src/noise.rs | 2 +- src/sstream.rs | 2 +- src/test_utils.rs | 6 +++--- tests/_util.rs | 2 +- tests/js/mod.rs | 14 ++++---------- tests/js_interop.rs | 10 +++++----- 10 files changed, 21 insertions(+), 27 deletions(-) diff --git a/benches/throughput.rs b/benches/throughput.rs index b19167e..f14d7a0 100644 --- a/benches/throughput.rs +++ b/benches/throughput.rs @@ -21,7 +21,7 @@ const CLIENTS: usize = 1; fn bench_throughput(c: &mut Criterion) { test_utils::log(); - let address = format!("localhost:{}", PORT); + let address = format!("localhost:{PORT}"); let mut group = c.benchmark_group("throughput"); let data = vec![1u8; SIZE as usize]; diff --git a/src/channels.rs b/src/channels.rs index f16ac7f..64751ab 100644 --- a/src/channels.rs +++ b/src/channels.rs @@ -512,5 +512,5 @@ impl ChannelMap { } fn error(message: &str) -> Error { - Error::new(ErrorKind::Other, message) + Error::other(message) } diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index dee19f6..4fb42c9 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -67,7 +67,7 @@ impl DecryptCipher { pub(crate) fn decrypt_buf(&mut self, buf: &[u8]) -> io::Result<(Vec, Tag)> { let mut to_decrypt = buf.to_vec(); let tag = &self.pull_stream.pull(&mut to_decrypt, &[]).map_err(|err| { - io::Error::new(io::ErrorKind::Other, format!("Decrypt failed: {err}")) + io::Error::other(format!("Decrypt failed: {err}")) })?; Ok((to_decrypt, *tag)) } @@ -147,7 +147,7 @@ impl EncryptCipher { self.push_stream .push(&mut out, &[], Tag::Message) .map_err(|err| { - io::Error::new(io::ErrorKind::Other, format!("Encrypt failed: {err}")) + io::Error::other(format!("Encrypt failed: {err}")) })?; Ok(out) } diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index 9c130c4..8c7ccfd 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -183,7 +183,7 @@ impl Handshake { pub fn read_raw(&mut self, msg: &[u8]) -> Result>> { // eprintln!("hs read len {}", msg.len()); if self.complete() { - return Err(Error::new(ErrorKind::Other, "Handshake read after finish")); + return Err(Error::other("Handshake read after finish")); } let rx_len = self.recv(msg)?; @@ -231,7 +231,7 @@ impl Handshake { /// get the handshake result when it is completed pub fn get_result(&self) -> Result<&HandshakeResult> { if !self.complete() { - Err(Error::new(ErrorKind::Other, "Handshake is not complete")) + Err(Error::other("Handshake is not complete")) } else { Ok(&self.result) } diff --git a/src/noise.rs b/src/noise.rs index ec56ca6..5a76569 100644 --- a/src/noise.rs +++ b/src/noise.rs @@ -45,7 +45,7 @@ impl std::fmt::Display for Step { Step::SecretStream(_) => "SecretStream", Step::Established(_) => "Established", }; - write!(f, "{}", x) + write!(f, "{x}") } } diff --git a/src/sstream.rs b/src/sstream.rs index 3ca028b..bc7a058 100644 --- a/src/sstream.rs +++ b/src/sstream.rs @@ -136,7 +136,7 @@ impl IkSecretStream { pub fn new_responder(private: &[u8]) -> Result { let params: NoiseParams = PARAMS.parse().expect("known to work"); let state = Builder::new(params.clone()) - .local_private_key(&private)? + .local_private_key(private)? .build_responder()?; Ok(Self { is_initiator: false, diff --git a/src/test_utils.rs b/src/test_utils.rs index ac80d55..8afa83e 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -1,6 +1,6 @@ #![allow(dead_code)] use std::{ - io::{self, ErrorKind}, + io::{self}, pin::Pin, task::{Context, Poll}, }; @@ -41,7 +41,7 @@ impl Sink> for Io { fn start_send(mut self: Pin<&mut Self>, item: Vec) -> Result<(), Self::Error> { Pin::new(&mut self.sender) .start_send(item) - .map_err(|_e| io::Error::new(ErrorKind::Other, "SendError")) + .map_err(|_e| io::Error::other("SendError")) } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { @@ -125,7 +125,7 @@ impl + Unpin> Sink let this = self.get_mut(); Pin::new(&mut this.sender) .start_send(item) - .map_err(|_e| io::Error::new(ErrorKind::Other, "SendError")) + .map_err(|_e| io::Error::other("SendError")) } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { diff --git a/tests/_util.rs b/tests/_util.rs index 78c89e4..fa5e406 100644 --- a/tests/_util.rs +++ b/tests/_util.rs @@ -93,7 +93,7 @@ pub async fn wait_for_localhost_port(port: u32) { loop { let timeout = async_std::future::timeout( Duration::from_millis(NO_RESPONSE_TIMEOUT), - TcpStream::connect(format!("localhost:{}", port)), + TcpStream::connect(format!("localhost:{port}")), ) .await; if timeout.is_err() { diff --git a/tests/js/mod.rs b/tests/js/mod.rs index b8cd6ec..39d83de 100644 --- a/tests/js/mod.rs +++ b/tests/js/mod.rs @@ -43,9 +43,9 @@ pub fn install() { } pub fn prepare_test_set(test_set: &str) -> (String, String, String) { - let path_result = format!("tests/js/work/{}/result.txt", test_set); - let path_writer = format!("tests/js/work/{}/writer", test_set); - let path_reader = format!("tests/js/work/{}/reader", test_set); + let path_result = format!("tests/js/work/{test_set}/result.txt"); + let path_writer = format!("tests/js/work/{test_set}/writer"); + let path_reader = format!("tests/js/work/{test_set}/reader"); create_dir_all(&path_writer).expect("Unable to create work writer directory"); create_dir_all(&path_reader).expect("Unable to create work reader directory"); (path_result, path_writer, path_reader) @@ -100,13 +100,7 @@ impl JavascriptServer { assert_eq!( Some(0), code, - "node server did not exit successfully, is_writer={}, port={}, data_count={}, data_size={}, data_char={}, test_set={}", - is_writer, - port, - data_count, - data_size, - data_char, - test_set, + "node server did not exit successfully, is_writer={is_writer}, port={port}, data_count={data_count}, data_size={data_size}, data_char={data_char}, test_set={test_set}", ); })); wait_for_localhost_port(port).await; diff --git a/tests/js_interop.rs b/tests/js_interop.rs index fe19db3..4a1a3d8 100644 --- a/tests/js_interop.rs +++ b/tests/js_interop.rs @@ -314,7 +314,7 @@ async fn assert_result( let expected_value = data_char.to_string().repeat(item_size); let mut line = String::new(); while reader.read_line(&mut line).await? != 0 { - assert_eq!(line, format!("{} {}\n", i, expected_value)); + assert_eq!(line, format!("{i} {expected_value}\n")); i += 1; line = String::new(); } @@ -737,7 +737,7 @@ async fn on_replication_message( let mut writer = BufWriter::new(File::create(result_path).await?); for i in 0..new_info.contiguous_length { let value = String::from_utf8(hypercore.get(i).await?.unwrap()).unwrap(); - let line = format!("{} {}\n", i, value); + let line = format!("{i} {value}\n"); let n_written = writer.write(line.as_bytes()).await?; if line.len() != n_written { panic!("Couldn't write all write all bytse"); @@ -770,7 +770,7 @@ async fn on_replication_message( } } _ => { - panic!("Received unexpected message {:?}", message); + panic!("Received unexpected message {message:?}"); } }; Ok(false) @@ -831,7 +831,7 @@ where F: Future> + Send, C: Clone + Send + 'static, { - let listener = TcpListener::bind(&format!("localhost:{}", port)).await?; + let listener = TcpListener::bind(&format!("localhost:{port}")).await?; while let Ok((stream, _peer_address)) = listener.accept().await { let context = context.clone(); @@ -853,6 +853,6 @@ where F: Future> + Send, C: Clone + Send + 'static, { - let stream = TcpStream::connect(&format!("localhost:{}", port)).await?; + let stream = TcpStream::connect(&format!("localhost:{port}")).await?; onconnection(stream, true, context).await } From a289805d8ace7c6edcf65e53fb6e9d970ce5a35d Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 31 Jul 2025 00:03:01 -0400 Subject: [PATCH 153/206] split out initiator payload recv --- src/sstream.rs | 228 +++++++++++++++++++++++++++++++------------------ 1 file changed, 144 insertions(+), 84 deletions(-) diff --git a/src/sstream.rs b/src/sstream.rs index bc7a058..671da8e 100644 --- a/src/sstream.rs +++ b/src/sstream.rs @@ -1,8 +1,10 @@ -//! Create a sodium secret stream using the IK pattern +//! Create a sodium secret stream using the Noise IK pattern. +//! We use the "Typestate pattern" for the steps of the handshake. +//! Attempt at a diagram use crypto_secretstream::{Header, Key, PullStream, PushStream, Tag}; use rand::rngs::OsRng; use snow::{params::NoiseParams, Builder, HandshakeState}; -use std::fmt::Debug; +use std::{fmt::Debug, marker::PhantomData}; use crate::{crypto::write_stream_id, Error}; @@ -12,8 +14,8 @@ const RAW_HEADER_MSG_LEN: usize = STREAM_ID_LENGTH + Header::BYTES; const SNOW_CIPHERKEYLEN: usize = 32; const PUBLIC_KEYLEN: usize = 32; -#[derive(Debug)] /// Data for creating an initiator +#[derive(Debug)] pub struct InitiatorConfig { remote_public_key: [u8; PUBLIC_KEYLEN], } @@ -25,16 +27,16 @@ impl InitiatorConfig { } } -#[derive(Debug)] /// Secret Stream protocol state -pub struct IkSecretStream { +#[derive(Debug)] +pub struct SecStream { is_initiator: bool, state: HandshakeState, msg_buf: [u8; 1024], - _step: Step, + step: Step, } -impl IkSecretStream { +impl SecStream { /// split handshake into (tx, rx) pub fn split_handshake(&mut self) -> ([u8; SNOW_CIPHERKEYLEN], [u8; SNOW_CIPHERKEYLEN]) { let (a, b) = self.state.dangerously_get_raw_split(); @@ -49,14 +51,23 @@ impl IkSecretStream { /// Initial initiator state #[derive(Debug)] pub struct InitiatorInitial; - -/// Initial responder state +/// Initiator has sent the first message #[derive(Debug)] -pub struct ResponderInitial; +pub struct InitiatorInitialSent { + _res_step: PhantomData, +} -/// Initiator has sent the first message +/// Initial responder state +/// This first is before it receives the first message. +/// The second is after it reads it and gets the payload, but before creating the encyptor and +/// emitting the next message. This distinction is necessary so we can handle the received payload +/// and send a new one #[derive(Debug)] -pub struct InitiatorInitialSent; +pub struct Responder { + _res_step: PhantomData, +} +struct One; +struct Two; /// No decryptor yet pub struct EncryptorReady { @@ -88,7 +99,7 @@ impl Debug for Ready { } } -impl IkSecretStream { +impl SecStream { /// Create an initiator of a secret stream pub fn new_initiator(config: InitiatorConfig) -> Result { let params: NoiseParams = PARAMS.parse().expect("known to work"); @@ -102,7 +113,7 @@ impl IkSecretStream { is_initiator: true, state, msg_buf: [0; 1024], - _step: InitiatorInitial, + step: InitiatorInitial, }) } @@ -110,7 +121,7 @@ impl IkSecretStream { pub fn make_first_msg( mut self, prologue: &[u8], - ) -> Result<(IkSecretStream, Vec), Error> { + ) -> Result<(SecStream>, Vec), Error> { let len = self.state.write_message(prologue, &mut self.msg_buf)?; let msg = self.msg_buf[..len].to_vec(); let Self { @@ -120,18 +131,20 @@ impl IkSecretStream { .. } = self; Ok(( - IkSecretStream { + SecStream { is_initiator, state, msg_buf, - _step: InitiatorInitialSent, + step: InitiatorInitialSent { + _res_step: PhantomData, + }, }, msg, )) } } -impl IkSecretStream { +impl SecStream> { /// Create a responder of a secret stream pub fn new_responder(private: &[u8]) -> Result { let params: NoiseParams = PARAMS.parse().expect("known to work"); @@ -142,17 +155,55 @@ impl IkSecretStream { is_initiator: false, state, msg_buf: [0; 1024], - _step: ResponderInitial, + step: Responder { + _res_step: PhantomData, + }, }) } + /// Read msg and return it's payload + pub fn read_first_msg_get_payload( + mut self, + msg: &[u8], + ) -> Result<(SecStream>, Vec), Error> { + let len = self.state.read_message(msg, &mut self.msg_buf)?; + let payload = &self.msg_buf[..len]; + let Self { + is_initiator, + state, + msg_buf, + .. + } = self; + Ok(( + SecStream { + is_initiator, + state, + msg_buf, + step: Responder { + _res_step: PhantomData, + }, + }, + payload.to_vec(), + )) + } /// Read the first message of the protocol, create the next two messages to send to the initiator. pub fn read_first_msg( - mut self, + self, msg: &[u8], - ) -> Result<(IkSecretStream, [Vec; 2]), Error> { - self.state.read_message(msg, &mut self.msg_buf)?; - let len = self.state.write_message(&[], &mut self.msg_buf)?; + ) -> Result<(SecStream, [Vec; 2]), Error> { + let (self2, _rx_payload) = self.read_first_msg_get_payload(msg)?; + self2.make_second_msg(&[]) + } +} + +impl SecStream> { + /// Make second message with the given payload. Returns two messages, the first completes the + /// Noise handshake. The second has the shared key for the remote to set up a Decryptor. + pub fn make_second_msg( + mut self, + payload: &[u8], + ) -> Result<(SecStream, [Vec; 2]), Error> { + let len = self.state.write_message(payload, &mut self.msg_buf)?; let hs_msg = self.msg_buf[..len].to_vec(); assert!(self.state.is_handshake_finished()); @@ -178,11 +229,11 @@ impl IkSecretStream { .. } = self; Ok(( - IkSecretStream { + SecStream { is_initiator, state, msg_buf, - _step: EncryptorReady { + step: EncryptorReady { rx: Key::from(rx), pusher, handshake_hash, @@ -193,14 +244,44 @@ impl IkSecretStream { } } -impl IkSecretStream { +impl SecStream> { /// Recieve the last message to complet the handsake + pub fn read_first_msg_get_payload( + mut self, + msg: &[u8], + ) -> Result<(SecStream>, Vec), Error> { + let len = self.state.read_message(msg, &mut self.msg_buf)?; + let payload = &self.msg_buf[..len]; + let Self { + is_initiator, + state, + msg_buf, + .. + } = self; + Ok(( + SecStream { + is_initiator, + state, + msg_buf, + step: InitiatorInitialSent { + _res_step: PhantomData, + }, + }, + payload.to_vec(), + )) + } + pub fn receive_msg( mut self, msg: &[u8], - ) -> Result<(IkSecretStream, Vec), Error> { - self.state.read_message(msg, &mut self.msg_buf)?; + ) -> Result<(SecStream, Vec), Error> { + let (mut self2, _payload) = self.read_first_msg_get_payload(msg)?; + self2.make_msg() + } +} +impl SecStream> { + fn make_msg(mut self) -> Result<(SecStream, Vec), Error> { let (tx, rx) = self.split_handshake(); let key: [u8; SNOW_CIPHERKEYLEN] = tx[..SNOW_CIPHERKEYLEN] .try_into() @@ -219,18 +300,18 @@ impl IkSecretStream { // write push header to back of msg msg[STREAM_ID_LENGTH..].copy_from_slice(header.as_ref()); - let Self { + let SecStream { is_initiator, state, msg_buf, .. } = self; Ok(( - IkSecretStream { + SecStream { is_initiator, state, msg_buf, - _step: EncryptorReady { + step: EncryptorReady { pusher, rx: Key::from(rx), handshake_hash, @@ -240,12 +321,13 @@ impl IkSecretStream { )) } } -impl IkSecretStream { + +impl SecStream { /// Recieve message the last message, used to set up the decryption stream - pub fn receive_msg(self, msg: &[u8]) -> Result, Error> { + pub fn receive_msg(self, msg: &[u8]) -> Result, Error> { let Self { is_initiator, - _step: + step: EncryptorReady { pusher, rx, @@ -264,16 +346,16 @@ impl IkSecretStream { let header: [u8; Header::BYTES] = msg[STREAM_ID_LENGTH..].try_into().expect("TODO wrong size"); let puller = PullStream::init(header.into(), &rx); - Ok(IkSecretStream { + Ok(SecStream { is_initiator, state, msg_buf, - _step: Ready { pusher, puller }, + step: Ready { pusher, puller }, }) } } -impl IkSecretStream { +impl SecStream { /// Encrypt a message in place pub fn push( &mut self, @@ -281,11 +363,11 @@ impl IkSecretStream { associated_data: &[u8], tag: Tag, ) -> Result<(), Error> { - Ok(self._step.pusher.push(msg, associated_data, tag)?) + Ok(self.step.pusher.push(msg, associated_data, tag)?) } /// Decrypt a message in place pub fn pull(&mut self, msg: &mut Vec, associated_data: &[u8]) -> Result { - Ok(self._step.puller.pull(msg, associated_data)?) + Ok(self.step.puller.pull(msg, associated_data)?) } } @@ -294,17 +376,32 @@ mod test { use super::*; #[tokio::test] - async fn foosstream() -> std::result::Result<(), Box> { + async fn set_up_secret_steram() -> std::result::Result<(), Box> { + /// Excessive typing to demonstrate flow through typestates let params: NoiseParams = PARAMS.parse().expect("known to work"); let kp = Builder::new(params.clone()).generate_keypair()?; let config = InitiatorConfig::new(kp.public.try_into().unwrap()); - let init = IkSecretStream::new_initiator(config)?; - let resp = IkSecretStream::new_responder(&kp.private)?; - let (init, msg) = init.make_first_msg(&[])?; - let (resp, [msg1, msg2]) = resp.read_first_msg(&msg)?; - let (init, to_resp) = init.receive_msg(&msg1)?; - let mut resp = resp.receive_msg(&to_resp)?; - let mut init = init.receive_msg(&msg2)?; + let init: SecStream = SecStream::new_initiator(config)?; + let resp: SecStream> = SecStream::new_responder(&kp.private)?; + + let (init, msg): (SecStream>, Vec) = + init.make_first_msg(b"hello")?; + + let (resp, payload): (SecStream>, Vec) = + resp.read_first_msg_get_payload(&msg)?; + assert_eq!(payload, b"hello"); + + let payload2 = b"goodbye"; + let (resp, [msg1, msg2]): (SecStream, [Vec; 2]) = + resp.make_second_msg(payload2)?; + + let (init, payload_recv) = init.read_first_msg_get_payload(&msg1)?; + assert_eq!(payload_recv, b"goodbye"); + + let (init, to_resp): (SecStream, Vec) = init.make_msg()?; + + let mut init: SecStream = init.receive_msg(&msg2)?; + let mut resp: SecStream = resp.receive_msg(&to_resp)?; let hello = b"hello".to_vec(); let mut msg = hello.clone(); @@ -317,41 +414,4 @@ mod test { assert_eq!(msg, hello); Ok(()) } - - #[tokio::test] - async fn hs() -> std::result::Result<(), Box> { - let params: NoiseParams = PARAMS.parse()?; - - let initiator_kp = Builder::new(params.clone()).generate_keypair()?; - let responder_kp = Builder::new(params.clone()).generate_keypair()?; - - let mut initiator = Builder::new(params.clone()) - .local_private_key(&initiator_kp.private)? - .remote_public_key(&responder_kp.public)? - .build_initiator()?; - - let mut responder = Builder::new(params.clone()) - .local_private_key(&responder_kp.private)? - .build_responder()?; - let (mut read_buf, mut first_msg, mut second_msg, mut enc_buf) = - ([0u8; 1024], [0u8; 1024], [0u8; 1024], [0u8; 1024]); - - // -> e, es, s, ss - let first_len = initiator.write_message(&[], &mut first_msg)?; - // responder processes the first message... - responder.read_message(&first_msg[..first_len], &mut read_buf)?; - - // <- e, ee, se - let second_len = responder.write_message(&[], &mut second_msg)?; - let mut resp_transport = responder.into_transport_mode()?; - let _read_len = initiator.read_message(&second_msg[..second_len], &mut read_buf)?; - - let mut init_transport = initiator.into_transport_mode()?; - - let msg = b"my message"; - let elen = resp_transport.write_message(msg, &mut enc_buf)?; - init_transport.read_message(&enc_buf[..elen], &mut read_buf)?; - - Ok(()) - } } From 6658eb73710b1883f2a47879435755ad3ac8a7d5 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Thu, 31 Jul 2025 00:52:18 -0400 Subject: [PATCH 154/206] shorter names, more typestate --- src/sstream.rs | 104 +++++++++++++++++++++++++++--------------------- tests/js/mod.rs | 4 +- 2 files changed, 60 insertions(+), 48 deletions(-) diff --git a/src/sstream.rs b/src/sstream.rs index 671da8e..4a8b655 100644 --- a/src/sstream.rs +++ b/src/sstream.rs @@ -48,12 +48,9 @@ impl SecStream { } } -/// Initial initiator state +/// Initiator #[derive(Debug)] -pub struct InitiatorInitial; -/// Initiator has sent the first message -#[derive(Debug)] -pub struct InitiatorInitialSent { +pub struct Initiator { _res_step: PhantomData, } @@ -63,11 +60,17 @@ pub struct InitiatorInitialSent { /// emitting the next message. This distinction is necessary so we can handle the received payload /// and send a new one #[derive(Debug)] -pub struct Responder { - _res_step: PhantomData, +pub struct Responder { + _res_step: PhantomData, } -struct One; -struct Two; +/// The first step. We must send or receive a handshake message to proceed. +struct Start; +/// The handshake message has been sent. We must receive a handshake message to proceed to +/// [`HsDone`]. Only on [`Initiator`]. +struct HsMsgSent; +/// [`snow::HandshakeState::is_handshake_finished`] is `true`. +/// We are ready create a [`PushStream`] and proeed to [`EncryptorReady`]. +struct HsDone; /// No decryptor yet pub struct EncryptorReady { @@ -99,7 +102,7 @@ impl Debug for Ready { } } -impl SecStream { +impl SecStream> { /// Create an initiator of a secret stream pub fn new_initiator(config: InitiatorConfig) -> Result { let params: NoiseParams = PARAMS.parse().expect("known to work"); @@ -113,15 +116,17 @@ impl SecStream { is_initiator: true, state, msg_buf: [0; 1024], - step: InitiatorInitial, + step: Initiator { + _res_step: PhantomData, + }, }) } /// Create the first message the initiator sends to the responder - pub fn make_first_msg( + pub fn write_msg( mut self, prologue: &[u8], - ) -> Result<(SecStream>, Vec), Error> { + ) -> Result<(SecStream>, Vec), Error> { let len = self.state.write_message(prologue, &mut self.msg_buf)?; let msg = self.msg_buf[..len].to_vec(); let Self { @@ -135,7 +140,7 @@ impl SecStream { is_initiator, state, msg_buf, - step: InitiatorInitialSent { + step: Initiator { _res_step: PhantomData, }, }, @@ -144,7 +149,7 @@ impl SecStream { } } -impl SecStream> { +impl SecStream> { /// Create a responder of a secret stream pub fn new_responder(private: &[u8]) -> Result { let params: NoiseParams = PARAMS.parse().expect("known to work"); @@ -162,10 +167,10 @@ impl SecStream> { } /// Read msg and return it's payload - pub fn read_first_msg_get_payload( + pub fn read_msg( mut self, msg: &[u8], - ) -> Result<(SecStream>, Vec), Error> { + ) -> Result<(SecStream>, Vec), Error> { let len = self.state.read_message(msg, &mut self.msg_buf)?; let payload = &self.msg_buf[..len]; let Self { @@ -187,19 +192,19 @@ impl SecStream> { )) } /// Read the first message of the protocol, create the next two messages to send to the initiator. - pub fn read_first_msg( + pub fn read_and_write_msg( self, msg: &[u8], ) -> Result<(SecStream, [Vec; 2]), Error> { - let (self2, _rx_payload) = self.read_first_msg_get_payload(msg)?; - self2.make_second_msg(&[]) + let (self2, _rx_payload) = self.read_msg(msg)?; + self2.write_msg(&[]) } } -impl SecStream> { +impl SecStream> { /// Make second message with the given payload. Returns two messages, the first completes the /// Noise handshake. The second has the shared key for the remote to set up a Decryptor. - pub fn make_second_msg( + pub fn write_msg( mut self, payload: &[u8], ) -> Result<(SecStream, [Vec; 2]), Error> { @@ -244,12 +249,12 @@ impl SecStream> { } } -impl SecStream> { +impl SecStream> { /// Recieve the last message to complet the handsake - pub fn read_first_msg_get_payload( + pub fn read_msg( mut self, msg: &[u8], - ) -> Result<(SecStream>, Vec), Error> { + ) -> Result<(SecStream>, Vec), Error> { let len = self.state.read_message(msg, &mut self.msg_buf)?; let payload = &self.msg_buf[..len]; let Self { @@ -263,7 +268,7 @@ impl SecStream> { is_initiator, state, msg_buf, - step: InitiatorInitialSent { + step: Initiator { _res_step: PhantomData, }, }, @@ -271,17 +276,19 @@ impl SecStream> { )) } - pub fn receive_msg( - mut self, + /// read in a message, and write the next message. Any payload in the recieved message is + /// dropped. + pub fn read_and_write_msg( + self, msg: &[u8], ) -> Result<(SecStream, Vec), Error> { - let (mut self2, _payload) = self.read_first_msg_get_payload(msg)?; - self2.make_msg() + let (self2, _payload) = self.read_msg(msg)?; + self2.write_msg() } } -impl SecStream> { - fn make_msg(mut self) -> Result<(SecStream, Vec), Error> { +impl SecStream> { + fn write_msg(mut self) -> Result<(SecStream, Vec), Error> { let (tx, rx) = self.split_handshake(); let key: [u8; SNOW_CIPHERKEYLEN] = tx[..SNOW_CIPHERKEYLEN] .try_into() @@ -324,7 +331,7 @@ impl SecStream> { impl SecStream { /// Recieve message the last message, used to set up the decryption stream - pub fn receive_msg(self, msg: &[u8]) -> Result, Error> { + pub fn read_msg(self, msg: &[u8]) -> Result, Error> { let Self { is_initiator, step: @@ -377,31 +384,36 @@ mod test { #[tokio::test] async fn set_up_secret_steram() -> std::result::Result<(), Box> { - /// Excessive typing to demonstrate flow through typestates + // Excessive typing to demonstrate flow through typestates let params: NoiseParams = PARAMS.parse().expect("known to work"); let kp = Builder::new(params.clone()).generate_keypair()?; let config = InitiatorConfig::new(kp.public.try_into().unwrap()); - let init: SecStream = SecStream::new_initiator(config)?; - let resp: SecStream> = SecStream::new_responder(&kp.private)?; + // Create an initiator and responder + let init: SecStream> = SecStream::new_initiator(config)?; + let resp: SecStream> = SecStream::new_responder(&kp.private)?; - let (init, msg): (SecStream>, Vec) = - init.make_first_msg(b"hello")?; + // initiator sends the first handshake message, a payload can be included to send extra data to the + // responder. + let (init, msg): (SecStream>, Vec) = init.write_msg(b"hello")?; - let (resp, payload): (SecStream>, Vec) = - resp.read_first_msg_get_payload(&msg)?; + // responder receives the hs message, extracts the payload + let (resp, payload): (SecStream>, Vec) = resp.read_msg(&msg)?; assert_eq!(payload, b"hello"); - let payload2 = b"goodbye"; + // responder sends a handshake message, which can include a payload. As well as a second + // message which contains the symmetric key needed to set up the decryptor let (resp, [msg1, msg2]): (SecStream, [Vec; 2]) = - resp.make_second_msg(payload2)?; + resp.write_msg(b"goodbye")?; - let (init, payload_recv) = init.read_first_msg_get_payload(&msg1)?; + // Initiator receives last handshake message, use handshake to create the extract payload. + let (init, payload_recv): (SecStream>, Vec) = init.read_msg(&msg1)?; assert_eq!(payload_recv, b"goodbye"); - let (init, to_resp): (SecStream, Vec) = init.make_msg()?; + // receive decryptor keey + let (init, to_resp): (SecStream, Vec) = init.write_msg()?; - let mut init: SecStream = init.receive_msg(&msg2)?; - let mut resp: SecStream = resp.receive_msg(&to_resp)?; + let mut init: SecStream = init.read_msg(&msg2)?; + let mut resp: SecStream = resp.read_msg(&to_resp)?; let hello = b"hello".to_vec(); let mut msg = hello.clone(); diff --git a/tests/js/mod.rs b/tests/js/mod.rs index 39d83de..41a347e 100644 --- a/tests/js/mod.rs +++ b/tests/js/mod.rs @@ -30,9 +30,9 @@ pub fn cleanup() { } pub fn install() { - let status = Command::new("npm") + let status = Command::new("cp") .current_dir("tests/js") - .args(["install"]) + .args(["-r", "node_modules_for_tests", "node_modules"]) .status() .expect("Unable to run npm install"); assert_eq!( From d719b13c7515693c0d43cf9f384e9856b360f5cc Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 5 Aug 2025 00:25:08 -0400 Subject: [PATCH 155/206] move test to doc-comment --- src/sstream.rs | 119 +++++++++++++++++++++++++------------------------ 1 file changed, 61 insertions(+), 58 deletions(-) diff --git a/src/sstream.rs b/src/sstream.rs index 4a8b655..fee407c 100644 --- a/src/sstream.rs +++ b/src/sstream.rs @@ -1,6 +1,56 @@ -//! Create a sodium secret stream using the Noise IK pattern. -//! We use the "Typestate pattern" for the steps of the handshake. -//! Attempt at a diagram +/*! + Create a sodium secret stream using the Noise IK pattern. + We use the "Typestate pattern" for the steps of the handshake. + ``` +// Excessive typing to demonstrate flow through typestates +# use hypercore_protocol::{sstream::{Ready, Start, SecStream, Initiator, InitiatorConfig, +# PARAMS, Responder, HsMsgSent, +# EncryptorReady, HsDone}}; +use snow::{params::NoiseParams, Builder}; +use crypto_secretstream::{Header, Tag}; +let params: NoiseParams = PARAMS.parse().expect("known to work"); +let kp = Builder::new(params.clone()).generate_keypair()?; +let config = InitiatorConfig::new(kp.public.try_into().unwrap()); +// Create an initiator and responder +let init: SecStream> = SecStream::new_initiator(config)?; +let resp: SecStream> = SecStream::new_responder(&kp.private)?; + +// initiator sends the first handshake message, a payload can be included to send extra data to the +// responder. +let (init, msg): (SecStream>, Vec) = init.write_msg(b"hello")?; + +// responder receives the hs message, extracts the payload +let (resp, payload): (SecStream>, Vec) = resp.read_msg(&msg)?; +assert_eq!(payload, b"hello"); + +// responder sends a handshake message, which can include a payload. As well as a second +// message which contains the symmetric key needed to set up the decryptor +let (resp, [msg1, msg2]): (SecStream, [Vec; 2]) = + resp.write_msg(b"goodbye")?; + +// Initiator receives last handshake message, use handshake to create the extract payload. +let (init, payload_recv): (SecStream>, Vec) = init.read_msg(&msg1)?; +assert_eq!(payload_recv, b"goodbye"); + +// receive decryptor keey +let (init, to_resp): (SecStream, Vec) = init.write_msg()?; + +let mut init: SecStream = init.read_msg(&msg2)?; +let mut resp: SecStream = resp.read_msg(&to_resp)?; + +let hello = b"hello".to_vec(); +let mut msg = hello.clone(); + +println!("msg {msg:?}"); +init.push(&mut msg, &[], Tag::Message)?; +println!("enc {msg:?}"); +let tag = resp.pull(&mut msg, &[])?; +println!("res {msg:?} tag: {tag:?}"); +assert_eq!(msg, hello); + +Ok::<(), Box>(()) + ``` +*/ use crypto_secretstream::{Header, Key, PullStream, PushStream, Tag}; use rand::rngs::OsRng; use snow::{params::NoiseParams, Builder, HandshakeState}; @@ -8,7 +58,7 @@ use std::{fmt::Debug, marker::PhantomData}; use crate::{crypto::write_stream_id, Error}; -const PARAMS: &str = "Noise_IK_25519_ChaChaPoly_BLAKE2b"; +pub const PARAMS: &str = "Noise_IK_25519_ChaChaPoly_BLAKE2b"; const STREAM_ID_LENGTH: usize = 32; const RAW_HEADER_MSG_LEN: usize = STREAM_ID_LENGTH + Header::BYTES; const SNOW_CIPHERKEYLEN: usize = 32; @@ -64,13 +114,16 @@ pub struct Responder { _res_step: PhantomData, } /// The first step. We must send or receive a handshake message to proceed. -struct Start; +#[derive(Debug)] +pub struct Start; /// The handshake message has been sent. We must receive a handshake message to proceed to /// [`HsDone`]. Only on [`Initiator`]. -struct HsMsgSent; +#[derive(Debug)] +pub struct HsMsgSent; /// [`snow::HandshakeState::is_handshake_finished`] is `true`. /// We are ready create a [`PushStream`] and proeed to [`EncryptorReady`]. -struct HsDone; +#[derive(Debug)] +pub struct HsDone; /// No decryptor yet pub struct EncryptorReady { @@ -288,7 +341,7 @@ impl SecStream> { } impl SecStream> { - fn write_msg(mut self) -> Result<(SecStream, Vec), Error> { + pub fn write_msg(mut self) -> Result<(SecStream, Vec), Error> { let (tx, rx) = self.split_handshake(); let key: [u8; SNOW_CIPHERKEYLEN] = tx[..SNOW_CIPHERKEYLEN] .try_into() @@ -377,53 +430,3 @@ impl SecStream { Ok(self.step.puller.pull(msg, associated_data)?) } } - -#[cfg(test)] -mod test { - use super::*; - - #[tokio::test] - async fn set_up_secret_steram() -> std::result::Result<(), Box> { - // Excessive typing to demonstrate flow through typestates - let params: NoiseParams = PARAMS.parse().expect("known to work"); - let kp = Builder::new(params.clone()).generate_keypair()?; - let config = InitiatorConfig::new(kp.public.try_into().unwrap()); - // Create an initiator and responder - let init: SecStream> = SecStream::new_initiator(config)?; - let resp: SecStream> = SecStream::new_responder(&kp.private)?; - - // initiator sends the first handshake message, a payload can be included to send extra data to the - // responder. - let (init, msg): (SecStream>, Vec) = init.write_msg(b"hello")?; - - // responder receives the hs message, extracts the payload - let (resp, payload): (SecStream>, Vec) = resp.read_msg(&msg)?; - assert_eq!(payload, b"hello"); - - // responder sends a handshake message, which can include a payload. As well as a second - // message which contains the symmetric key needed to set up the decryptor - let (resp, [msg1, msg2]): (SecStream, [Vec; 2]) = - resp.write_msg(b"goodbye")?; - - // Initiator receives last handshake message, use handshake to create the extract payload. - let (init, payload_recv): (SecStream>, Vec) = init.read_msg(&msg1)?; - assert_eq!(payload_recv, b"goodbye"); - - // receive decryptor keey - let (init, to_resp): (SecStream, Vec) = init.write_msg()?; - - let mut init: SecStream = init.read_msg(&msg2)?; - let mut resp: SecStream = resp.read_msg(&to_resp)?; - - let hello = b"hello".to_vec(); - let mut msg = hello.clone(); - - println!("msg {msg:?}"); - init.push(&mut msg, &[], Tag::Message)?; - println!("enc {msg:?}"); - let tag = resp.pull(&mut msg, &[])?; - println!("res {msg:?} tag: {tag:?}"); - assert_eq!(msg, hello); - Ok(()) - } -} From 4b9a0f00cd306850cb74b17aa001b85b4ff64190 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 6 Aug 2025 23:03:25 -0400 Subject: [PATCH 156/206] format rustdoc comments --- src/lib.rs | 15 ++++++++++++--- src/sstream.rs | 19 ++++++++----------- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 08518f0..1177d3e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -53,7 +53,9 @@ //! use async_std::prelude::*; //! use hypercore_protocol::{schema::*, Event, Message, ProtocolBuilder}; //! // Start a tcp server. -//! let listener = async_std::net::TcpListener::bind("localhost:8000").await.unwrap(); +//! let listener = async_std::net::TcpListener::bind("localhost:8000") +//! .await +//! .unwrap(); //! async_std::task::spawn(async move { //! let mut incoming = listener.incoming(); //! while let Some(Ok(stream)) = incoming.next().await { @@ -62,7 +64,9 @@ //! }); //! //! // Connect a client. -//! let stream = async_std::net::TcpStream::connect("localhost:8000").await.unwrap(); +//! let stream = async_std::net::TcpStream::connect("localhost:8000") +//! .await +//! .unwrap(); //! onconnection(stream, true).await; //! //! /// Start Hypercore protocol on a TcpStream. @@ -89,7 +93,12 @@ //! // A Channel can be sent to other tasks. //! async_std::task::spawn(async move { //! // A Channel can both send messages and is a stream of incoming messages. -//! channel.send(Message::Want(Want { start: 0, length: 1 })).await; +//! channel +//! .send(Message::Want(Want { +//! start: 0, +//! length: 1, +//! })) +//! .await; //! while let Some(message) = channel.next().await { //! eprintln!("{} received message: {:?}", name, message); //! } diff --git a/src/sstream.rs b/src/sstream.rs index fee407c..357539e 100644 --- a/src/sstream.rs +++ b/src/sstream.rs @@ -17,36 +17,33 @@ let resp: SecStream> = SecStream::new_responder(&kp.private)?; // initiator sends the first handshake message, a payload can be included to send extra data to the // responder. -let (init, msg): (SecStream>, Vec) = init.write_msg(b"hello")?; +let (init, msg): (SecStream>, Vec) = init.write_msg(b"one")?; // responder receives the hs message, extracts the payload let (resp, payload): (SecStream>, Vec) = resp.read_msg(&msg)?; -assert_eq!(payload, b"hello"); +assert_eq!(payload, b"one"); // responder sends a handshake message, which can include a payload. As well as a second // message which contains the symmetric key needed to set up the decryptor let (resp, [msg1, msg2]): (SecStream, [Vec; 2]) = - resp.write_msg(b"goodbye")?; + resp.write_msg(b"two")?; // Initiator receives last handshake message, use handshake to create the extract payload. let (init, payload_recv): (SecStream>, Vec) = init.read_msg(&msg1)?; -assert_eq!(payload_recv, b"goodbye"); +assert_eq!(payload_recv, b"two"); // receive decryptor keey let (init, to_resp): (SecStream, Vec) = init.write_msg()?; +// finalize both sides let mut init: SecStream = init.read_msg(&msg2)?; let mut resp: SecStream = resp.read_msg(&to_resp)?; -let hello = b"hello".to_vec(); -let mut msg = hello.clone(); - -println!("msg {msg:?}"); +// Now both sides can send and receive messages +let mut msg = b"three".to_vec(); init.push(&mut msg, &[], Tag::Message)?; -println!("enc {msg:?}"); let tag = resp.pull(&mut msg, &[])?; -println!("res {msg:?} tag: {tag:?}"); -assert_eq!(msg, hello); +assert_eq!(msg, b"three"); Ok::<(), Box>(()) ``` From 666fa340ee1cee330c28d5851d3e25956390ebb8 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 6 Aug 2025 23:04:11 -0400 Subject: [PATCH 157/206] rustdoc --- src/sstream.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/sstream.rs b/src/sstream.rs index 357539e..17cd9ab 100644 --- a/src/sstream.rs +++ b/src/sstream.rs @@ -55,6 +55,7 @@ use std::{fmt::Debug, marker::PhantomData}; use crate::{crypto::write_stream_id, Error}; +/// Default pattern pub const PARAMS: &str = "Noise_IK_25519_ChaChaPoly_BLAKE2b"; const STREAM_ID_LENGTH: usize = 32; const RAW_HEADER_MSG_LEN: usize = STREAM_ID_LENGTH + Header::BYTES; @@ -338,6 +339,7 @@ impl SecStream> { } impl SecStream> { + /// Write the final setup message pub fn write_msg(mut self) -> Result<(SecStream, Vec), Error> { let (tx, rx) = self.split_handshake(); let key: [u8; SNOW_CIPHERKEYLEN] = tx[..SNOW_CIPHERKEYLEN] From a030dfc32583ac63c834a8d47c22ce7426455610 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 6 Aug 2025 23:04:24 -0400 Subject: [PATCH 158/206] rustfmt --- src/crypto/cipher.rs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index 4fb42c9..5e8d438 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -66,9 +66,10 @@ impl DecryptCipher { } pub(crate) fn decrypt_buf(&mut self, buf: &[u8]) -> io::Result<(Vec, Tag)> { let mut to_decrypt = buf.to_vec(); - let tag = &self.pull_stream.pull(&mut to_decrypt, &[]).map_err(|err| { - io::Error::other(format!("Decrypt failed: {err}")) - })?; + let tag = &self + .pull_stream + .pull(&mut to_decrypt, &[]) + .map_err(|err| io::Error::other(format!("Decrypt failed: {err}")))?; Ok((to_decrypt, *tag)) } } @@ -146,9 +147,7 @@ impl EncryptCipher { let mut out = msg.to_vec(); self.push_stream .push(&mut out, &[], Tag::Message) - .map_err(|err| { - io::Error::other(format!("Encrypt failed: {err}")) - })?; + .map_err(|err| io::Error::other(format!("Encrypt failed: {err}")))?; Ok(out) } } From 793c1132582c81ad800d01def10d641ede4c6714 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 6 Aug 2025 23:07:12 -0400 Subject: [PATCH 159/206] move sstream to mod --- src/{sstream.rs => sstream/mod.rs} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/{sstream.rs => sstream/mod.rs} (100%) diff --git a/src/sstream.rs b/src/sstream/mod.rs similarity index 100% rename from src/sstream.rs rename to src/sstream/mod.rs From 9d6c2cc544494f95709b8061a07e2e2e41b3bfe7 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 10 Aug 2025 12:37:34 -0400 Subject: [PATCH 160/206] add statemachine mod & format docs --- src/sstream/mod.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/sstream/mod.rs b/src/sstream/mod.rs index 17cd9ab..e76994c 100644 --- a/src/sstream/mod.rs +++ b/src/sstream/mod.rs @@ -3,11 +3,10 @@ We use the "Typestate pattern" for the steps of the handshake. ``` // Excessive typing to demonstrate flow through typestates -# use hypercore_protocol::{sstream::{Ready, Start, SecStream, Initiator, InitiatorConfig, -# PARAMS, Responder, HsMsgSent, -# EncryptorReady, HsDone}}; -use snow::{params::NoiseParams, Builder}; -use crypto_secretstream::{Header, Tag}; +use crate::sstream::{ + EncryptorReady, HsDone, HsMsgSent, Initiator, InitiatorConfig, Ready, Responder, SecStream, + Start, PARAMS, +}; let params: NoiseParams = PARAMS.parse().expect("known to work"); let kp = Builder::new(params.clone()).generate_keypair()?; let config = InitiatorConfig::new(kp.public.try_into().unwrap()); @@ -25,8 +24,7 @@ assert_eq!(payload, b"one"); // responder sends a handshake message, which can include a payload. As well as a second // message which contains the symmetric key needed to set up the decryptor -let (resp, [msg1, msg2]): (SecStream, [Vec; 2]) = - resp.write_msg(b"two")?; +let (resp, [msg1, msg2]): (SecStream, [Vec; 2]) = resp.write_msg(b"two")?; // Initiator receives last handshake message, use handshake to create the extract payload. let (init, payload_recv): (SecStream>, Vec) = init.read_msg(&msg1)?; @@ -44,14 +42,16 @@ let mut msg = b"three".to_vec(); init.push(&mut msg, &[], Tag::Message)?; let tag = resp.pull(&mut msg, &[])?; assert_eq!(msg, b"three"); - Ok::<(), Box>(()) ``` */ +mod statemachine; +mod streamsink; + use crypto_secretstream::{Header, Key, PullStream, PushStream, Tag}; use rand::rngs::OsRng; use snow::{params::NoiseParams, Builder, HandshakeState}; -use std::{fmt::Debug, marker::PhantomData}; +use std::{fmt::Debug, marker::PhantomData, mem::replace}; use crate::{crypto::write_stream_id, Error}; From aab7276384f2b9e4918439714ab610d29bb9a1c0 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 10 Aug 2025 12:39:27 -0400 Subject: [PATCH 161/206] Add WIP statemachine mod --- src/sstream/statemachine.rs | 135 ++++++++++++++++++++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 src/sstream/statemachine.rs diff --git a/src/sstream/statemachine.rs b/src/sstream/statemachine.rs new file mode 100644 index 0000000..cc00e08 --- /dev/null +++ b/src/sstream/statemachine.rs @@ -0,0 +1,135 @@ +/*! +I need a pattern for this where I want to match an enum on a certain state. Take data from +within the state, then replace the state with a new thing. +enum in place. + +What doesn't work: +``` +match &self.state { + State::InitiatorStart(s) => { + // NB: write_msg takes ownership of self. It's signature is: (mut self, &[u8]) + // I + let (new_s, msg) = s.write_msg(prologue)?; + let _ = replace(&mut self.state, State::InitiatorSent(new_s)); + Ok(msg) + } + _ => todo!(), +} +``` +because `s` is behind a shared reference. + +A thing that does work: +``` +// We must be in self.state == InitiatorStart, we want to go to InitiatorSent +if let State::InitiatorStart(s) = replace(&mut self.state, State::Invalid) { + let (new_s, msg) = s.write_msg(prologue)?; + let _ = replace(&mut self.state, State::InitiatorSent(new_s)); + Ok(msg) +} else { + Err(Error::InvalidHandshakeState( + "Cannot start handshake in current state".into(), + )) +} +``` +*/ +use std::mem::replace; + +use super::{ + EncryptorReady, HsDone, HsMsgSent, Initiator, InitiatorConfig, Ready, Responder, SecStream, + Start, +}; +use crate::error::Error; + +enum State { + InitiatorStart(SecStream>), + InitiatorSent(SecStream>), + InitiatorDone(SecStream>), + ResponderStart(SecStream>), + ResponderReady(SecStream), + Ready(SecStream), + Invalid, +} + +struct Manager { + state: State, + is_initiator: bool, +} +impl Manager { + pub fn read_msg(&mut self, prologue: &[u8]) -> Result>, Error> { + match replace(&mut self.state, State::Invalid) { + State::ResponderStart(s) => { + let (s2, payload) = s.read_msg(prologue)?; + // handle payload + let mm = b""; + let (s3, two_msgs) = s2.write_msg(mm)?; + let _ = replace(&mut self.state, State::ResponderReady(s3)); + Ok(two_msgs.to_vec()) + } + State::InitiatorSent(s) => { + let (s2, payload) = s.read_msg(prologue)?; + // handle payload + let mm = b""; + //let (s3, outmsg) = s2.write_msg(mm)?; + //Ok(vec![outmsg]) + todo!() + } + + _ => Err(Error::InvalidHandshakeState( + "Cannot start handshake in current state".into(), + )), + } + } + + /// Create new initiator + pub fn new_initiator(config: InitiatorConfig) -> Result { + Ok(Self { + state: State::InitiatorStart(SecStream::new_initiator(config)?), + is_initiator: true, + }) + } + + /// Create new responder + pub fn new_responder(private: &[u8]) -> Result { + Ok(Self { + state: State::ResponderStart(SecStream::new_responder(private)?), + is_initiator: false, + }) + } + /// Start handshake (initiator only) + pub fn start_handshake(&mut self, prologue: &[u8]) -> Result, Error> { + match replace(&mut self.state, State::Invalid) { + State::InitiatorStart(s) => { + let (new_s, msg) = s.write_msg(prologue)?; + let _ = replace(&mut self.state, State::InitiatorSent(new_s)); + Ok(msg) + } + _ => Err(Error::InvalidHandshakeState( + "Cannot start handshake in current state".into(), + )), + } + } + fn transition(&mut self, f: F) -> Result + where + F: FnOnce(State) -> Result<(State, R), (State, Error)>, + { + let old_state = replace(&mut self.state, State::Invalid); + match f(old_state) { + Ok((new_state, result)) => { + self.state = new_state; + Ok(result) + } + Err((restored_state, error)) => { + self.state = restored_state; + Err(error) + } + } + } +} + +#[cfg(test)] +mod test { + use super::*; + fn start_hs() -> Result<(), Error> { + todo!() + } +} From 252f6ce3c3951b4980d46cad4ce2432c205a103f Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 10 Aug 2025 12:40:19 -0400 Subject: [PATCH 162/206] WIP streamsink --- src/sstream/streamsink.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 src/sstream/streamsink.rs diff --git a/src/sstream/streamsink.rs b/src/sstream/streamsink.rs new file mode 100644 index 0000000..a8fa72c --- /dev/null +++ b/src/sstream/streamsink.rs @@ -0,0 +1,16 @@ +use std::collections::VecDeque; + +use crate::error::Error; + +pub enum Event {} + +pub struct Encrypted { + io: IO, + step: (), + is_initiator: bool, + encrypted_tx: VecDeque>, + encrypted_rx: VecDeque, Error>>, + plain_tx: VecDeque>, + plain_rx: VecDeque, + flush: bool, +} From 419a72209f16b5a7695dae951edb0c1ef26bfa10 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 10 Aug 2025 12:40:35 -0400 Subject: [PATCH 163/206] Add invalid HS state error --- src/error.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/error.rs b/src/error.rs index 81764e3..c11a76f 100644 --- a/src/error.rs +++ b/src/error.rs @@ -4,6 +4,8 @@ pub enum Error { Snow(#[from] snow::Error), #[error("Error from `crypto_secretstream`: {0}")] SecretStream(crypto_secretstream::aead::Error), + #[error("Invalid Handshake State: {0}")] + InvalidHandshakeState(String), } impl From for Error { From 184092c59636d70477728f4b7f8b732f17e17b24 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 10 Aug 2025 12:41:43 -0400 Subject: [PATCH 164/206] add .rustfmt.toml --- .rustfmt.toml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .rustfmt.toml diff --git a/.rustfmt.toml b/.rustfmt.toml new file mode 100644 index 0000000..adbe5db --- /dev/null +++ b/.rustfmt.toml @@ -0,0 +1,5 @@ +# groups 'use' statements by crate +imports_granularity = "crate" +# formats code within doc tests +# requires: cargo +nightly fmt (otherwise rustfmt will warn, but pass) +format_code_in_doc_comments = true From f587a45a32eda082488bc841fdb5d1ac777b456e Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 10 Aug 2025 14:20:20 -0400 Subject: [PATCH 165/206] fix doctest --- src/sstream/mod.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/sstream/mod.rs b/src/sstream/mod.rs index e76994c..df5590f 100644 --- a/src/sstream/mod.rs +++ b/src/sstream/mod.rs @@ -3,10 +3,12 @@ We use the "Typestate pattern" for the steps of the handshake. ``` // Excessive typing to demonstrate flow through typestates -use crate::sstream::{ +use hypercore_protocol::sstream::{ EncryptorReady, HsDone, HsMsgSent, Initiator, InitiatorConfig, Ready, Responder, SecStream, Start, PARAMS, }; +use crypto_secretstream::Tag; +use snow::{Builder, params::NoiseParams}; let params: NoiseParams = PARAMS.parse().expect("known to work"); let kp = Builder::new(params.clone()).generate_keypair()?; let config = InitiatorConfig::new(kp.public.try_into().unwrap()); From f986c396f9c1394900b517029724e3fc770d0303 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 10 Aug 2025 15:19:33 -0400 Subject: [PATCH 166/206] make typestate payload/prologue optional --- src/sstream/mod.rs | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/src/sstream/mod.rs b/src/sstream/mod.rs index df5590f..7aa62b6 100644 --- a/src/sstream/mod.rs +++ b/src/sstream/mod.rs @@ -18,7 +18,7 @@ let resp: SecStream> = SecStream::new_responder(&kp.private)?; // initiator sends the first handshake message, a payload can be included to send extra data to the // responder. -let (init, msg): (SecStream>, Vec) = init.write_msg(b"one")?; +let (init, msg): (SecStream>, Vec) = init.write_msg(Some(b"one"))?; // responder receives the hs message, extracts the payload let (resp, payload): (SecStream>, Vec) = resp.read_msg(&msg)?; @@ -26,18 +26,18 @@ assert_eq!(payload, b"one"); // responder sends a handshake message, which can include a payload. As well as a second // message which contains the symmetric key needed to set up the decryptor -let (resp, [msg1, msg2]): (SecStream, [Vec; 2]) = resp.write_msg(b"two")?; +let (resp, [msg1, msg2]): (SecStream, [Vec; 2]) = resp.write_msg(Some(b"two"))?; // Initiator receives last handshake message, use handshake to create the extract payload. let (init, payload_recv): (SecStream>, Vec) = init.read_msg(&msg1)?; assert_eq!(payload_recv, b"two"); // receive decryptor keey -let (init, to_resp): (SecStream, Vec) = init.write_msg()?; +let (init, to_resp_final): (SecStream, Vec) = init.write_msg()?; // finalize both sides let mut init: SecStream = init.read_msg(&msg2)?; -let mut resp: SecStream = resp.read_msg(&to_resp)?; +let mut resp: SecStream = resp.read_msg(&to_resp_final)?; // Now both sides can send and receive messages let mut msg = b"three".to_vec(); @@ -53,7 +53,7 @@ mod streamsink; use crypto_secretstream::{Header, Key, PullStream, PushStream, Tag}; use rand::rngs::OsRng; use snow::{params::NoiseParams, Builder, HandshakeState}; -use std::{fmt::Debug, marker::PhantomData, mem::replace}; +use std::{fmt::Debug, marker::PhantomData}; use crate::{crypto::write_stream_id, Error}; @@ -178,8 +178,9 @@ impl SecStream> { /// Create the first message the initiator sends to the responder pub fn write_msg( mut self, - prologue: &[u8], + prologue: Option<&[u8]>, ) -> Result<(SecStream>, Vec), Error> { + let prologue = prologue.unwrap_or_default(); let len = self.state.write_message(prologue, &mut self.msg_buf)?; let msg = self.msg_buf[..len].to_vec(); let Self { @@ -250,7 +251,7 @@ impl SecStream> { msg: &[u8], ) -> Result<(SecStream, [Vec; 2]), Error> { let (self2, _rx_payload) = self.read_msg(msg)?; - self2.write_msg(&[]) + self2.write_msg(Some(&[])) } } @@ -259,8 +260,9 @@ impl SecStream> { /// Noise handshake. The second has the shared key for the remote to set up a Decryptor. pub fn write_msg( mut self, - payload: &[u8], + payload: Option<&[u8]>, ) -> Result<(SecStream, [Vec; 2]), Error> { + let payload = payload.unwrap_or_default(); let len = self.state.write_message(payload, &mut self.msg_buf)?; let hs_msg = self.msg_buf[..len].to_vec(); assert!(self.state.is_handshake_finished()); @@ -414,6 +416,15 @@ impl SecStream { step: Ready { pusher, puller }, }) } + + pub fn push( + &mut self, + msg: &mut Vec, + associated_data: &[u8], + tag: Tag, + ) -> Result<(), Error> { + Ok(self.step.pusher.push(msg, associated_data, tag)?) + } } impl SecStream { From 7ec7a2ac311467ea9d564beefb9ef04ebfed828c Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 10 Aug 2025 15:26:18 -0400 Subject: [PATCH 167/206] WIP secretstream statemachine --- src/sstream/mod.rs | 2 +- src/sstream/statemachine.rs | 124 ++++++++++++++++-------------------- 2 files changed, 56 insertions(+), 70 deletions(-) diff --git a/src/sstream/mod.rs b/src/sstream/mod.rs index 7aa62b6..0e58aa1 100644 --- a/src/sstream/mod.rs +++ b/src/sstream/mod.rs @@ -343,7 +343,7 @@ impl SecStream> { } impl SecStream> { - /// Write the final setup message + /// Write the final setup message pub fn write_msg(mut self) -> Result<(SecStream, Vec), Error> { let (tx, rx) = self.split_handshake(); let key: [u8; SNOW_CIPHERKEYLEN] = tx[..SNOW_CIPHERKEYLEN] diff --git a/src/sstream/statemachine.rs b/src/sstream/statemachine.rs index cc00e08..3b0ee52 100644 --- a/src/sstream/statemachine.rs +++ b/src/sstream/statemachine.rs @@ -1,51 +1,20 @@ -/*! -I need a pattern for this where I want to match an enum on a certain state. Take data from -within the state, then replace the state with a new thing. -enum in place. - -What doesn't work: -``` -match &self.state { - State::InitiatorStart(s) => { - // NB: write_msg takes ownership of self. It's signature is: (mut self, &[u8]) - // I - let (new_s, msg) = s.write_msg(prologue)?; - let _ = replace(&mut self.state, State::InitiatorSent(new_s)); - Ok(msg) - } - _ => todo!(), -} -``` -because `s` is behind a shared reference. - -A thing that does work: -``` -// We must be in self.state == InitiatorStart, we want to go to InitiatorSent -if let State::InitiatorStart(s) = replace(&mut self.state, State::Invalid) { - let (new_s, msg) = s.write_msg(prologue)?; - let _ = replace(&mut self.state, State::InitiatorSent(new_s)); - Ok(msg) -} else { - Err(Error::InvalidHandshakeState( - "Cannot start handshake in current state".into(), - )) -} -``` -*/ +//! A state machine for the building secre stream. use std::mem::replace; +use crypto_secretstream::Tag; + use super::{ EncryptorReady, HsDone, HsMsgSent, Initiator, InitiatorConfig, Ready, Responder, SecStream, Start, }; use crate::error::Error; +#[derive(Debug)] enum State { InitiatorStart(SecStream>), - InitiatorSent(SecStream>), - InitiatorDone(SecStream>), ResponderStart(SecStream>), - ResponderReady(SecStream), + InitiatorSent(SecStream>), + EncReady(SecStream), Ready(SecStream), Invalid, } @@ -55,31 +24,64 @@ struct Manager { is_initiator: bool, } impl Manager { - pub fn read_msg(&mut self, prologue: &[u8]) -> Result>, Error> { + pub fn read_msg(&mut self, msg: &[u8]) -> Result>, Error> { match replace(&mut self.state, State::Invalid) { State::ResponderStart(s) => { - let (s2, payload) = s.read_msg(prologue)?; - // handle payload - let mm = b""; - let (s3, two_msgs) = s2.write_msg(mm)?; - let _ = replace(&mut self.state, State::ResponderReady(s3)); + let (s2, payload) = s.read_msg(msg)?; + // TODO handle payload, pass something to write_msg + let (s3, two_msgs) = s2.write_msg(None)?; + self.state = State::EncReady(s3); Ok(two_msgs.to_vec()) } State::InitiatorSent(s) => { - let (s2, payload) = s.read_msg(prologue)?; + let (s2, payload) = s.read_msg(msg)?; // handle payload - let mm = b""; - //let (s3, outmsg) = s2.write_msg(mm)?; - //Ok(vec![outmsg]) - todo!() + // TODO handle payload + let (s3, outmsg) = s2.write_msg()?; + // TODO if we want to add another msg, we can create it here with `let extra_out = s3.push(msg)` + self.state = State::EncReady(s3); // Update state + Ok(vec![outmsg]) + } + State::EncReady(s) => { + let s2 = s.read_msg(msg)?; + self.state = State::Ready(s2); // Update state + Ok(vec![]) + } + s => { + let msg = format!("Cannot start handshake in state: {s:?}"); + self.state = s; + Err(Error::InvalidHandshakeState(msg)) } - - _ => Err(Error::InvalidHandshakeState( - "Cannot start handshake in current state".into(), - )), } } + /// Encrypt a message in place + pub fn push( + &mut self, + msg: &mut Vec, + associated_data: &[u8], + tag: Tag, + ) -> Result<(), Error> { + Ok(match &mut self.state { + State::Ready(s) => s.push(msg, associated_data, tag)?, + s => { + return Err(Error::InvalidHandshakeState(format!( + "Cannot Encrypt a message while in state: {s:?}" + ))); + } + }) + } + /// Decrypt a message in place + pub fn pull(&mut self, msg: &mut Vec, associated_data: &[u8]) -> Result { + Ok(match &mut self.state { + State::Ready(s) => s.pull(msg, associated_data)?, + s => { + return Err(Error::InvalidHandshakeState(format!( + "Cannot Decrypet a message while in state: {s:?}" + ))); + } + }) + } /// Create new initiator pub fn new_initiator(config: InitiatorConfig) -> Result { Ok(Self { @@ -99,7 +101,7 @@ impl Manager { pub fn start_handshake(&mut self, prologue: &[u8]) -> Result, Error> { match replace(&mut self.state, State::Invalid) { State::InitiatorStart(s) => { - let (new_s, msg) = s.write_msg(prologue)?; + let (new_s, msg) = s.write_msg(Some(prologue))?; let _ = replace(&mut self.state, State::InitiatorSent(new_s)); Ok(msg) } @@ -108,22 +110,6 @@ impl Manager { )), } } - fn transition(&mut self, f: F) -> Result - where - F: FnOnce(State) -> Result<(State, R), (State, Error)>, - { - let old_state = replace(&mut self.state, State::Invalid); - match f(old_state) { - Ok((new_state, result)) => { - self.state = new_state; - Ok(result) - } - Err((restored_state, error)) => { - self.state = restored_state; - Err(error) - } - } - } } #[cfg(test)] From 40bda39f86850d56cee51f5dae556b460ad84da1 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 10 Aug 2025 16:20:08 -0400 Subject: [PATCH 168/206] add prologue --- src/sstream/mod.rs | 11 ++++++----- src/sstream/statemachine.rs | 12 ++---------- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/src/sstream/mod.rs b/src/sstream/mod.rs index 0e58aa1..5b23bbb 100644 --- a/src/sstream/mod.rs +++ b/src/sstream/mod.rs @@ -13,7 +13,7 @@ let params: NoiseParams = PARAMS.parse().expect("known to work"); let kp = Builder::new(params.clone()).generate_keypair()?; let config = InitiatorConfig::new(kp.public.try_into().unwrap()); // Create an initiator and responder -let init: SecStream> = SecStream::new_initiator(config)?; +let init: SecStream> = SecStream::new_initiator(config, &[])?; let resp: SecStream> = SecStream::new_responder(&kp.private)?; // initiator sends the first handshake message, a payload can be included to send extra data to the @@ -157,11 +157,12 @@ impl Debug for Ready { impl SecStream> { /// Create an initiator of a secret stream - pub fn new_initiator(config: InitiatorConfig) -> Result { + pub fn new_initiator(config: InitiatorConfig, prologue: &[u8]) -> Result { let params: NoiseParams = PARAMS.parse().expect("known to work"); let kp = Builder::new(params.clone()).generate_keypair()?; let state = Builder::new(params.clone()) .local_private_key(&kp.private)? + .prologue(prologue)? .remote_public_key(&config.remote_public_key)? .build_initiator()?; @@ -178,10 +179,10 @@ impl SecStream> { /// Create the first message the initiator sends to the responder pub fn write_msg( mut self, - prologue: Option<&[u8]>, + payload: Option<&[u8]>, ) -> Result<(SecStream>, Vec), Error> { - let prologue = prologue.unwrap_or_default(); - let len = self.state.write_message(prologue, &mut self.msg_buf)?; + let payload = payload.unwrap_or_default(); + let len = self.state.write_message(payload, &mut self.msg_buf)?; let msg = self.msg_buf[..len].to_vec(); let Self { is_initiator, diff --git a/src/sstream/statemachine.rs b/src/sstream/statemachine.rs index 3b0ee52..f08ab6d 100644 --- a/src/sstream/statemachine.rs +++ b/src/sstream/statemachine.rs @@ -83,9 +83,9 @@ impl Manager { }) } /// Create new initiator - pub fn new_initiator(config: InitiatorConfig) -> Result { + pub fn new_initiator(config: InitiatorConfig, prologue: &[u8]) -> Result { Ok(Self { - state: State::InitiatorStart(SecStream::new_initiator(config)?), + state: State::InitiatorStart(SecStream::new_initiator(config, prologue)?), is_initiator: true, }) } @@ -111,11 +111,3 @@ impl Manager { } } } - -#[cfg(test)] -mod test { - use super::*; - fn start_hs() -> Result<(), Error> { - todo!() - } -} From ccc2f35aa16dd45d1637fae55b3cb040651a4018 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 12 Aug 2025 23:24:07 -0400 Subject: [PATCH 169/206] bump edition to 2024 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 5695165..ad88ad8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ authors = [ documentation = "https://docs.rs/hypercore-protocol" repository = "https://github.com/datrs/hypercore-protocol-rs" readme = "README.md" -edition = "2021" +edition = "2024" keywords = ["dat", "p2p", "replication", "hypercore", "protocol"] categories = [ "asynchronous", From 63443263553e002653124545be591c622634fcba Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 12 Aug 2025 23:24:29 -0400 Subject: [PATCH 170/206] rm outdated forbids --- src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 1177d3e..e22713a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -119,7 +119,7 @@ //! [AsyncWrite]: futures_lite::AsyncWrite //! [examples]: https://github.com/datrs/hypercore-protocol-rs#examples -#![forbid(unsafe_code, future_incompatible, rust_2018_idioms)] +#![forbid(unsafe_code)] #![deny(missing_debug_implementations, nonstandard_style)] #![warn(missing_docs, unreachable_pub)] From 996bb33280d82b4bd5028f59801455306bf588dc Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 12 Aug 2025 23:25:08 -0400 Subject: [PATCH 171/206] intial commit of sm2 --- src/sstream/sm2.rs | 134 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 src/sstream/sm2.rs diff --git a/src/sstream/sm2.rs b/src/sstream/sm2.rs new file mode 100644 index 0000000..c09ebf2 --- /dev/null +++ b/src/sstream/sm2.rs @@ -0,0 +1,134 @@ +use std::{collections::VecDeque, mem::replace}; + +use crypto_secretstream::Tag; +use futures::{Sink, Stream}; + +use crate::{ + error::Error, + sstream::{EncryptorReady, HsMsgSent, Initiator, Ready, Responder, SecStream, Start}, +}; + +#[derive(Debug)] +pub(crate) enum State { + InitiatorStart(SecStream>), + InitiatorSent(SecStream>), + RespStart(SecStream>), + EncReady(SecStream), + Ready(SecStream), + Invalid, +} + +#[derive(Debug)] +enum Event { + HandshakePayload(Vec), + Message(Vec), +} + +#[derive(Debug)] +/// For each tx/rx VecDeque messages go in with `.push_back` then taken out with `.pop_front`. +/// If a message should skip the line it should be inserted with `.push_front`. +pub(crate) struct Machine { + io: IO, + is_initiator: bool, + state: State, + encrypted_tx: VecDeque>, + encrypted_rx: VecDeque, Error>>, + plain_tx: VecDeque>, + plain_rx: VecDeque, +} + +impl Machine +where + IO: Stream, Error>> + Sink> + Send + Unpin + 'static, +{ + fn new_init(io: IO, state: SecStream>) -> Self { + Self { + io, + state: State::InitiatorStart(state), + is_initiator: true, + encrypted_tx: Default::default(), + encrypted_rx: Default::default(), + plain_tx: Default::default(), + plain_rx: Default::default(), + } + } + fn new_resp(io: IO, state: SecStream>) -> Self { + Self { + io, + state: State::RespStart(state), + is_initiator: false, + encrypted_tx: Default::default(), + encrypted_rx: Default::default(), + plain_tx: Default::default(), + plain_rx: Default::default(), + } + } + fn handshake_start(&mut self, payload: &[u8]) -> Result<(), Error> { + match replace(&mut self.state, State::Invalid) { + State::InitiatorStart(s) => { + let (s2, out) = s.write_msg(Some(payload))?; + self.encrypted_tx.push_back(out); + self.state = State::InitiatorSent(s2); + Ok(()) + } + e => todo!(), + } + } + + fn poll_tx_rx(&mut self) -> Result, Error> { + match replace(&mut self.state, State::Invalid) { + State::InitiatorSent(s) => { + let Some(msg) = self.encrypted_rx.pop_front() else { + return Ok(None); + }; + let (s2, payload) = s.read_msg(&msg?)?; + // Ensure payload jumps to the front of the line + self.plain_rx.push_front(Event::HandshakePayload(payload)); + // Ensure payload jumps to the front of the line + let (s3, out) = s2.write_msg()?; + self.encrypted_tx.push_front(out); + self.state = State::EncReady(s3); + Ok(Some(())) + } + State::RespStart(s) => { + let Some(msg) = self.encrypted_rx.pop_front() else { + return Ok(None); + }; + let (s2, payload) = s.read_msg(&msg?)?; + // Ensure payload jumps to the front of the line + self.plain_rx.push_front(Event::HandshakePayload(payload)); + let next_tx = self.plain_tx.pop_front(); + let (s3, [msg1, msg2]) = s2.write_msg(next_tx.as_deref())?; + self.plain_tx.push_front(msg2); + self.plain_tx.push_front(msg1); + self.state = State::EncReady(s3); + Ok(Some(())) + } + State::EncReady(s) => { + let Some(msg) = self.encrypted_rx.pop_front() else { + return Ok(None); + }; + self.state = State::Ready(s.read_msg(&msg?)?); + Ok(Some(())) + } + State::Ready(mut s) => { + match self.encrypted_rx.pop_front() { + Some(Ok(mut m)) => { + s.pull(&mut m, &[])?; + self.plain_rx.push_back(Event::Message(m)); + } + Some(Err(_e)) => todo!(), + None => todo!(), + } + if let Some(mut m) = self.plain_tx.pop_front() { + s.push(&mut m, &[], Tag::Message)?; + self.encrypted_tx.push_back(m); + } + + todo!() + } + State::InitiatorStart(_sec_stream) => todo!(), + State::Invalid => todo!(), + } + } +} From c7fdf46c2a6965115fcd1e2f77c9570a25b63c77 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 12 Aug 2025 23:25:34 -0400 Subject: [PATCH 172/206] rm initiator config --- src/sstream/mod.rs | 25 +++++++------------------ 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/src/sstream/mod.rs b/src/sstream/mod.rs index 5b23bbb..3c229b0 100644 --- a/src/sstream/mod.rs +++ b/src/sstream/mod.rs @@ -4,16 +4,15 @@ ``` // Excessive typing to demonstrate flow through typestates use hypercore_protocol::sstream::{ - EncryptorReady, HsDone, HsMsgSent, Initiator, InitiatorConfig, Ready, Responder, SecStream, + EncryptorReady, HsDone, HsMsgSent, Initiator, Ready, Responder, SecStream, Start, PARAMS, }; use crypto_secretstream::Tag; use snow::{Builder, params::NoiseParams}; let params: NoiseParams = PARAMS.parse().expect("known to work"); let kp = Builder::new(params.clone()).generate_keypair()?; -let config = InitiatorConfig::new(kp.public.try_into().unwrap()); // Create an initiator and responder -let init: SecStream> = SecStream::new_initiator(config, &[])?; +let init: SecStream> = SecStream::new_initiator(&kp.public.try_into().unwrap(), &[])?; let resp: SecStream> = SecStream::new_responder(&kp.private)?; // initiator sends the first handshake message, a payload can be included to send extra data to the @@ -64,19 +63,6 @@ const RAW_HEADER_MSG_LEN: usize = STREAM_ID_LENGTH + Header::BYTES; const SNOW_CIPHERKEYLEN: usize = 32; const PUBLIC_KEYLEN: usize = 32; -/// Data for creating an initiator -#[derive(Debug)] -pub struct InitiatorConfig { - remote_public_key: [u8; PUBLIC_KEYLEN], -} - -impl InitiatorConfig { - /// Create a new [`InitiatorConfig`] - pub fn new(remote_public_key: [u8; PUBLIC_KEYLEN]) -> Self { - Self { remote_public_key } - } -} - /// Secret Stream protocol state #[derive(Debug)] pub struct SecStream { @@ -157,13 +143,16 @@ impl Debug for Ready { impl SecStream> { /// Create an initiator of a secret stream - pub fn new_initiator(config: InitiatorConfig, prologue: &[u8]) -> Result { + pub fn new_initiator( + remote_public_key: &[u8; PUBLIC_KEYLEN], + prologue: &[u8], + ) -> Result { let params: NoiseParams = PARAMS.parse().expect("known to work"); let kp = Builder::new(params.clone()).generate_keypair()?; let state = Builder::new(params.clone()) .local_private_key(&kp.private)? .prologue(prologue)? - .remote_public_key(&config.remote_public_key)? + .remote_public_key(remote_public_key.as_slice())? .build_initiator()?; Ok(Self { From 9f9fb838b70f300084eea39f854e2ed3e9251a2c Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 12 Aug 2025 23:25:45 -0400 Subject: [PATCH 173/206] add sm2 module --- src/sstream/mod.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sstream/mod.rs b/src/sstream/mod.rs index 3c229b0..40f56a7 100644 --- a/src/sstream/mod.rs +++ b/src/sstream/mod.rs @@ -46,6 +46,7 @@ assert_eq!(msg, b"three"); Ok::<(), Box>(()) ``` */ +mod sm2; mod statemachine; mod streamsink; From 3c5ec7fe5e6b75d9725996dc906c2f8de9e043aa Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 12 Aug 2025 23:26:32 -0400 Subject: [PATCH 174/206] statemachine wip --- src/sstream/statemachine.rs | 137 +++++++++++++++++++++++++++++------- 1 file changed, 110 insertions(+), 27 deletions(-) diff --git a/src/sstream/statemachine.rs b/src/sstream/statemachine.rs index f08ab6d..9d4c370 100644 --- a/src/sstream/statemachine.rs +++ b/src/sstream/statemachine.rs @@ -1,19 +1,19 @@ +#![allow(unused)] //! A state machine for the building secre stream. use std::mem::replace; use crypto_secretstream::Tag; -use super::{ - EncryptorReady, HsDone, HsMsgSent, Initiator, InitiatorConfig, Ready, Responder, SecStream, - Start, -}; -use crate::error::Error; +use super::{EncryptorReady, HsDone, HsMsgSent, Initiator, Ready, Responder, SecStream, Start}; +use crate::{error::Error, sstream::PUBLIC_KEYLEN}; #[derive(Debug)] enum State { InitiatorStart(SecStream>), - ResponderStart(SecStream>), InitiatorSent(SecStream>), + InitiatorHsDone(SecStream>), + ResponderStart(SecStream>), + ResponderHsDone(SecStream>), EncReady(SecStream), Ready(SecStream), Invalid, @@ -24,28 +24,63 @@ struct Manager { is_initiator: bool, } impl Manager { - pub fn read_msg(&mut self, msg: &[u8]) -> Result>, Error> { + pub(crate) fn read_msg2(&mut self, msg: &[u8]) -> Result, Error> { match replace(&mut self.state, State::Invalid) { State::ResponderStart(s) => { let (s2, payload) = s.read_msg(msg)?; + self.state = State::ResponderHsDone(s2); + Ok(payload) + } + State::InitiatorSent(s) => { + let (s2, payload) = s.read_msg(msg)?; + self.state = State::InitiatorHsDone(s2); + Ok(payload) + } + State::EncReady(s) => { + let _s2 = s.read_msg(msg)?; + todo!() + } + s => { + let msg = format!("Cannot start handshake in state: {s:?}"); + self.state = s; + Err(Error::InvalidHandshakeState(msg)) + } + } + } + // takes the received encrypted message, and the next payload that is to be encrypted and sent. + // returns the decrypted payload from the received encrypted message, and Vec of messages to + // be sent. + pub(crate) fn read_msg( + &mut self, + ciphertxt_msg_rx: &[u8], + plaintext_msg_tx: Option<&[u8]>, + ) -> Result<(Vec, Vec>), Error> { + match replace(&mut self.state, State::Invalid) { + State::ResponderStart(s) => { + let (s2, payload) = s.read_msg(ciphertxt_msg_rx)?; // TODO handle payload, pass something to write_msg - let (s3, two_msgs) = s2.write_msg(None)?; + let (s3, two_msgs) = s2.write_msg(plaintext_msg_tx)?; self.state = State::EncReady(s3); - Ok(two_msgs.to_vec()) + Ok((payload, two_msgs.to_vec())) } State::InitiatorSent(s) => { - let (s2, payload) = s.read_msg(msg)?; + let (s2, payload) = s.read_msg(ciphertxt_msg_rx)?; // handle payload - // TODO handle payload - let (s3, outmsg) = s2.write_msg()?; - // TODO if we want to add another msg, we can create it here with `let extra_out = s3.push(msg)` + let (mut s3, msg1) = s2.write_msg()?; + let out = if let Some(payload) = plaintext_msg_tx { + let mut msg2 = payload.to_vec(); + s3.push(&mut msg2, &[], Tag::Message)?; + vec![msg1, msg2] + } else { + vec![msg1] + }; self.state = State::EncReady(s3); // Update state - Ok(vec![outmsg]) + Ok((payload, out)) } State::EncReady(s) => { - let s2 = s.read_msg(msg)?; + let s2 = s.read_msg(ciphertxt_msg_rx)?; self.state = State::Ready(s2); // Update state - Ok(vec![]) + Ok((vec![], vec![])) } s => { let msg = format!("Cannot start handshake in state: {s:?}"); @@ -54,9 +89,12 @@ impl Manager { } } } + fn ready(&self) -> bool { + matches!(self.state, State::Ready(_)) + } /// Encrypt a message in place - pub fn push( + pub(crate) fn push( &mut self, msg: &mut Vec, associated_data: &[u8], @@ -72,7 +110,7 @@ impl Manager { }) } /// Decrypt a message in place - pub fn pull(&mut self, msg: &mut Vec, associated_data: &[u8]) -> Result { + pub(crate) fn pull(&mut self, msg: &mut Vec, associated_data: &[u8]) -> Result { Ok(match &mut self.state { State::Ready(s) => s.pull(msg, associated_data)?, s => { @@ -82,32 +120,77 @@ impl Manager { } }) } + /// Create new initiator - pub fn new_initiator(config: InitiatorConfig, prologue: &[u8]) -> Result { + pub(crate) fn new_initiator(state: SecStream>) -> Result { Ok(Self { - state: State::InitiatorStart(SecStream::new_initiator(config, prologue)?), + state: State::InitiatorStart(state), is_initiator: true, }) } /// Create new responder - pub fn new_responder(private: &[u8]) -> Result { + pub(crate) fn new_responder(state: SecStream>) -> Result { Ok(Self { - state: State::ResponderStart(SecStream::new_responder(private)?), + state: State::ResponderStart(state), is_initiator: false, }) } + /// Start handshake (initiator only) - pub fn start_handshake(&mut self, prologue: &[u8]) -> Result, Error> { + pub(crate) fn start_handshake(&mut self, payload: &[u8]) -> Result, Error> { match replace(&mut self.state, State::Invalid) { State::InitiatorStart(s) => { - let (new_s, msg) = s.write_msg(Some(prologue))?; + let (new_s, msg) = s.write_msg(Some(payload))?; let _ = replace(&mut self.state, State::InitiatorSent(new_s)); Ok(msg) } - _ => Err(Error::InvalidHandshakeState( - "Cannot start handshake in current state".into(), - )), + s => Err(Error::InvalidHandshakeState(format!( + "Cannot start handshake in current state {s:?}" + ))), } } } + +#[cfg(test)] +mod test { + use snow::Builder; + + use crate::sstream::PARAMS; + + use super::*; + + fn new_paired() -> (SecStream>, SecStream>) { + let kp = Builder::new(PARAMS.parse().unwrap()) + .generate_keypair() + .unwrap(); + ( + SecStream::new_initiator(&kp.public.try_into().unwrap(), &[]).unwrap(), + SecStream::new_responder(&kp.private).unwrap(), + ) + } + #[test] + fn test_read1() -> Result<(), Error> { + let (ini, res) = new_paired(); + let mut ini = Manager::new_initiator(ini)?; + let mut res = Manager::new_responder(res)?; + let msg = ini.start_handshake(&[])?; + let (payload, msgs) = res.read_msg(&msg, None)?; + assert!(payload.is_empty()); + let [one, two] = msgs.try_into().unwrap(); + let (payload, mut to_resp_msgs1) = ini.read_msg(&one, None)?; + assert!(to_resp_msgs1.len() == 1); + assert!(payload.is_empty()); + let (payload, to_resp_msgs2) = ini.read_msg(&two, None)?; + assert!(to_resp_msgs2.is_empty()); + assert!(payload.is_empty()); + + let (payload, msgs) = res.read_msg(&to_resp_msgs1.remove(0), None)?; + assert!(msgs.is_empty()); + assert!(payload.is_empty()); + assert!(res.ready()); + assert!(ini.ready()); + + Ok(()) + } +} From b0261b0371189b0caca5dfe24af55fbb58ab7057 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 12 Aug 2025 23:26:42 -0400 Subject: [PATCH 175/206] wip streamsink --- src/sstream/streamsink.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/sstream/streamsink.rs b/src/sstream/streamsink.rs index a8fa72c..776b8ee 100644 --- a/src/sstream/streamsink.rs +++ b/src/sstream/streamsink.rs @@ -1,10 +1,11 @@ +#![allow(unused)] use std::collections::VecDeque; use crate::error::Error; -pub enum Event {} +pub(crate) enum Event {} -pub struct Encrypted { +pub(crate) struct Encrypted { io: IO, step: (), is_initiator: bool, From 0aa80b33baaef2ca48d38c9993bf5c24c395fa45 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 12 Aug 2025 23:48:28 -0400 Subject: [PATCH 176/206] implment poll rx/tx --- src/error.rs | 6 ++++ src/sstream/sm2.rs | 81 +++++++++++++++++++++++++++++++++++++--------- 2 files changed, 72 insertions(+), 15 deletions(-) diff --git a/src/error.rs b/src/error.rs index c11a76f..5c35e58 100644 --- a/src/error.rs +++ b/src/error.rs @@ -6,6 +6,12 @@ pub enum Error { SecretStream(crypto_secretstream::aead::Error), #[error("Invalid Handshake State: {0}")] InvalidHandshakeState(String), + // TODO added by claude + #[error("Invalid Encryption Statemachine State: {0}")] + InvalidState(String), + // TODO added by claude + #[error("IoError: {0}")] + IoError(String), } impl From for Error { diff --git a/src/sstream/sm2.rs b/src/sstream/sm2.rs index c09ebf2..49397d4 100644 --- a/src/sstream/sm2.rs +++ b/src/sstream/sm2.rs @@ -1,7 +1,12 @@ -use std::{collections::VecDeque, mem::replace}; +use std::{ + collections::VecDeque, + mem::replace, + pin::Pin, + task::{Context, Poll}, +}; use crypto_secretstream::Tag; -use futures::{Sink, Stream}; +use futures::{Sink, Stream, StreamExt}; use crate::{ error::Error, @@ -40,6 +45,7 @@ pub(crate) struct Machine { impl Machine where IO: Stream, Error>> + Sink> + Send + Unpin + 'static, + >>::Error: std::fmt::Debug, { fn new_init(io: IO, state: SecStream>) -> Self { Self { @@ -84,7 +90,6 @@ where let (s2, payload) = s.read_msg(&msg?)?; // Ensure payload jumps to the front of the line self.plain_rx.push_front(Event::HandshakePayload(payload)); - // Ensure payload jumps to the front of the line let (s3, out) = s2.write_msg()?; self.encrypted_tx.push_front(out); self.state = State::EncReady(s3); @@ -112,23 +117,69 @@ where Ok(Some(())) } State::Ready(mut s) => { - match self.encrypted_rx.pop_front() { - Some(Ok(mut m)) => { - s.pull(&mut m, &[])?; - self.plain_rx.push_back(Event::Message(m)); + let mut made_progress = false; + + if let Some(encrypted_result) = self.encrypted_rx.pop_front() { + match encrypted_result { + Ok(mut encrypted_msg) => { + let _tag = s.pull(&mut encrypted_msg, &[])?; + self.plain_rx.push_back(Event::Message(encrypted_msg)); + made_progress = true; + } + Err(e) => return Err(e), } - Some(Err(_e)) => todo!(), - None => todo!(), } - if let Some(mut m) = self.plain_tx.pop_front() { - s.push(&mut m, &[], Tag::Message)?; - self.encrypted_tx.push_back(m); + + // send outgoing messages + if let Some(mut plain_msg) = self.plain_tx.pop_front() { + s.push(&mut plain_msg, &[], Tag::Message)?; + self.encrypted_tx.push_back(plain_msg); + made_progress = true; } - todo!() + self.state = State::Ready(s); + Ok(if made_progress { Some(()) } else { None }) + } + State::InitiatorStart(_sec_stream) => { + // not started yet... Error? + Ok(None) + } + State::Invalid => Err(Error::InvalidState("Invalid state".into())), + } + } + + /// pull in new incomming encrypted messages + fn poll_incoming_encrypted(&mut self, cx: &mut Context<'_>) -> Poll<()> { + while let Poll::Ready(Some(result)) = Pin::new(&mut self.io).poll_next(cx) { + self.encrypted_rx.push_back(result); + } + Poll::Ready(()) + } + + fn poll_outgoing_encrypted(&mut self, cx: &mut Context<'_>) -> Poll> { + while let Some(msg) = self.encrypted_tx.pop_front() { + match Pin::new(&mut self.io).poll_ready(cx) { + Poll::Ready(Ok(())) => { + if let Err(e) = Pin::new(&mut self.io).start_send(msg) { + return Poll::Ready(Err(Error::IoError(format!("Send failed: {:?}", e)))); + } + } + Poll::Ready(Err(e)) => { + return Poll::Ready(Err(Error::IoError(format!("IO error: {:?}", e)))); + } + Poll::Pending => { + self.encrypted_tx.push_front(msg); + return Poll::Pending; + } + } + } + + match Pin::new(&mut self.io).poll_flush(cx) { + Poll::Ready(Ok(())) => Poll::Ready(Ok(())), + Poll::Ready(Err(e)) => { + Poll::Ready(Err(Error::IoError(format!("Flush failed: {:?}", e)))) } - State::InitiatorStart(_sec_stream) => todo!(), - State::Invalid => todo!(), + Poll::Pending => Poll::Pending, } } } From bb9f605a87207851909959334a82c13995c30462 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 12 Aug 2025 23:55:24 -0400 Subject: [PATCH 177/206] impl Stream --- src/sstream/sm2.rs | 62 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 60 insertions(+), 2 deletions(-) diff --git a/src/sstream/sm2.rs b/src/sstream/sm2.rs index 49397d4..1a63a07 100644 --- a/src/sstream/sm2.rs +++ b/src/sstream/sm2.rs @@ -6,7 +6,7 @@ use std::{ }; use crypto_secretstream::Tag; -use futures::{Sink, Stream, StreamExt}; +use futures::{Sink, Stream}; use crate::{ error::Error, @@ -77,7 +77,7 @@ where self.state = State::InitiatorSent(s2); Ok(()) } - e => todo!(), + _e => todo!(), } } @@ -183,3 +183,61 @@ where } } } + +impl Stream for Machine +where + IO: Stream, Error>> + Sink> + Send + Unpin + 'static, + >>::Error: std::fmt::Debug, +{ + type Item = Result, Error>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + // 1. First, try to return any ready plaintext messages + while let Some(event) = self.plain_rx.pop_front() { + match event { + Event::Message(data) => { + return Poll::Ready(Some(Ok(data))); + } + Event::HandshakePayload(_) => { + // Skip handshake payloads - they're for internal use only + continue; + } + } + } + + // 2. Pull new encrypted data from IO into our queue + let _ = self.poll_incoming_encrypted(cx); + + // 3. Send any pending encrypted data to IO + match self.poll_outgoing_encrypted(cx) { + Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))), + Poll::Pending => { + // IO is busy, we can't make progress on sending + // but we can still try to process incoming messages + } + Poll::Ready(Ok(())) => { + // Successfully sent outgoing data + } + } + + // 4. Process crypto operations (handshake, encrypt/decrypt) + match self.poll_tx_rx() { + Ok(Some(())) => { + // Made progress, loop again to check for more work + continue; + } + Ok(None) => { + // No progress made, no more work available + break; + } + Err(e) => { + return Poll::Ready(Some(Err(e))); + } + } + } + + // No messages ready and no progress can be made + Poll::Pending + } +} From 63d51d0b4dc9e2e80766664b9a5131a3fcda39b1 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 12 Aug 2025 23:59:19 -0400 Subject: [PATCH 178/206] impl Sink --- src/sstream/sm2.rs | 85 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/src/sstream/sm2.rs b/src/sstream/sm2.rs index 1a63a07..2dfd08c 100644 --- a/src/sstream/sm2.rs +++ b/src/sstream/sm2.rs @@ -241,3 +241,88 @@ where Poll::Pending } } + +impl Sink> for Machine +where + IO: Stream, Error>> + Sink> + Send + Unpin + 'static, + >>::Error: std::fmt::Debug, +{ + type Error = Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // Process any pending work to make space in queues + self.poll_incoming_encrypted(cx); + + match self.poll_outgoing_encrypted(cx) { + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => { + // IO is busy, but we can still accept messages for queuing + return Poll::Ready(Ok(())); + } + Poll::Ready(Ok(())) => { + // IO is ready + } + } + + // Process crypto operations to make progress + match self.poll_tx_rx() { + Ok(_) => { + // Always ready to accept more plaintext messages for queuing + Poll::Ready(Ok(())) + } + Err(e) => Poll::Ready(Err(e)), + } + } + + fn start_send(mut self: Pin<&mut Self>, item: Vec) -> Result<(), Self::Error> { + // Queue the plaintext message for encryption + self.plain_tx.push_back(item); + Ok(()) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + // Process crypto operations to encrypt any pending plaintext + match self.poll_tx_rx() { + Ok(Some(())) => { + // Made progress, continue processing + continue; + } + Ok(None) => { + // No more crypto work to do + break; + } + Err(e) => return Poll::Ready(Err(e)), + } + } + + // Send any pending encrypted data to IO + match self.poll_outgoing_encrypted(cx) { + Poll::Ready(Ok(())) => { + // Check if we have any pending plaintext that hasn't been encrypted yet + if self.plain_tx.is_empty() { + Poll::Ready(Ok(())) + } else { + // Still have pending plaintext, not fully flushed + Poll::Pending + } + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + } + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // First flush any pending data + match self.as_mut().poll_flush(cx) { + Poll::Ready(Ok(())) => { + // Now close the underlying IO + Pin::new(&mut self.io) + .poll_close(cx) + .map_err(|e| Error::IoError(format!("Close failed: {:?}", e))) + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + } + } +} From 880b9fc0119106d3d831a6ef07209fbae081766e Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 13 Aug 2025 00:09:02 -0400 Subject: [PATCH 179/206] wip tests --- src/sstream/sm2.rs | 197 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 197 insertions(+) diff --git a/src/sstream/sm2.rs b/src/sstream/sm2.rs index 2dfd08c..fb9e2c9 100644 --- a/src/sstream/sm2.rs +++ b/src/sstream/sm2.rs @@ -326,3 +326,200 @@ where } } } + +#[cfg(test)] +mod tests { + use super::*; + use futures::{SinkExt, StreamExt, channel::mpsc}; + + // Mock IO that implements Stream + Sink for testing + #[derive(Debug)] + struct MockIo { + receiver: mpsc::UnboundedReceiver, Error>>, + sender: mpsc::UnboundedSender>, + } + + impl Stream for MockIo { + type Item = Result, Error>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.receiver).poll_next(cx) + } + } + + impl Sink> for MockIo { + type Error = Error; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: Vec) -> Result<(), Self::Error> { + self.sender + .unbounded_send(item) + .map_err(|_| Error::InvalidState("Send failed".into())) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + fn create_mock_io_pair() -> ( + MockIo, + mpsc::UnboundedSender, Error>>, + mpsc::UnboundedReceiver>, + ) { + let (io_tx, io_rx) = mpsc::unbounded(); + let (out_tx, out_rx) = mpsc::unbounded(); + + let mock_io = MockIo { + receiver: io_rx, + sender: out_tx, + }; + + (mock_io, io_tx, out_rx) + } + + #[tokio::test] + async fn test_machine_sink_queues_messages() -> Result<(), Error> { + let remote_key = [1u8; 32]; + let initiator_state = SecStream::new_initiator(&remote_key, &[])?; + + let (mock_io, _io_tx, _out_rx) = create_mock_io_pair(); + let mut machine = Machine::new_init(mock_io, initiator_state); + + // Test that we can send messages to the sink + let test_msg = b"Hello, World!".to_vec(); + machine.send(test_msg.clone()).await?; + + // Message should be queued in plain_tx + assert_eq!(machine.plain_tx.len(), 1); + assert_eq!(machine.plain_tx.front(), Some(&test_msg)); + + Ok(()) + } + + #[tokio::test] + async fn test_machine_sink_multiple_messages() -> Result<(), Error> { + let remote_key = [2u8; 32]; + let initiator_state = SecStream::new_initiator(&remote_key, &[])?; + + let (mock_io, _io_tx, _out_rx) = create_mock_io_pair(); + let mut machine = Machine::new_init(mock_io, initiator_state); + + // Send multiple messages + for i in 0..5 { + let msg = format!("Message {}", i).into_bytes(); + machine.send(msg).await?; + } + + // All messages should be queued + assert_eq!(machine.plain_tx.len(), 5); + + Ok(()) + } + + #[tokio::test] + async fn test_machine_stream_returns_pending_when_no_data() -> Result<(), Error> { + let remote_key = [3u8; 32]; + let initiator_state = SecStream::new_initiator(&remote_key, &[])?; + + let (mock_io, _io_tx, _out_rx) = create_mock_io_pair(); + let mut machine = Machine::new_init(mock_io, initiator_state); + + // Test that stream returns None when no data is available + let mut stream = Box::pin(&mut machine); + + // Use a timeout to ensure we don't wait forever + let result = + tokio::time::timeout(std::time::Duration::from_millis(100), stream.next()).await; + + // Should timeout because no data is available + assert!(result.is_err()); + + Ok(()) + } + + #[tokio::test] + async fn test_machine_handshake_start() -> Result<(), Error> { + let remote_key = [4u8; 32]; + let initiator_state = SecStream::new_initiator(&remote_key, &[])?; + + let (mock_io, _io_tx, mut out_rx) = create_mock_io_pair(); + let mut machine = Machine::new_init(mock_io, initiator_state); + + // Start handshake + let payload = b"handshake payload"; + machine.handshake_start(payload)?; + + // Should have transitioned to InitiatorSent state + assert!(matches!(machine.state, State::InitiatorSent(_))); + + // Should have queued encrypted handshake message + assert!(!machine.encrypted_tx.is_empty()); + + // Process outgoing to send the handshake message + let _result = futures::poll!(machine.poll_outgoing_encrypted( + &mut std::task::Context::from_waker(futures::task::noop_waker_ref()) + )); + + // Should have sent handshake message to IO + let sent_msg = out_rx.try_next().unwrap(); + assert!(sent_msg.is_some()); + + Ok(()) + } + + #[tokio::test] + async fn test_machine_ready_state_processing() -> Result<(), Error> { + // This test would require more complex setup to reach Ready state + // For now, test that we can create a machine in different states + + let remote_key = [5u8; 32]; + let initiator_state = SecStream::new_initiator(&remote_key, &[])?; + + let (mock_io, _io_tx, _out_rx) = create_mock_io_pair(); + let machine = Machine::new_init(mock_io, initiator_state); + + // Verify initial state + assert!(matches!(machine.state, State::InitiatorStart(_))); + assert!(machine.is_initiator); + assert!(machine.plain_tx.is_empty()); + assert!(machine.plain_rx.is_empty()); + + Ok(()) + } + + #[tokio::test] + async fn test_machine_poll_ready_always_succeeds() -> Result<(), Error> { + let remote_key = [6u8; 32]; + let initiator_state = SecStream::new_initiator(&remote_key, &[])?; + + let (mock_io, _io_tx, _out_rx) = create_mock_io_pair(); + let mut machine = Machine::new_init(mock_io, initiator_state); + + // poll_ready should always succeed since we queue internally + let mut sink = Box::pin(&mut machine); + let ready_result = futures::poll!(sink.as_mut().poll_ready( + &mut std::task::Context::from_waker(futures::task::noop_waker_ref()) + )); + + assert!(matches!(ready_result, Poll::Ready(Ok(())))); + + Ok(()) + } +} From c4818896a5b00b44e57244c88e42d22e792f2fff Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Wed, 13 Aug 2025 00:12:59 -0400 Subject: [PATCH 180/206] fix test errors --- src/sstream/sm2.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/sstream/sm2.rs b/src/sstream/sm2.rs index fb9e2c9..09ce1e6 100644 --- a/src/sstream/sm2.rs +++ b/src/sstream/sm2.rs @@ -473,9 +473,9 @@ mod tests { assert!(!machine.encrypted_tx.is_empty()); // Process outgoing to send the handshake message - let _result = futures::poll!(machine.poll_outgoing_encrypted( - &mut std::task::Context::from_waker(futures::task::noop_waker_ref()) - )); + let waker = futures::task::noop_waker(); + let mut cx = std::task::Context::from_waker(&waker); + let _result = machine.poll_outgoing_encrypted(&mut cx); // Should have sent handshake message to IO let sent_msg = out_rx.try_next().unwrap(); @@ -514,9 +514,9 @@ mod tests { // poll_ready should always succeed since we queue internally let mut sink = Box::pin(&mut machine); - let ready_result = futures::poll!(sink.as_mut().poll_ready( - &mut std::task::Context::from_waker(futures::task::noop_waker_ref()) - )); + let waker = futures::task::noop_waker(); + let mut cx = std::task::Context::from_waker(&waker); + let ready_result = sink.as_mut().poll_ready(&mut cx); assert!(matches!(ready_result, Poll::Ready(Ok(())))); From 321c4be6fbc36c55359704ef2176dbbbf9165eea Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 17 Aug 2025 01:11:45 -0400 Subject: [PATCH 181/206] Error add PartialEq --- src/error.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/error.rs b/src/error.rs index 5c35e58..a48fbb5 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,4 +1,4 @@ -#[derive(Debug, thiserror::Error)] +#[derive(Debug, thiserror::Error, PartialEq)] pub enum Error { #[error("Error from `snow`: {0}")] Snow(#[from] snow::Error), From 824d29d41f9dece3b8c11c96954e39c2c1c50104 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 17 Aug 2025 01:12:46 -0400 Subject: [PATCH 182/206] wip Machine Stream/Sink --- src/sstream/mod.rs | 28 +++++-- src/sstream/sm2.rs | 204 ++++++++++++++++++++++++++++++++++++--------- 2 files changed, 185 insertions(+), 47 deletions(-) diff --git a/src/sstream/mod.rs b/src/sstream/mod.rs index 40f56a7..b669d65 100644 --- a/src/sstream/mod.rs +++ b/src/sstream/mod.rs @@ -52,10 +52,10 @@ mod streamsink; use crypto_secretstream::{Header, Key, PullStream, PushStream, Tag}; use rand::rngs::OsRng; -use snow::{params::NoiseParams, Builder, HandshakeState}; +use snow::{Builder, HandshakeState, params::NoiseParams}; use std::{fmt::Debug, marker::PhantomData}; -use crate::{crypto::write_stream_id, Error}; +use crate::{Error, crypto::write_stream_id}; /// Default pattern pub const PARAMS: &str = "Noise_IK_25519_ChaChaPoly_BLAKE2b"; @@ -65,7 +65,6 @@ const SNOW_CIPHERKEYLEN: usize = 32; const PUBLIC_KEYLEN: usize = 32; /// Secret Stream protocol state -#[derive(Debug)] pub struct SecStream { is_initiator: bool, state: HandshakeState, @@ -73,15 +72,22 @@ pub struct SecStream { step: Step, } +impl std::fmt::Debug for SecStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SecStream") + .field("is_initiator", &self.is_initiator) + .field("state", &self.state) + .field("msg_buf", &"[...]") + .field("step", &self.step) + .finish() + } +} + impl SecStream { /// split handshake into (tx, rx) pub fn split_handshake(&mut self) -> ([u8; SNOW_CIPHERKEYLEN], [u8; SNOW_CIPHERKEYLEN]) { let (a, b) = self.state.dangerously_get_raw_split(); - if self.is_initiator { - (a, b) - } else { - (b, a) - } + if self.is_initiator { (a, b) } else { (b, a) } } } @@ -394,7 +400,10 @@ impl SecStream { let mut expected_stream_id: [u8; STREAM_ID_LENGTH] = [0; STREAM_ID_LENGTH]; write_stream_id(&handshake_hash, !is_initiator, &mut expected_stream_id); if expected_stream_id != msg[..STREAM_ID_LENGTH] { - panic!() + panic!( + "stream ID's don't match\n{expected_stream_id:?}\n != \n{:?}", + &msg[..STREAM_ID_LENGTH] + ); } let header: [u8; Header::BYTES] = @@ -408,6 +417,7 @@ impl SecStream { }) } + /// Encrypt a message in place pub fn push( &mut self, msg: &mut Vec, diff --git a/src/sstream/sm2.rs b/src/sstream/sm2.rs index 09ce1e6..7fdb0b5 100644 --- a/src/sstream/sm2.rs +++ b/src/sstream/sm2.rs @@ -1,5 +1,6 @@ use std::{ collections::VecDeque, + fmt::Debug, mem::replace, pin::Pin, task::{Context, Poll}, @@ -7,13 +8,13 @@ use std::{ use crypto_secretstream::Tag; use futures::{Sink, Stream}; +use tracing::{instrument, trace, warn}; use crate::{ error::Error, sstream::{EncryptorReady, HsMsgSent, Initiator, Ready, Responder, SecStream, Start}, }; -#[derive(Debug)] pub(crate) enum State { InitiatorStart(SecStream>), InitiatorSent(SecStream>), @@ -23,13 +24,29 @@ pub(crate) enum State { Invalid, } +impl Debug for State { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Self::InitiatorStart(_) => "InitiatorStart", + Self::InitiatorSent(_) => "InitiatorSent", + Self::RespStart(_) => "RespStart", + Self::EncReady(_) => "EncReady", + Self::Ready(_) => "Ready", + Self::Invalid => "Invalid", + } + ) + } +} + #[derive(Debug)] enum Event { HandshakePayload(Vec), Message(Vec), } -#[derive(Debug)] /// For each tx/rx VecDeque messages go in with `.push_back` then taken out with `.pop_front`. /// If a message should skip the line it should be inserted with `.push_front`. pub(crate) struct Machine { @@ -42,10 +59,23 @@ pub(crate) struct Machine { plain_rx: VecDeque, } +impl Debug for Machine { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Machine") + .field("io", &"") + .field("is_initiator", &self.is_initiator) + .field("state", &self.state) + .field("encrypted_tx", &self.encrypted_tx.len()) + .field("encrypted_rx", &self.encrypted_rx.len()) + .field("plain_tx", &self.plain_tx.len()) + .field("plain_rx", &self.plain_rx.len()) + .finish() + } +} + impl Machine where IO: Stream, Error>> + Sink> + Send + Unpin + 'static, - >>::Error: std::fmt::Debug, { fn new_init(io: IO, state: SecStream>) -> Self { Self { @@ -69,6 +99,7 @@ where plain_rx: Default::default(), } } + #[instrument(skip_all, err)] fn handshake_start(&mut self, payload: &[u8]) -> Result<(), Error> { match replace(&mut self.state, State::Invalid) { State::InitiatorStart(s) => { @@ -81,10 +112,21 @@ where } } - fn poll_tx_rx(&mut self) -> Result, Error> { + #[instrument(skip_all, err)] + fn poll_encrypt_decrypt(&mut self) -> Result, Error> { + trace!( + state =? self.state, + init = self.is_initiator, + plain_tx = self.plain_tx.len(), + plain_rx = self.plain_rx.len(), + enc_tx = self.encrypted_tx.len(), + enc_rx = self.encrypted_rx.len(), + ); + match replace(&mut self.state, State::Invalid) { State::InitiatorSent(s) => { let Some(msg) = self.encrypted_rx.pop_front() else { + self.state = State::InitiatorSent(s); return Ok(None); }; let (s2, payload) = s.read_msg(&msg?)?; @@ -97,6 +139,8 @@ where } State::RespStart(s) => { let Some(msg) = self.encrypted_rx.pop_front() else { + // Not ready + self.state = State::RespStart(s); return Ok(None); }; let (s2, payload) = s.read_msg(&msg?)?; @@ -104,13 +148,18 @@ where self.plain_rx.push_front(Event::HandshakePayload(payload)); let next_tx = self.plain_tx.pop_front(); let (s3, [msg1, msg2]) = s2.write_msg(next_tx.as_deref())?; - self.plain_tx.push_front(msg2); - self.plain_tx.push_front(msg1); + self.encrypted_tx.push_front(msg2); + self.encrypted_tx.push_front(msg1); self.state = State::EncReady(s3); Ok(Some(())) } - State::EncReady(s) => { + State::EncReady(mut s) => { + while let Some(mut msg) = self.plain_tx.pop_front() { + s.push(&mut msg, &[], Tag::Message)?; + self.encrypted_tx.push_back(msg); + } let Some(msg) = self.encrypted_rx.pop_front() else { + self.state = State::EncReady(s); return Ok(None); }; self.state = State::Ready(s.read_msg(&msg?)?); @@ -126,7 +175,10 @@ where self.plain_rx.push_back(Event::Message(encrypted_msg)); made_progress = true; } - Err(e) => return Err(e), + Err(e) => { + warn!("INVALID STATE FROM Ready"); + return Err(e); + } } } @@ -140,32 +192,45 @@ where self.state = State::Ready(s); Ok(if made_progress { Some(()) } else { None }) } - State::InitiatorStart(_sec_stream) => { - // not started yet... Error? - Ok(None) + State::InitiatorStart(s) => { + // no handshake message.. We use first thing in plain_tx, but maybe it should be an + // error bc we might want the payload to be handled explicitly + let payload = self.plain_tx.pop_front(); + let (s2, out) = s.write_msg(payload.as_deref())?; + self.encrypted_tx.push_back(out); + self.state = State::InitiatorSent(s2); + Ok(Some(())) } State::Invalid => Err(Error::InvalidState("Invalid state".into())), } } /// pull in new incomming encrypted messages + #[instrument(skip_all, fields(init = self.is_initiator))] fn poll_incoming_encrypted(&mut self, cx: &mut Context<'_>) -> Poll<()> { while let Poll::Ready(Some(result)) = Pin::new(&mut self.io).poll_next(cx) { + trace!("RX {result:?}"); self.encrypted_rx.push_back(result); } Poll::Ready(()) } + #[instrument(skip_all)] fn poll_outgoing_encrypted(&mut self, cx: &mut Context<'_>) -> Poll> { while let Some(msg) = self.encrypted_tx.pop_front() { match Pin::new(&mut self.io).poll_ready(cx) { Poll::Ready(Ok(())) => { + dbg!(); if let Err(e) = Pin::new(&mut self.io).start_send(msg) { - return Poll::Ready(Err(Error::IoError(format!("Send failed: {:?}", e)))); + return Poll::Ready(Err(Error::IoError(format!( + "Send failed: TODO handle Debug" + )))); } } Poll::Ready(Err(e)) => { - return Poll::Ready(Err(Error::IoError(format!("IO error: {:?}", e)))); + return Poll::Ready(Err(Error::IoError(format!( + "IO error: TODO dbg error here" + )))); } Poll::Pending => { self.encrypted_tx.push_front(msg); @@ -176,9 +241,9 @@ where match Pin::new(&mut self.io).poll_flush(cx) { Poll::Ready(Ok(())) => Poll::Ready(Ok(())), - Poll::Ready(Err(e)) => { - Poll::Ready(Err(Error::IoError(format!("Flush failed: {:?}", e)))) - } + Poll::Ready(Err(e)) => Poll::Ready(Err(Error::IoError(format!( + "Flush failed: TODO dbg error here" + )))), Poll::Pending => Poll::Pending, } } @@ -187,10 +252,10 @@ where impl Stream for Machine where IO: Stream, Error>> + Sink> + Send + Unpin + 'static, - >>::Error: std::fmt::Debug, { type Item = Result, Error>; + #[instrument(skip_all, fields(init = self.is_initiator))] fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { // 1. First, try to return any ready plaintext messages @@ -222,7 +287,7 @@ where } // 4. Process crypto operations (handshake, encrypt/decrypt) - match self.poll_tx_rx() { + match self.poll_encrypt_decrypt() { Ok(Some(())) => { // Made progress, loop again to check for more work continue; @@ -245,10 +310,10 @@ where impl Sink> for Machine where IO: Stream, Error>> + Sink> + Send + Unpin + 'static, - >>::Error: std::fmt::Debug, { type Error = Error; + #[instrument(skip_all)] fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { // Process any pending work to make space in queues self.poll_incoming_encrypted(cx); @@ -265,7 +330,7 @@ where } // Process crypto operations to make progress - match self.poll_tx_rx() { + match self.poll_encrypt_decrypt() { Ok(_) => { // Always ready to accept more plaintext messages for queuing Poll::Ready(Ok(())) @@ -274,16 +339,19 @@ where } } + #[instrument(skip_all)] fn start_send(mut self: Pin<&mut Self>, item: Vec) -> Result<(), Self::Error> { // Queue the plaintext message for encryption self.plain_tx.push_back(item); Ok(()) } + #[instrument(skip_all)] fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let _always_poll_ready_but_why = self.poll_incoming_encrypted(cx); loop { // Process crypto operations to encrypt any pending plaintext - match self.poll_tx_rx() { + match self.poll_encrypt_decrypt() { Ok(Some(())) => { // Made progress, continue processing continue; @@ -304,6 +372,7 @@ where Poll::Ready(Ok(())) } else { // Still have pending plaintext, not fully flushed + cx.waker().wake_by_ref(); Poll::Pending } } @@ -312,6 +381,7 @@ where } } + #[instrument(skip_all)] fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { // First flush any pending data match self.as_mut().poll_flush(cx) { @@ -319,7 +389,7 @@ where // Now close the underlying IO Pin::new(&mut self.io) .poll_close(cx) - .map_err(|e| Error::IoError(format!("Close failed: {:?}", e))) + .map_err(|e| Error::IoError(format!("Close failed TODO dbg err here"))) } Poll::Ready(Err(e)) => Poll::Ready(Err(e)), Poll::Pending => Poll::Pending, @@ -329,17 +399,21 @@ where #[cfg(test)] mod tests { + use crate::{sstream::PARAMS, test_utils::log}; + use super::*; - use futures::{SinkExt, StreamExt, channel::mpsc}; + use futures::{SinkExt, StreamExt, channel::mpsc, join}; + use snow::Builder; + use tracing::instrument; // Mock IO that implements Stream + Sink for testing #[derive(Debug)] - struct MockIo { - receiver: mpsc::UnboundedReceiver, Error>>, + struct MockIo, Error>>> { + receiver: S, sender: mpsc::UnboundedSender>, } - impl Stream for MockIo { + impl, Error>> + Unpin> Stream for MockIo { type Item = Result, Error>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -347,7 +421,7 @@ mod tests { } } - impl Sink> for MockIo { + impl, Error>>> Sink> for MockIo { type Error = Error; fn poll_ready( @@ -379,7 +453,7 @@ mod tests { } fn create_mock_io_pair() -> ( - MockIo, + MockIo, Error>>>, mpsc::UnboundedSender, Error>>, mpsc::UnboundedReceiver>, ) { @@ -394,21 +468,75 @@ mod tests { (mock_io, io_tx, out_rx) } - #[tokio::test] - async fn test_machine_sink_queues_messages() -> Result<(), Error> { - let remote_key = [1u8; 32]; - let initiator_state = SecStream::new_initiator(&remote_key, &[])?; + fn new_connected_secret_stream() -> (SecStream>, SecStream>) { + let kp = Builder::new(PARAMS.parse().unwrap()) + .generate_keypair() + .unwrap(); + ( + SecStream::new_initiator(&kp.public.try_into().unwrap(), &[]).unwrap(), + SecStream::new_responder(&kp.private).unwrap(), + ) + } - let (mock_io, _io_tx, _out_rx) = create_mock_io_pair(); - let mut machine = Machine::new_init(mock_io, initiator_state); + fn new_connected_streams() -> ( + impl Stream, Error>> + Sink> + Debug, + impl Stream, Error>> + Sink> + Debug, + ) { + let (left_tx, left_rx) = mpsc::unbounded(); + let res_left_rx = left_rx.map(|msg: Vec| Ok::<_, Error>(msg)); + + let (right_tx, right_rx) = mpsc::unbounded(); + let res_right_rx = right_rx.map(|msg: Vec| Ok::<_, Error>(msg)); + + let left = MockIo { + sender: left_tx, + receiver: res_right_rx, + }; + let right = MockIo { + sender: right_tx, + receiver: res_left_rx, + }; + (left, right) + } + + fn connected_machines() -> ( + Machine, Error>> + Sink> + Debug>, + Machine, Error>> + Sink> + Debug>, + ) { + let (lss, rss) = new_connected_secret_stream(); + let (lio, rio) = new_connected_streams(); + let (lm, rm) = (Machine::new_init(lio, lss), Machine::new_resp(rio, rss)); + (lm, rm) + } + + #[tokio::test] + async fn test_streams() -> Result<(), Error> { + let (mut l, mut r) = new_connected_streams(); + let (a, b) = join!(l.send(b"yo".to_vec()), r.next()); + assert!(a.is_ok()); + assert_eq!(b, Some(Ok(b"yo".to_vec()))); + + let (a, b) = join!(r.send(b"yo".to_vec()), l.next()); + assert!(a.is_ok()); + assert_eq!(b, Some(Ok(b"yo".to_vec()))); + Ok(()) + } + #[tokio::test] + async fn test_machine_io() -> Result<(), Error> { + let (mut lm, mut rm) = connected_machines(); - // Test that we can send messages to the sink let test_msg = b"Hello, World!".to_vec(); - machine.send(test_msg.clone()).await?; + log(); + lm.handshake_start(b"payload")?; + let lmfut = lm.send(b"next".to_vec()); + let rmfut = rm.next(); + let (lres, rres) = join!(lmfut, rmfut); // Message should be queued in plain_tx - assert_eq!(machine.plain_tx.len(), 1); - assert_eq!(machine.plain_tx.front(), Some(&test_msg)); + assert_eq!(rres, Some(Ok(b"next".into()))); + + //assert_eq!(machine.plain_tx.len(), 1); + //assert_eq!(machine.plain_tx.front(), Some(b"next".to_vec().as_ref())); Ok(()) } From 73d5bb201a9334ee5c684f3067cf4f9342d42944 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 17 Aug 2025 13:17:45 -0400 Subject: [PATCH 183/206] Stream should emit Event --- src/sstream/sm2.rs | 51 ++++++++++++++++++++-------------------------- 1 file changed, 22 insertions(+), 29 deletions(-) diff --git a/src/sstream/sm2.rs b/src/sstream/sm2.rs index 7fdb0b5..cf4e500 100644 --- a/src/sstream/sm2.rs +++ b/src/sstream/sm2.rs @@ -41,10 +41,11 @@ impl Debug for State { } } -#[derive(Debug)] -enum Event { +#[derive(Debug, PartialEq)] +pub(crate) enum Event { HandshakePayload(Vec), Message(Vec), + ErrStuff(Error), } /// For each tx/rx VecDeque messages go in with `.push_back` then taken out with `.pop_front`. @@ -220,16 +221,15 @@ where while let Some(msg) = self.encrypted_tx.pop_front() { match Pin::new(&mut self.io).poll_ready(cx) { Poll::Ready(Ok(())) => { - dbg!(); if let Err(e) = Pin::new(&mut self.io).start_send(msg) { return Poll::Ready(Err(Error::IoError(format!( - "Send failed: TODO handle Debug" + "Send failed: TODO Error should have fmt::Debug here" )))); } } Poll::Ready(Err(e)) => { return Poll::Ready(Err(Error::IoError(format!( - "IO error: TODO dbg error here" + "IO error: TODO Error should have fmt::Debug here" )))); } Poll::Pending => { @@ -242,7 +242,7 @@ where match Pin::new(&mut self.io).poll_flush(cx) { Poll::Ready(Ok(())) => Poll::Ready(Ok(())), Poll::Ready(Err(e)) => Poll::Ready(Err(Error::IoError(format!( - "Flush failed: TODO dbg error here" + "Flush failed: TODO Error should have fmt::Debug here" )))), Poll::Pending => Poll::Pending, } @@ -253,22 +253,14 @@ impl Stream for Machine where IO: Stream, Error>> + Sink> + Send + Unpin + 'static, { - type Item = Result, Error>; + type Item = Event; #[instrument(skip_all, fields(init = self.is_initiator))] fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { // 1. First, try to return any ready plaintext messages while let Some(event) = self.plain_rx.pop_front() { - match event { - Event::Message(data) => { - return Poll::Ready(Some(Ok(data))); - } - Event::HandshakePayload(_) => { - // Skip handshake payloads - they're for internal use only - continue; - } - } + return Poll::Ready(Some(event)); } // 2. Pull new encrypted data from IO into our queue @@ -276,7 +268,9 @@ where // 3. Send any pending encrypted data to IO match self.poll_outgoing_encrypted(cx) { - Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))), + Poll::Ready(Err(e)) => { + return Poll::Ready(Some(Event::ErrStuff(e))); + } Poll::Pending => { // IO is busy, we can't make progress on sending // but we can still try to process incoming messages @@ -297,7 +291,7 @@ where break; } Err(e) => { - return Poll::Ready(Some(Err(e))); + return Poll::Ready(Some(Event::ErrStuff(e))); } } } @@ -387,9 +381,11 @@ where match self.as_mut().poll_flush(cx) { Poll::Ready(Ok(())) => { // Now close the underlying IO - Pin::new(&mut self.io) - .poll_close(cx) - .map_err(|e| Error::IoError(format!("Close failed TODO dbg err here"))) + Pin::new(&mut self.io).poll_close(cx).map_err(|e| { + Error::IoError(format!( + "Close failed TODO Error should have fmt::debug here" + )) + }) } Poll::Ready(Err(e)) => Poll::Ready(Err(e)), Poll::Pending => Poll::Pending, @@ -525,15 +521,12 @@ mod tests { async fn test_machine_io() -> Result<(), Error> { let (mut lm, mut rm) = connected_machines(); - let test_msg = b"Hello, World!".to_vec(); - log(); - lm.handshake_start(b"payload")?; - let lmfut = lm.send(b"next".to_vec()); - let rmfut = rm.next(); - - let (lres, rres) = join!(lmfut, rmfut); + let payload = b"Hello, World!".to_vec(); + lm.handshake_start(&payload); + //let (lres, rres) = join!(lm.send(b"foo"), rm.next()); // TODO this hangs + let (lres, rres) = join!(lm.flush(), rm.next()); // Message should be queued in plain_tx - assert_eq!(rres, Some(Ok(b"next".into()))); + assert_eq!(rres, Some(Event::HandshakePayload(payload))); //assert_eq!(machine.plain_tx.len(), 1); //assert_eq!(machine.plain_tx.front(), Some(b"next".to_vec().as_ref())); From 6911901b3d81920cf752b4a2fc282c070245ee14 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 1 Sep 2025 15:09:10 -0400 Subject: [PATCH 184/206] Add Machine::complete_handshake function --- src/sstream/sm2.rs | 80 ++++++++++++++++++++++++++++++------- src/sstream/statemachine.rs | 5 ++- 2 files changed, 68 insertions(+), 17 deletions(-) diff --git a/src/sstream/sm2.rs b/src/sstream/sm2.rs index cf4e500..4cabf65 100644 --- a/src/sstream/sm2.rs +++ b/src/sstream/sm2.rs @@ -100,6 +100,23 @@ where plain_rx: Default::default(), } } + async fn complete_handshake(&mut self) -> Result<(), Error> { + use futures::SinkExt; + use futures::StreamExt; + + loop { + if !matches!(self.state, State::Ready(_)) { + self.send(vec![]).await?; + if matches!(self.state, State::Ready(_)) { + return Ok(()); + } + self.next().await; + } else { + return Ok(()); + } + } + } + #[instrument(skip_all, err)] fn handshake_start(&mut self, payload: &[u8]) -> Result<(), Error> { match replace(&mut self.state, State::Invalid) { @@ -210,7 +227,6 @@ where #[instrument(skip_all, fields(init = self.is_initiator))] fn poll_incoming_encrypted(&mut self, cx: &mut Context<'_>) -> Poll<()> { while let Poll::Ready(Some(result)) = Pin::new(&mut self.io).poll_next(cx) { - trace!("RX {result:?}"); self.encrypted_rx.push_back(result); } Poll::Ready(()) @@ -259,7 +275,7 @@ where fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { // 1. First, try to return any ready plaintext messages - while let Some(event) = self.plain_rx.pop_front() { + if let Some(event) = self.plain_rx.pop_front() { return Poll::Ready(Some(event)); } @@ -518,38 +534,72 @@ mod tests { Ok(()) } #[tokio::test] - async fn test_machine_io() -> Result<(), Error> { + async fn test_machine_io_l_to_r() -> Result<(), Error> { let (mut lm, mut rm) = connected_machines(); let payload = b"Hello, World!".to_vec(); lm.handshake_start(&payload); //let (lres, rres) = join!(lm.send(b"foo"), rm.next()); // TODO this hangs let (lres, rres) = join!(lm.flush(), rm.next()); - // Message should be queued in plain_tx assert_eq!(rres, Some(Event::HandshakePayload(payload))); + assert_eq!(lres, Ok(())); + Ok(()) + } + + #[tokio::test] + async fn test_machine_io_both_ways() -> Result<(), Error> { + let (mut lm, mut rm) = connected_machines(); - //assert_eq!(machine.plain_tx.len(), 1); - //assert_eq!(machine.plain_tx.front(), Some(b"next".to_vec().as_ref())); + let res = join!(lm.send(b"ltor".into()), rm.send(b"rtol".into())); + assert_eq!((res.0?, res.1?), ((), ())); + + let (Some(lr), Some(rr)) = join!(lm.next(), rm.next()) else { + panic!() + }; + assert_eq!( + (lr, rr), + ( + Event::HandshakePayload(vec![]), + Event::HandshakePayload(vec![]) + ), + ); + + let (Some(lr), Some(rr)) = join!(lm.next(), rm.next()) else { + panic!() + }; + assert_eq!( + (lr, rr), + ( + Event::Message(b"rtol".to_vec()), + Event::Message(b"ltor".to_vec()) + ) + ); Ok(()) } - #[tokio::test] async fn test_machine_sink_multiple_messages() -> Result<(), Error> { - let remote_key = [2u8; 32]; - let initiator_state = SecStream::new_initiator(&remote_key, &[])?; + let (mut lm, mut rm) = connected_machines(); - let (mock_io, _io_tx, _out_rx) = create_mock_io_pair(); - let mut machine = Machine::new_init(mock_io, initiator_state); + let (rl, rr) = join!(lm.complete_handshake(), rm.complete_handshake()); + rl?; + rr?; - // Send multiple messages + let mut msgs = vec![]; for i in 0..5 { let msg = format!("Message {}", i).into_bytes(); - machine.send(msg).await?; + msgs.push(msg.clone()); + lm.send(msg).await?; } - // All messages should be queued - assert_eq!(machine.plain_tx.len(), 5); + let mut results = vec![]; + for _ in 0..5 { + let Event::Message(m) = rm.next().await.unwrap() else { + panic!(); + }; + results.push(m); + } + assert_eq!(results, msgs); Ok(()) } diff --git a/src/sstream/statemachine.rs b/src/sstream/statemachine.rs index 9d4c370..3bf6b9b 100644 --- a/src/sstream/statemachine.rs +++ b/src/sstream/statemachine.rs @@ -100,14 +100,15 @@ impl Manager { associated_data: &[u8], tag: Tag, ) -> Result<(), Error> { - Ok(match &mut self.state { + match &mut self.state { State::Ready(s) => s.push(msg, associated_data, tag)?, s => { return Err(Error::InvalidHandshakeState(format!( "Cannot Encrypt a message while in state: {s:?}" ))); } - }) + }; + Ok(()) } /// Decrypt a message in place pub(crate) fn pull(&mut self, msg: &mut Vec, associated_data: &[u8]) -> Result { From 114c622c47091201f3076212b34ddfdd154390f4 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 1 Sep 2025 15:14:34 -0400 Subject: [PATCH 185/206] Docs, TODO, etc --- src/sstream/sm2.rs | 36 ++++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/src/sstream/sm2.rs b/src/sstream/sm2.rs index 4cabf65..35c2862 100644 --- a/src/sstream/sm2.rs +++ b/src/sstream/sm2.rs @@ -131,6 +131,8 @@ where } #[instrument(skip_all, err)] + /// Encrypt outgoing messages, and decrypt encomming messages. + /// This also processes messages to complete the handshake. fn poll_encrypt_decrypt(&mut self) -> Result, Error> { trace!( state =? self.state, @@ -194,13 +196,12 @@ where made_progress = true; } Err(e) => { - warn!("INVALID STATE FROM Ready"); - return Err(e); + return todo!("How should we handle an error in receiving a message?"); } } } - // send outgoing messages + // encrypt outgoing messages if let Some(mut plain_msg) = self.plain_tx.pop_front() { s.push(&mut plain_msg, &[], Tag::Message)?; self.encrypted_tx.push_back(plain_msg); @@ -238,15 +239,15 @@ where match Pin::new(&mut self.io).poll_ready(cx) { Poll::Ready(Ok(())) => { if let Err(e) = Pin::new(&mut self.io).start_send(msg) { - return Poll::Ready(Err(Error::IoError(format!( - "Send failed: TODO Error should have fmt::Debug here" - )))); + return Poll::Ready(Err(Error::IoError( + "Send failed: TODO Error should have fmt::Debug here".into(), + ))); } } Poll::Ready(Err(e)) => { - return Poll::Ready(Err(Error::IoError(format!( - "IO error: TODO Error should have fmt::Debug here" - )))); + return Poll::Ready(Err(Error::IoError( + "IO error: TODO Error should have fmt::Debug here".into(), + ))); } Poll::Pending => { self.encrypted_tx.push_front(msg); @@ -257,9 +258,9 @@ where match Pin::new(&mut self.io).poll_flush(cx) { Poll::Ready(Ok(())) => Poll::Ready(Ok(())), - Poll::Ready(Err(e)) => Poll::Ready(Err(Error::IoError(format!( - "Flush failed: TODO Error should have fmt::Debug here" - )))), + Poll::Ready(Err(e)) => Poll::Ready(Err(Error::IoError( + "Flush failed: TODO Error should have fmt::Debug here".into(), + ))), Poll::Pending => Poll::Pending, } } @@ -521,6 +522,17 @@ mod tests { (lm, rm) } + #[tokio::test] + async fn test_complete_handshake() -> Result<(), Error> { + let (mut lm, mut rm) = connected_machines(); + let (rl, rr) = join!(lm.complete_handshake(), rm.complete_handshake()); + rl?; + rr?; + assert!(matches!(lm.state, State::Ready(_))); + assert!(matches!(rm.state, State::Ready(_))); + Ok(()) + } + #[tokio::test] async fn test_streams() -> Result<(), Error> { let (mut l, mut r) = new_connected_streams(); From 909c91a2d369320e2c59ce9d30b92c7b53c57be7 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 1 Sep 2025 18:04:27 -0400 Subject: [PATCH 186/206] split out SansIoMachine --- src/sstream/sm2.rs | 229 ++++++++++++++++++++++++++++++--------------- 1 file changed, 156 insertions(+), 73 deletions(-) diff --git a/src/sstream/sm2.rs b/src/sstream/sm2.rs index 35c2862..025a1df 100644 --- a/src/sstream/sm2.rs +++ b/src/sstream/sm2.rs @@ -1,3 +1,4 @@ +#[allow(unused)] use std::{ collections::VecDeque, fmt::Debug, @@ -41,17 +42,8 @@ impl Debug for State { } } -#[derive(Debug, PartialEq)] -pub(crate) enum Event { - HandshakePayload(Vec), - Message(Vec), - ErrStuff(Error), -} - -/// For each tx/rx VecDeque messages go in with `.push_back` then taken out with `.pop_front`. -/// If a message should skip the line it should be inserted with `.push_front`. -pub(crate) struct Machine { - io: IO, +/// Like [`Machine`] but no IO +struct SansIoMachine { is_initiator: bool, state: State, encrypted_tx: VecDeque>, @@ -60,27 +52,9 @@ pub(crate) struct Machine { plain_rx: VecDeque, } -impl Debug for Machine { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Machine") - .field("io", &"") - .field("is_initiator", &self.is_initiator) - .field("state", &self.state) - .field("encrypted_tx", &self.encrypted_tx.len()) - .field("encrypted_rx", &self.encrypted_rx.len()) - .field("plain_tx", &self.plain_tx.len()) - .field("plain_rx", &self.plain_rx.len()) - .finish() - } -} - -impl Machine -where - IO: Stream, Error>> + Sink> + Send + Unpin + 'static, -{ - fn new_init(io: IO, state: SecStream>) -> Self { +impl SansIoMachine { + fn new_init(state: SecStream>) -> Self { Self { - io, state: State::InitiatorStart(state), is_initiator: true, encrypted_tx: Default::default(), @@ -89,9 +63,9 @@ where plain_rx: Default::default(), } } - fn new_resp(io: IO, state: SecStream>) -> Self { + + fn new_resp(state: SecStream>) -> Self { Self { - io, state: State::RespStart(state), is_initiator: false, encrypted_tx: Default::default(), @@ -100,22 +74,6 @@ where plain_rx: Default::default(), } } - async fn complete_handshake(&mut self) -> Result<(), Error> { - use futures::SinkExt; - use futures::StreamExt; - - loop { - if !matches!(self.state, State::Ready(_)) { - self.send(vec![]).await?; - if matches!(self.state, State::Ready(_)) { - return Ok(()); - } - self.next().await; - } else { - return Ok(()); - } - } - } #[instrument(skip_all, err)] fn handshake_start(&mut self, payload: &[u8]) -> Result<(), Error> { @@ -174,13 +132,15 @@ where Ok(Some(())) } State::EncReady(mut s) => { + let mut made_progress = false; while let Some(mut msg) = self.plain_tx.pop_front() { s.push(&mut msg, &[], Tag::Message)?; self.encrypted_tx.push_back(msg); + made_progress = true; } let Some(msg) = self.encrypted_rx.pop_front() else { self.state = State::EncReady(s); - return Ok(None); + return Ok(made_progress.then_some(())); }; self.state = State::Ready(s.read_msg(&msg?)?); Ok(Some(())) @@ -195,9 +155,7 @@ where self.plain_rx.push_back(Event::Message(encrypted_msg)); made_progress = true; } - Err(e) => { - return todo!("How should we handle an error in receiving a message?"); - } + Err(_e) => todo!("How should we handle an error in receiving a message?"), } } @@ -224,33 +182,158 @@ where } } + fn get_next_msg_to_send(&mut self) -> Result>, Error> { + if let Some(out) = self.encrypted_tx.pop_front() { + return Ok(Some(out)); + } + while let Some(()) = self.poll_encrypt_decrypt()? { + if let Some(out) = self.encrypted_tx.pop_front() { + return Ok(Some(out)); + } + } + Ok(None) + } + + fn receive_next(&mut self, encrypted_msg: Vec) { + self.encrypted_rx.push_back(Ok(encrypted_msg)); + } + + fn queue_msg(&mut self, msg: Vec) { + self.plain_tx.push_back(msg); + } + + fn next_decrypted_message(&mut self) -> Result, Error> { + if let Some(out) = self.plain_rx.pop_front() { + return Ok(Some(out)); + } + while let Some(()) = self.poll_encrypt_decrypt()? { + if let Some(out) = self.plain_rx.pop_front() { + return Ok(Some(out)); + } + } + Ok(None) + } + + fn ready(&self) -> bool { + matches!(self.state, State::Ready(_)) + } +} +impl Debug for SansIoMachine { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SansIoMachine") + .field("is_initiator", &self.is_initiator) + .field("state", &self.state) + .field("encrypted_tx", &self.encrypted_tx.len()) + .field("encrypted_rx", &self.encrypted_rx.len()) + .field("plain_tx", &self.plain_tx.len()) + .field("plain_rx", &self.plain_rx.len()) + .finish() + } +} + +#[derive(Debug, PartialEq)] +pub(crate) enum Event { + HandshakePayload(Vec), + Message(Vec), + ErrStuff(Error), +} + +/// For each tx/rx VecDeque messages go in with `.push_back` then taken out with `.pop_front`. +/// If a message should skip the line it should be inserted with `.push_front`. +pub(crate) struct Machine { + io: IO, + inner: SansIoMachine, +} + +impl Debug for Machine { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Machine") + .field("io", &"") + .field("inner", &self.inner) + .finish() + } +} + +impl Machine +where + IO: Stream, Error>> + Sink> + Send + Unpin + 'static, +{ + fn new_init(io: IO, state: SecStream>) -> Self { + Self { + io, + inner: SansIoMachine::new_init(state), + } + } + fn new_resp(io: IO, state: SecStream>) -> Self { + Self { + io, + inner: SansIoMachine::new_resp(state), + } + } + async fn complete_handshake(&mut self) -> Result<(), Error> { + use futures::SinkExt; + use futures::StreamExt; + + loop { + if !self.inner.ready() { + self.send(vec![]).await?; + if self.inner.ready() { + return Ok(()); + } + let _ = self.next().await; + } else { + return Ok(()); + } + } + } + + #[instrument(skip_all, err)] + fn handshake_start(&mut self, payload: &[u8]) -> Result<(), Error> { + self.inner.handshake_start(payload) + } + + #[instrument(skip_all, err)] + /// Encrypt outgoing messages, and decrypt encomming messages. + /// This also processes messages to complete the handshake. + fn poll_encrypt_decrypt(&mut self) -> Result, Error> { + trace!( + state =? self.inner.state, + init = self.inner.is_initiator, + plain_tx = self.inner.plain_tx.len(), + plain_rx = self.inner.plain_rx.len(), + enc_tx = self.inner.encrypted_tx.len(), + enc_rx = self.inner.encrypted_rx.len(), + ); + self.inner.poll_encrypt_decrypt() + } + /// pull in new incomming encrypted messages - #[instrument(skip_all, fields(init = self.is_initiator))] + #[instrument(skip_all, fields(init = self.inner.is_initiator))] fn poll_incoming_encrypted(&mut self, cx: &mut Context<'_>) -> Poll<()> { while let Poll::Ready(Some(result)) = Pin::new(&mut self.io).poll_next(cx) { - self.encrypted_rx.push_back(result); + self.inner.encrypted_rx.push_back(result); } Poll::Ready(()) } #[instrument(skip_all)] fn poll_outgoing_encrypted(&mut self, cx: &mut Context<'_>) -> Poll> { - while let Some(msg) = self.encrypted_tx.pop_front() { + while let Some(msg) = self.inner.encrypted_tx.pop_front() { match Pin::new(&mut self.io).poll_ready(cx) { Poll::Ready(Ok(())) => { - if let Err(e) = Pin::new(&mut self.io).start_send(msg) { + if let Err(_e) = Pin::new(&mut self.io).start_send(msg) { return Poll::Ready(Err(Error::IoError( "Send failed: TODO Error should have fmt::Debug here".into(), ))); } } - Poll::Ready(Err(e)) => { + Poll::Ready(Err(_e)) => { return Poll::Ready(Err(Error::IoError( "IO error: TODO Error should have fmt::Debug here".into(), ))); } Poll::Pending => { - self.encrypted_tx.push_front(msg); + self.inner.encrypted_tx.push_front(msg); return Poll::Pending; } } @@ -258,7 +341,7 @@ where match Pin::new(&mut self.io).poll_flush(cx) { Poll::Ready(Ok(())) => Poll::Ready(Ok(())), - Poll::Ready(Err(e)) => Poll::Ready(Err(Error::IoError( + Poll::Ready(Err(_e)) => Poll::Ready(Err(Error::IoError( "Flush failed: TODO Error should have fmt::Debug here".into(), ))), Poll::Pending => Poll::Pending, @@ -272,11 +355,11 @@ where { type Item = Event; - #[instrument(skip_all, fields(init = self.is_initiator))] + #[instrument(skip_all, fields(init = self.inner.is_initiator))] fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { // 1. First, try to return any ready plaintext messages - if let Some(event) = self.plain_rx.pop_front() { + if let Some(event) = self.inner.plain_rx.pop_front() { return Poll::Ready(Some(event)); } @@ -353,7 +436,7 @@ where #[instrument(skip_all)] fn start_send(mut self: Pin<&mut Self>, item: Vec) -> Result<(), Self::Error> { // Queue the plaintext message for encryption - self.plain_tx.push_back(item); + self.inner.plain_tx.push_back(item); Ok(()) } @@ -379,7 +462,7 @@ where match self.poll_outgoing_encrypted(cx) { Poll::Ready(Ok(())) => { // Check if we have any pending plaintext that hasn't been encrypted yet - if self.plain_tx.is_empty() { + if self.inner.plain_tx.is_empty() { Poll::Ready(Ok(())) } else { // Still have pending plaintext, not fully flushed @@ -398,7 +481,7 @@ where match self.as_mut().poll_flush(cx) { Poll::Ready(Ok(())) => { // Now close the underlying IO - Pin::new(&mut self.io).poll_close(cx).map_err(|e| { + Pin::new(&mut self.io).poll_close(cx).map_err(|_e| { Error::IoError(format!( "Close failed TODO Error should have fmt::debug here" )) @@ -528,8 +611,8 @@ mod tests { let (rl, rr) = join!(lm.complete_handshake(), rm.complete_handshake()); rl?; rr?; - assert!(matches!(lm.state, State::Ready(_))); - assert!(matches!(rm.state, State::Ready(_))); + assert!(lm.inner.ready()); + assert!(rm.inner.ready()); Ok(()) } @@ -650,10 +733,10 @@ mod tests { machine.handshake_start(payload)?; // Should have transitioned to InitiatorSent state - assert!(matches!(machine.state, State::InitiatorSent(_))); + assert!(matches!(machine.inner.state, State::InitiatorSent(_))); // Should have queued encrypted handshake message - assert!(!machine.encrypted_tx.is_empty()); + assert!(!machine.inner.encrypted_tx.is_empty()); // Process outgoing to send the handshake message let waker = futures::task::noop_waker(); @@ -679,10 +762,10 @@ mod tests { let machine = Machine::new_init(mock_io, initiator_state); // Verify initial state - assert!(matches!(machine.state, State::InitiatorStart(_))); - assert!(machine.is_initiator); - assert!(machine.plain_tx.is_empty()); - assert!(machine.plain_rx.is_empty()); + assert!(matches!(machine.inner.state, State::InitiatorStart(_))); + assert!(machine.inner.is_initiator); + assert!(machine.inner.plain_tx.is_empty()); + assert!(machine.inner.plain_rx.is_empty()); Ok(()) } From e5895c19697e9bc09d093e45e4d48e893d820bfe Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 1 Sep 2025 22:41:56 -0400 Subject: [PATCH 187/206] Add vectorized methods for SansIo --- src/sstream/sm2.rs | 52 ++++++++++++++++++++++++++++++---------------- 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/src/sstream/sm2.rs b/src/sstream/sm2.rs index 025a1df..ec5ae6c 100644 --- a/src/sstream/sm2.rs +++ b/src/sstream/sm2.rs @@ -182,16 +182,23 @@ impl SansIoMachine { } } - fn get_next_msg_to_send(&mut self) -> Result>, Error> { - if let Some(out) = self.encrypted_tx.pop_front() { - return Ok(Some(out)); + /// Do as much work as possible encrypting plaintext and decrypting ciphertext + fn poll_all_enc_dec(&mut self) -> Result, Error> { + let mut made_progress = false; + while self.poll_encrypt_decrypt()?.is_some() { + made_progress = true; } - while let Some(()) = self.poll_encrypt_decrypt()? { - if let Some(out) = self.encrypted_tx.pop_front() { - return Ok(Some(out)); - } - } - Ok(None) + Ok(made_progress.then_some(())) + } + + fn get_sendable_messages(&mut self) -> Result>, Error> { + self.poll_all_enc_dec()?; + Ok(self.encrypted_tx.drain(..).collect()) + } + + fn receive_next_messages(&mut self, encrypted_messages: Vec>) { + self.encrypted_rx + .extend(encrypted_messages.into_iter().map(|x| Ok(x))); } fn receive_next(&mut self, encrypted_msg: Vec) { @@ -203,21 +210,15 @@ impl SansIoMachine { } fn next_decrypted_message(&mut self) -> Result, Error> { - if let Some(out) = self.plain_rx.pop_front() { - return Ok(Some(out)); - } - while let Some(()) = self.poll_encrypt_decrypt()? { - if let Some(out) = self.plain_rx.pop_front() { - return Ok(Some(out)); - } - } - Ok(None) + self.poll_all_enc_dec()?; + Ok(self.plain_rx.pop_front()) } fn ready(&self) -> bool { matches!(self.state, State::Ready(_)) } } + impl Debug for SansIoMachine { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SansIoMachine") @@ -605,6 +606,21 @@ mod tests { (lm, rm) } + #[test] + fn sans_io() -> Result<(), Error> { + let (lss, rss) = new_connected_secret_stream(); + let (mut l, mut r) = (SansIoMachine::new_init(lss), SansIoMachine::new_resp(rss)); + + r.receive_next_messages(l.get_sendable_messages()?); + l.receive_next_messages(r.get_sendable_messages()?); + r.receive_next_messages(l.get_sendable_messages()?); + assert!(l.ready()); + l.receive_next_messages(r.get_sendable_messages()?); + assert!(r.ready()); + //assert!(dbg!(l.next_decrypted_message()?).is_none()); + Ok(()) + } + #[tokio::test] async fn test_complete_handshake() -> Result<(), Error> { let (mut lm, mut rm) = connected_machines(); From cf96bcde90cd050d091ffae4b234228fd17beb9a Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Tue, 2 Sep 2025 00:48:53 -0400 Subject: [PATCH 188/206] Add constructors, Box IO, and make IO optional --- src/error.rs | 3 + src/lib.rs | 6 +- src/sstream/mod.rs | 2 +- src/sstream/sm2.rs | 145 +++++++++++++++++++++++++++++---------------- 4 files changed, 100 insertions(+), 56 deletions(-) diff --git a/src/error.rs b/src/error.rs index a48fbb5..d58e2d6 100644 --- a/src/error.rs +++ b/src/error.rs @@ -12,6 +12,9 @@ pub enum Error { // TODO added by claude #[error("IoError: {0}")] IoError(String), + // Missing IO + #[error("No IO available")] + NoIoError, } impl From for Error { diff --git a/src/lib.rs b/src/lib.rs index e22713a..168400e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -141,12 +141,12 @@ mod util; /// The wire messages used by the protocol. pub mod schema; -use error::Error; +pub use error::Error; pub use builder::Builder as ProtocolBuilder; pub use channels::Channel; pub use framing::Uint24LELengthPrefixedFraming; -pub use noise::{encrypted_framed_message_channel, Encrypted, Event as NoiseEvent}; +pub use noise::{Encrypted, Event as NoiseEvent, encrypted_framed_message_channel}; // Export the needed types for Channel::take_receiver, and Channel::local_sender() pub use async_channel::{ Receiver as ChannelReceiver, SendError as ChannelSendError, Sender as ChannelSender, @@ -157,4 +157,4 @@ pub use message::Message; pub use protocol::{Command, CommandTx, DiscoveryKey, Event, Key, Protocol}; pub use util::discovery_key; // Export DHT-related crypto functionality -pub use crypto::{handshake_constants, DecryptCipher, EncryptCipher, Handshake, HandshakeConfig}; +pub use crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeConfig, handshake_constants}; diff --git a/src/sstream/mod.rs b/src/sstream/mod.rs index b669d65..b774989 100644 --- a/src/sstream/mod.rs +++ b/src/sstream/mod.rs @@ -46,7 +46,7 @@ assert_eq!(msg, b"three"); Ok::<(), Box>(()) ``` */ -mod sm2; +pub mod sm2; mod statemachine; mod streamsink; diff --git a/src/sstream/sm2.rs b/src/sstream/sm2.rs index ec5ae6c..4689cbf 100644 --- a/src/sstream/sm2.rs +++ b/src/sstream/sm2.rs @@ -13,7 +13,9 @@ use tracing::{instrument, trace, warn}; use crate::{ error::Error, - sstream::{EncryptorReady, HsMsgSent, Initiator, Ready, Responder, SecStream, Start}, + sstream::{ + EncryptorReady, HsMsgSent, Initiator, PUBLIC_KEYLEN, Ready, Responder, SecStream, Start, + }, }; pub(crate) enum State { @@ -44,7 +46,6 @@ impl Debug for State { /// Like [`Machine`] but no IO struct SansIoMachine { - is_initiator: bool, state: State, encrypted_tx: VecDeque>, encrypted_rx: VecDeque, Error>>, @@ -53,10 +54,18 @@ struct SansIoMachine { } impl SansIoMachine { + fn new(state: State) -> Self { + Self { + state, + encrypted_tx: Default::default(), + encrypted_rx: Default::default(), + plain_tx: Default::default(), + plain_rx: Default::default(), + } + } fn new_init(state: SecStream>) -> Self { Self { state: State::InitiatorStart(state), - is_initiator: true, encrypted_tx: Default::default(), encrypted_rx: Default::default(), plain_tx: Default::default(), @@ -67,7 +76,6 @@ impl SansIoMachine { fn new_resp(state: SecStream>) -> Self { Self { state: State::RespStart(state), - is_initiator: false, encrypted_tx: Default::default(), encrypted_rx: Default::default(), plain_tx: Default::default(), @@ -94,7 +102,6 @@ impl SansIoMachine { fn poll_encrypt_decrypt(&mut self) -> Result, Error> { trace!( state =? self.state, - init = self.is_initiator, plain_tx = self.plain_tx.len(), plain_rx = self.plain_rx.len(), enc_tx = self.encrypted_tx.len(), @@ -222,7 +229,6 @@ impl SansIoMachine { impl Debug for SansIoMachine { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SansIoMachine") - .field("is_initiator", &self.is_initiator) .field("state", &self.state) .field("encrypted_tx", &self.encrypted_tx.len()) .field("encrypted_rx", &self.encrypted_rx.len()) @@ -233,20 +239,29 @@ impl Debug for SansIoMachine { } #[derive(Debug, PartialEq)] -pub(crate) enum Event { +pub enum Event { HandshakePayload(Vec), Message(Vec), ErrStuff(Error), } +pub trait MachineIo: + Stream, Error>> + Sink> + Send + Unpin + 'static +{ +} + +impl, Error>> + Sink> + Send + Unpin + 'static> MachineIo + for T +{ +} /// For each tx/rx VecDeque messages go in with `.push_back` then taken out with `.pop_front`. /// If a message should skip the line it should be inserted with `.push_front`. -pub(crate) struct Machine { - io: IO, +pub struct Machine { + io: Option>>, inner: SansIoMachine, } -impl Debug for Machine { +impl Debug for Machine { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Machine") .field("io", &"") @@ -255,22 +270,46 @@ impl Debug for Machine { } } -impl Machine -where - IO: Stream, Error>> + Sink> + Send + Unpin + 'static, -{ - fn new_init(io: IO, state: SecStream>) -> Self { +impl Machine { + pub fn new_dht_init( + io: Option>>, + remote_pub_key: &[u8; PUBLIC_KEYLEN], + prologue: &[u8], + ) -> Result { + let ss = SecStream::new_initiator(remote_pub_key, prologue)?; + let state = State::InitiatorStart(ss); + let inner = SansIoMachine::new(state); + Ok(Self { io, inner }) + } + + fn new_dht_resp( + io: Option>>, + private: &[u8], + ) -> Result { + let ss = SecStream::new_responder(private)?; + let state = State::RespStart(ss); + let inner = SansIoMachine::new(state); + Ok(Self { io, inner }) + } + + fn new(io: Option>>, inner: SansIoMachine) -> Self { + Self { io, inner } + } + + fn new_init(io: Box>, state: SecStream>) -> Self { Self { - io, + io: Some(io), inner: SansIoMachine::new_init(state), } } - fn new_resp(io: IO, state: SecStream>) -> Self { + + fn new_resp(io: Box>, state: SecStream>) -> Self { Self { - io, + io: Some(io), inner: SansIoMachine::new_resp(state), } } + async fn complete_handshake(&mut self) -> Result<(), Error> { use futures::SinkExt; use futures::StreamExt; @@ -293,13 +332,19 @@ where self.inner.handshake_start(payload) } + fn get_io(&mut self) -> Result<&mut Box>, Error> { + if let Some(io) = self.io.as_mut() { + return Ok(io); + } + Err(Error::NoIoError) + } + #[instrument(skip_all, err)] /// Encrypt outgoing messages, and decrypt encomming messages. /// This also processes messages to complete the handshake. fn poll_encrypt_decrypt(&mut self) -> Result, Error> { trace!( state =? self.inner.state, - init = self.inner.is_initiator, plain_tx = self.inner.plain_tx.len(), plain_rx = self.inner.plain_rx.len(), enc_tx = self.inner.encrypted_tx.len(), @@ -309,9 +354,11 @@ where } /// pull in new incomming encrypted messages - #[instrument(skip_all, fields(init = self.inner.is_initiator))] + #[instrument(skip_all)] fn poll_incoming_encrypted(&mut self, cx: &mut Context<'_>) -> Poll<()> { - while let Poll::Ready(Some(result)) = Pin::new(&mut self.io).poll_next(cx) { + while let Poll::Ready(Some(result)) = + Pin::new(&mut self.get_io().expect("Missing IO")).poll_next(cx) + { self.inner.encrypted_rx.push_back(result); } Poll::Ready(()) @@ -320,9 +367,11 @@ where #[instrument(skip_all)] fn poll_outgoing_encrypted(&mut self, cx: &mut Context<'_>) -> Poll> { while let Some(msg) = self.inner.encrypted_tx.pop_front() { - match Pin::new(&mut self.io).poll_ready(cx) { + match Pin::new(&mut self.get_io().unwrap()).poll_ready(cx) { Poll::Ready(Ok(())) => { - if let Err(_e) = Pin::new(&mut self.io).start_send(msg) { + if let Err(_e) = + Pin::new(&mut self.get_io().expect("Missing IO")).start_send(msg) + { return Poll::Ready(Err(Error::IoError( "Send failed: TODO Error should have fmt::Debug here".into(), ))); @@ -340,7 +389,7 @@ where } } - match Pin::new(&mut self.io).poll_flush(cx) { + match Pin::new(&mut self.get_io().expect("Missing IO")).poll_flush(cx) { Poll::Ready(Ok(())) => Poll::Ready(Ok(())), Poll::Ready(Err(_e)) => Poll::Ready(Err(Error::IoError( "Flush failed: TODO Error should have fmt::Debug here".into(), @@ -350,13 +399,10 @@ where } } -impl Stream for Machine -where - IO: Stream, Error>> + Sink> + Send + Unpin + 'static, -{ +impl Stream for Machine { type Item = Event; - #[instrument(skip_all, fields(init = self.inner.is_initiator))] + #[instrument(skip_all)] fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { // 1. First, try to return any ready plaintext messages @@ -402,10 +448,7 @@ where } } -impl Sink> for Machine -where - IO: Stream, Error>> + Sink> + Send + Unpin + 'static, -{ +impl Sink> for Machine { type Error = Error; #[instrument(skip_all)] @@ -482,11 +525,13 @@ where match self.as_mut().poll_flush(cx) { Poll::Ready(Ok(())) => { // Now close the underlying IO - Pin::new(&mut self.io).poll_close(cx).map_err(|_e| { - Error::IoError(format!( - "Close failed TODO Error should have fmt::debug here" - )) - }) + Pin::new(&mut self.get_io().expect("Missing IO")) + .poll_close(cx) + .map_err(|_e| { + Error::IoError(format!( + "Close failed TODO Error should have fmt::debug here" + )) + }) } Poll::Ready(Err(e)) => Poll::Ready(Err(e)), Poll::Pending => Poll::Pending, @@ -575,10 +620,7 @@ mod tests { ) } - fn new_connected_streams() -> ( - impl Stream, Error>> + Sink> + Debug, - impl Stream, Error>> + Sink> + Debug, - ) { + fn new_connected_streams() -> (impl MachineIo, impl MachineIo) { let (left_tx, left_rx) = mpsc::unbounded(); let res_left_rx = left_rx.map(|msg: Vec| Ok::<_, Error>(msg)); @@ -596,13 +638,13 @@ mod tests { (left, right) } - fn connected_machines() -> ( - Machine, Error>> + Sink> + Debug>, - Machine, Error>> + Sink> + Debug>, - ) { + fn connected_machines() -> (Machine, Machine) { let (lss, rss) = new_connected_secret_stream(); let (lio, rio) = new_connected_streams(); - let (lm, rm) = (Machine::new_init(lio, lss), Machine::new_resp(rio, rss)); + let (lm, rm) = ( + Machine::new_init(Box::new(lio), lss), + Machine::new_resp(Box::new(rio), rss), + ); (lm, rm) } @@ -721,7 +763,7 @@ mod tests { let initiator_state = SecStream::new_initiator(&remote_key, &[])?; let (mock_io, _io_tx, _out_rx) = create_mock_io_pair(); - let mut machine = Machine::new_init(mock_io, initiator_state); + let mut machine = Machine::new_init(Box::new(mock_io), initiator_state); // Test that stream returns None when no data is available let mut stream = Box::pin(&mut machine); @@ -742,7 +784,7 @@ mod tests { let initiator_state = SecStream::new_initiator(&remote_key, &[])?; let (mock_io, _io_tx, mut out_rx) = create_mock_io_pair(); - let mut machine = Machine::new_init(mock_io, initiator_state); + let mut machine = Machine::new_init(Box::new(mock_io), initiator_state); // Start handshake let payload = b"handshake payload"; @@ -775,11 +817,10 @@ mod tests { let initiator_state = SecStream::new_initiator(&remote_key, &[])?; let (mock_io, _io_tx, _out_rx) = create_mock_io_pair(); - let machine = Machine::new_init(mock_io, initiator_state); + let machine = Machine::new_init(Box::new(mock_io), initiator_state); // Verify initial state assert!(matches!(machine.inner.state, State::InitiatorStart(_))); - assert!(machine.inner.is_initiator); assert!(machine.inner.plain_tx.is_empty()); assert!(machine.inner.plain_rx.is_empty()); @@ -792,7 +833,7 @@ mod tests { let initiator_state = SecStream::new_initiator(&remote_key, &[])?; let (mock_io, _io_tx, _out_rx) = create_mock_io_pair(); - let mut machine = Machine::new_init(mock_io, initiator_state); + let mut machine = Machine::new_init(Box::new(mock_io), initiator_state); // poll_ready should always succeed since we queue internally let mut sink = Box::pin(&mut machine); From 5d910cf22a07ab3a0ee0eac06967edafe3f6ca09 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sat, 6 Sep 2025 17:31:36 -0400 Subject: [PATCH 189/206] rust fmt --- src/test_utils.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test_utils.rs b/src/test_utils.rs index 8afa83e..add405a 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -6,8 +6,8 @@ use std::{ }; use futures::{ - channel::mpsc::{unbounded, UnboundedReceiver as Receiver, UnboundedSender as Sender}, Sink, Stream, StreamExt, + channel::mpsc::{UnboundedReceiver as Receiver, UnboundedSender as Sender, unbounded}, }; #[derive(Debug)] @@ -77,7 +77,7 @@ pub(crate) fn log() { static START_LOGS: std::sync::OnceLock<()> = std::sync::OnceLock::new(); START_LOGS.get_or_init(|| { use tracing_subscriber::{ - layer::SubscriberExt as _, util::SubscriberInitExt as _, EnvFilter, + EnvFilter, layer::SubscriberExt as _, util::SubscriberInitExt as _, }; let env_filter = EnvFilter::from_default_env(); // Reads `RUST_LOG` environment variable From b44c07293e2bfd7a8c416a5cf01933ec5a3f93dc Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 7 Sep 2025 14:11:23 -0400 Subject: [PATCH 190/206] lint --- src/crypto/curve.rs | 4 ++-- src/crypto/handshake.rs | 12 ++++++------ src/crypto/mod.rs | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/crypto/curve.rs b/src/crypto/curve.rs index e71d582..4b1f482 100644 --- a/src/crypto/curve.rs +++ b/src/crypto/curve.rs @@ -1,4 +1,4 @@ -use hypercore::{generate_signing_key, SecretKey, SigningKey, VerifyingKey}; +use hypercore::{SecretKey, SigningKey, VerifyingKey, generate_signing_key}; use sha2::Digest; use snow::{ params::{CipherChoice, DHChoice, HashChoice}, @@ -78,7 +78,7 @@ impl Dh for Ed25519 { } #[derive(Default)] -pub(super) struct CurveResolver; +pub struct CurveResolver; impl CryptoResolver for CurveResolver { fn resolve_dh(&self, choice: &DHChoice) -> Option> { diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index 8c7ccfd..8fdc27d 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -1,13 +1,13 @@ use super::curve::CurveResolver; use blake2::{ - digest::{typenum::U32, FixedOutput, Update}, Blake2bMac, + digest::{FixedOutput, Update, typenum::U32}, }; use handshake_constants::name_from_pattern; use snow::{ + Builder, Error as SnowError, HandshakeState, params::HandshakePattern, resolvers::{DefaultResolver, FallbackResolver}, - Builder, Error as SnowError, HandshakeState, }; use std::io::{Error, ErrorKind, Result}; use tracing::instrument; @@ -282,7 +282,7 @@ fn new_handshake_state( let hs_name = name_from_pattern(&config.pattern)?; - let noise_params = NoiseParams::new( + let params = NoiseParams::new( hs_name.to_string(), BaseChoice::Noise, HandshakeChoice { @@ -295,7 +295,7 @@ fn new_handshake_state( ); let builder: Builder<'_> = Builder::with_resolver( - noise_params, + params, Box::new(FallbackResolver::new( Box::::default(), Box::::default(), @@ -319,7 +319,7 @@ fn new_handshake_state( } } - let handshake_state = if is_initiator { + let state = if is_initiator { tracing::debug!("building initiator with pattern {:?}", config.pattern); builder.build_initiator()? } else { @@ -327,7 +327,7 @@ fn new_handshake_state( builder.build_responder()? }; - Ok((handshake_state, key_pair.public)) + Ok((state, key_pair.public)) } fn map_err(e: SnowError) -> Error { diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index af87688..38fd25e 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -4,4 +4,4 @@ mod handshake; pub(crate) use cipher::write_stream_id; pub use cipher::{DecryptCipher, EncryptCipher}; pub(crate) use handshake::HandshakeResult; -pub use handshake::{handshake_constants, Handshake, HandshakeConfig}; +pub use handshake::{Handshake, HandshakeConfig, handshake_constants}; From 1bbda8bde57f5af07514e97ebd53bec7f971ffdb Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 7 Sep 2025 16:06:32 -0400 Subject: [PATCH 191/206] rm PartialEq on Error --- src/error.rs | 8 +++++--- src/framing.rs | 2 +- src/sstream/sm2.rs | 30 +++++++++++------------------- 3 files changed, 17 insertions(+), 23 deletions(-) diff --git a/src/error.rs b/src/error.rs index d58e2d6..d7dbf0c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,4 +1,4 @@ -#[derive(Debug, thiserror::Error, PartialEq)] +#[derive(Debug, thiserror::Error)] pub enum Error { #[error("Error from `snow`: {0}")] Snow(#[from] snow::Error), @@ -13,8 +13,10 @@ pub enum Error { #[error("IoError: {0}")] IoError(String), // Missing IO - #[error("No IO available")] - NoIoError, + #[error("Machine IO is not set.")] + NoIoSetError, + #[error("{0}")] + StdIoError(#[from] std::io::Error), } impl From for Error { diff --git a/src/framing.rs b/src/framing.rs index 760b7c6..fcb0e48 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -197,7 +197,7 @@ where } #[cfg(test)] pub(crate) mod test { - use crate::{test_utils::log, Duplex}; + use crate::{Duplex, test_utils::log}; use super::*; use futures::{SinkExt, StreamExt}; diff --git a/src/sstream/sm2.rs b/src/sstream/sm2.rs index 4689cbf..5ac74e6 100644 --- a/src/sstream/sm2.rs +++ b/src/sstream/sm2.rs @@ -238,7 +238,7 @@ impl Debug for SansIoMachine { } } -#[derive(Debug, PartialEq)] +#[derive(Debug)] pub enum Event { HandshakePayload(Vec), Message(Vec), @@ -679,11 +679,11 @@ mod tests { let (mut l, mut r) = new_connected_streams(); let (a, b) = join!(l.send(b"yo".to_vec()), r.next()); assert!(a.is_ok()); - assert_eq!(b, Some(Ok(b"yo".to_vec()))); + assert_eq!(b.unwrap()?, b"yo".to_vec()); let (a, b) = join!(r.send(b"yo".to_vec()), l.next()); assert!(a.is_ok()); - assert_eq!(b, Some(Ok(b"yo".to_vec()))); + assert_eq!(b.unwrap()?, b"yo".to_vec()); Ok(()) } #[tokio::test] @@ -694,8 +694,8 @@ mod tests { lm.handshake_start(&payload); //let (lres, rres) = join!(lm.send(b"foo"), rm.next()); // TODO this hangs let (lres, rres) = join!(lm.flush(), rm.next()); - assert_eq!(rres, Some(Event::HandshakePayload(payload))); - assert_eq!(lres, Ok(())); + assert!(matches!(rres, Some(Event::HandshakePayload(payload)))); + assert_eq!(lres?, ()); Ok(()) } @@ -709,24 +709,16 @@ mod tests { let (Some(lr), Some(rr)) = join!(lm.next(), rm.next()) else { panic!() }; - assert_eq!( - (lr, rr), - ( - Event::HandshakePayload(vec![]), - Event::HandshakePayload(vec![]) - ), - ); + + let (empty, rtol, ltor): (Vec, _, _) = (vec![], b"rtol".to_vec(), b"ltor".to_vec()); + assert!(matches!(lr, Event::HandshakePayload(empty))); + assert!(matches!(rr, Event::HandshakePayload(empty))); let (Some(lr), Some(rr)) = join!(lm.next(), rm.next()) else { panic!() }; - assert_eq!( - (lr, rr), - ( - Event::Message(b"rtol".to_vec()), - Event::Message(b"ltor".to_vec()) - ) - ); + assert!(matches!(lr, Event::Message(rtol))); + assert!(matches!(rr, Event::Message(ltor))); Ok(()) } From 31eca88b27de943e16479f7d23a648b8892cab43 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 7 Sep 2025 16:06:46 -0400 Subject: [PATCH 192/206] add ext_test_utils --- src/lib.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 168400e..15527da 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -158,3 +158,15 @@ pub use protocol::{Command, CommandTx, DiscoveryKey, Event, Key, Protocol}; pub use util::discovery_key; // Export DHT-related crypto functionality pub use crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeConfig, handshake_constants}; + +pub mod ext_test_utils { + use crate::sstream::PARAMS; + use snow::{Builder, Keypair}; + + pub fn new_key_pair() -> Keypair { + let kp = Builder::new(PARAMS.parse().unwrap()) + .generate_keypair() + .unwrap(); + kp + } +} From de3f0a7d4dd49cdff3d9e73c374d3f34d4fec005 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 7 Sep 2025 16:07:09 -0400 Subject: [PATCH 193/206] Add methods for Machine manual tx/rx handling --- src/sstream/mod.rs | 7 ++++--- src/sstream/sm2.rs | 24 ++++++++++++++++++++++-- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/src/sstream/mod.rs b/src/sstream/mod.rs index b774989..c5d48af 100644 --- a/src/sstream/mod.rs +++ b/src/sstream/mod.rs @@ -148,6 +148,8 @@ impl Debug for Ready { } } +pub const DHT_PARAM_STR: &str = "Noise_IK_Ed25519_ChaChaPoly_BLAKE2b"; + impl SecStream> { /// Create an initiator of a secret stream pub fn new_initiator( @@ -155,10 +157,10 @@ impl SecStream> { prologue: &[u8], ) -> Result { let params: NoiseParams = PARAMS.parse().expect("known to work"); - let kp = Builder::new(params.clone()).generate_keypair()?; + let key_pair = Builder::new(params.clone()).generate_keypair()?; let state = Builder::new(params.clone()) - .local_private_key(&kp.private)? .prologue(prologue)? + .local_private_key(&key_pair.private)? .remote_public_key(remote_public_key.as_slice())? .build_initiator()?; @@ -171,7 +173,6 @@ impl SecStream> { }, }) } - /// Create the first message the initiator sends to the responder pub fn write_msg( mut self, diff --git a/src/sstream/sm2.rs b/src/sstream/sm2.rs index 5ac74e6..b8f2f6c 100644 --- a/src/sstream/sm2.rs +++ b/src/sstream/sm2.rs @@ -202,6 +202,10 @@ impl SansIoMachine { self.poll_all_enc_dec()?; Ok(self.encrypted_tx.drain(..).collect()) } + fn get_next_sendable_message(&mut self) -> Result>, Error> { + self.poll_all_enc_dec()?; + Ok(self.encrypted_tx.pop_front()) + } fn receive_next_messages(&mut self, encrypted_messages: Vec>) { self.encrypted_rx @@ -328,15 +332,27 @@ impl Machine { } #[instrument(skip_all, err)] - fn handshake_start(&mut self, payload: &[u8]) -> Result<(), Error> { + pub fn handshake_start(&mut self, payload: &[u8]) -> Result<(), Error> { self.inner.handshake_start(payload) } + pub fn get_next_sendable_message(&mut self) -> Result>, Error> { + self.inner.get_next_sendable_message() + } + pub fn receive_next(&mut self, encrypted_msg: Vec) { + self.inner.receive_next(encrypted_msg) + } + pub fn next_decrypted_message(&mut self) -> Result, Error> { + self.inner.next_decrypted_message() + } fn get_io(&mut self) -> Result<&mut Box>, Error> { if let Some(io) = self.io.as_mut() { return Ok(io); } - Err(Error::NoIoError) + Err(Error::NoIoSetError) + } + pub fn set_io(&mut self, io: impl MachineIo) { + self.io = Some(Box::new(io)); } #[instrument(skip_all, err)] @@ -397,6 +413,10 @@ impl Machine { Poll::Pending => Poll::Pending, } } + + pub fn ready(&self) -> bool { + self.inner.ready() + } } impl Stream for Machine { From 2a559c6e0bd765c2b24b6c58051727468e9c1ee2 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 7 Sep 2025 18:42:59 -0400 Subject: [PATCH 194/206] WIP refactoring Machine to use Sink::Error = std::io::Error cargo check works, cargo c --all-targets fails in tests. MockIo is fucked. --- src/sstream/sm2.rs | 58 ++++++++++++++++++++++++++++------------------ 1 file changed, 36 insertions(+), 22 deletions(-) diff --git a/src/sstream/sm2.rs b/src/sstream/sm2.rs index b8f2f6c..692676c 100644 --- a/src/sstream/sm2.rs +++ b/src/sstream/sm2.rs @@ -254,14 +254,16 @@ pub trait MachineIo: { } -impl, Error>> + Sink> + Send + Unpin + 'static> MachineIo - for T +impl MachineIo for T +where + T: Stream, Error>> + Sink> + Send + Unpin + 'static, + >>::Error: Into + std::fmt::Debug, { } /// For each tx/rx VecDeque messages go in with `.push_back` then taken out with `.pop_front`. /// If a message should skip the line it should be inserted with `.push_front`. pub struct Machine { - io: Option>>, + io: Option>>, inner: SansIoMachine, } @@ -276,7 +278,7 @@ impl Debug for Machine { impl Machine { pub fn new_dht_init( - io: Option>>, + io: Option>>, remote_pub_key: &[u8; PUBLIC_KEYLEN], prologue: &[u8], ) -> Result { @@ -287,7 +289,7 @@ impl Machine { } fn new_dht_resp( - io: Option>>, + io: Option>>, private: &[u8], ) -> Result { let ss = SecStream::new_responder(private)?; @@ -296,18 +298,24 @@ impl Machine { Ok(Self { io, inner }) } - fn new(io: Option>>, inner: SansIoMachine) -> Self { + fn new(io: Option>>, inner: SansIoMachine) -> Self { Self { io, inner } } - fn new_init(io: Box>, state: SecStream>) -> Self { + fn new_init( + io: Box>, + state: SecStream>, + ) -> Self { Self { io: Some(io), inner: SansIoMachine::new_init(state), } } - fn new_resp(io: Box>, state: SecStream>) -> Self { + fn new_resp( + io: Box>, + state: SecStream>, + ) -> Self { Self { io: Some(io), inner: SansIoMachine::new_resp(state), @@ -345,14 +353,14 @@ impl Machine { pub fn next_decrypted_message(&mut self) -> Result, Error> { self.inner.next_decrypted_message() } - fn get_io(&mut self) -> Result<&mut Box>, Error> { + fn get_io(&mut self) -> Result<&mut Box>, Error> { if let Some(io) = self.io.as_mut() { return Ok(io); } Err(Error::NoIoSetError) } - pub fn set_io(&mut self, io: impl MachineIo) { - self.io = Some(Box::new(io)); + pub fn set_io(&mut self, io: Box>) { + self.io = Some(io); } #[instrument(skip_all, err)] @@ -570,7 +578,7 @@ mod tests { // Mock IO that implements Stream + Sink for testing #[derive(Debug)] - struct MockIo, Error>>> { + struct MockIo, std::io::Error>>> { receiver: S, sender: mpsc::UnboundedSender>, } @@ -583,8 +591,8 @@ mod tests { } } - impl, Error>>> Sink> for MockIo { - type Error = Error; + impl, std::io::Error>>> Sink> for MockIo { + type Error = std::io::Error; fn poll_ready( self: Pin<&mut Self>, @@ -594,9 +602,12 @@ mod tests { } fn start_send(self: Pin<&mut Self>, item: Vec) -> Result<(), Self::Error> { - self.sender - .unbounded_send(item) - .map_err(|_| Error::InvalidState("Send failed".into())) + self.sender.unbounded_send(item).map_err(|_| { + std::io::Error::new( + std::io::ErrorKind::Other, + Error::InvalidState("Send failed".into()), + ) + }) } fn poll_flush( @@ -615,8 +626,8 @@ mod tests { } fn create_mock_io_pair() -> ( - MockIo, Error>>>, - mpsc::UnboundedSender, Error>>, + MockIo, std::io::Error>>>, + mpsc::UnboundedSender, std::io::Error>>, mpsc::UnboundedReceiver>, ) { let (io_tx, io_rx) = mpsc::unbounded(); @@ -640,12 +651,15 @@ mod tests { ) } - fn new_connected_streams() -> (impl MachineIo, impl MachineIo) { + fn new_connected_streams() -> ( + impl MachineIo, + impl MachineIo, + ) { let (left_tx, left_rx) = mpsc::unbounded(); - let res_left_rx = left_rx.map(|msg: Vec| Ok::<_, Error>(msg)); + let res_left_rx = left_rx.map(|msg: Vec| Ok::<_, std::io::Error>(msg)); let (right_tx, right_rx) = mpsc::unbounded(); - let res_right_rx = right_rx.map(|msg: Vec| Ok::<_, Error>(msg)); + let res_right_rx = right_rx.map(|msg: Vec| Ok::<_, std::io::Error>(msg)); let left = MockIo { sender: left_tx, From 648ee0071f57c7f546aac9ff8cd51793d412490a Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sun, 7 Sep 2025 23:01:26 -0400 Subject: [PATCH 195/206] Replace errors in Machine with std::io::Error --- src/error.rs | 6 ++++ src/sstream/sm2.rs | 74 ++++++++++++++++++++++++---------------------- 2 files changed, 45 insertions(+), 35 deletions(-) diff --git a/src/error.rs b/src/error.rs index d7dbf0c..e12a932 100644 --- a/src/error.rs +++ b/src/error.rs @@ -24,3 +24,9 @@ impl From for Error { Error::SecretStream(value) } } + +impl From for std::io::Error { + fn from(value: Error) -> Self { + std::io::Error::new(std::io::ErrorKind::Other, value) + } +} diff --git a/src/sstream/sm2.rs b/src/sstream/sm2.rs index 692676c..fe1b466 100644 --- a/src/sstream/sm2.rs +++ b/src/sstream/sm2.rs @@ -2,6 +2,7 @@ use std::{ collections::VecDeque, fmt::Debug, + io::{Error as IoError, ErrorKind}, mem::replace, pin::Pin, task::{Context, Poll}, @@ -18,6 +19,10 @@ use crate::{ }, }; +fn err(msg: &str) -> IoError { + IoError::new(ErrorKind::Other, Error::IoError(msg.to_string())) +} + pub(crate) enum State { InitiatorStart(SecStream>), InitiatorSent(SecStream>), @@ -48,7 +53,7 @@ impl Debug for State { struct SansIoMachine { state: State, encrypted_tx: VecDeque>, - encrypted_rx: VecDeque, Error>>, + encrypted_rx: VecDeque, std::io::Error>>, plain_tx: VecDeque>, plain_rx: VecDeque, } @@ -84,7 +89,7 @@ impl SansIoMachine { } #[instrument(skip_all, err)] - fn handshake_start(&mut self, payload: &[u8]) -> Result<(), Error> { + fn handshake_start(&mut self, payload: &[u8]) -> Result<(), std::io::Error> { match replace(&mut self.state, State::Invalid) { State::InitiatorStart(s) => { let (s2, out) = s.write_msg(Some(payload))?; @@ -99,7 +104,7 @@ impl SansIoMachine { #[instrument(skip_all, err)] /// Encrypt outgoing messages, and decrypt encomming messages. /// This also processes messages to complete the handshake. - fn poll_encrypt_decrypt(&mut self) -> Result, Error> { + fn poll_encrypt_decrypt(&mut self) -> Result, std::io::Error> { trace!( state =? self.state, plain_tx = self.plain_tx.len(), @@ -185,12 +190,15 @@ impl SansIoMachine { self.state = State::InitiatorSent(s2); Ok(Some(())) } - State::Invalid => Err(Error::InvalidState("Invalid state".into())), + State::Invalid => Err(IoError::new( + ErrorKind::Other, + Error::InvalidState("Invalid state".into()), + )), } } /// Do as much work as possible encrypting plaintext and decrypting ciphertext - fn poll_all_enc_dec(&mut self) -> Result, Error> { + fn poll_all_enc_dec(&mut self) -> Result, IoError> { let mut made_progress = false; while self.poll_encrypt_decrypt()?.is_some() { made_progress = true; @@ -198,11 +206,11 @@ impl SansIoMachine { Ok(made_progress.then_some(())) } - fn get_sendable_messages(&mut self) -> Result>, Error> { + fn get_sendable_messages(&mut self) -> Result>, IoError> { self.poll_all_enc_dec()?; Ok(self.encrypted_tx.drain(..).collect()) } - fn get_next_sendable_message(&mut self) -> Result>, Error> { + fn get_next_sendable_message(&mut self) -> Result>, IoError> { self.poll_all_enc_dec()?; Ok(self.encrypted_tx.pop_front()) } @@ -220,7 +228,7 @@ impl SansIoMachine { self.plain_tx.push_back(msg); } - fn next_decrypted_message(&mut self) -> Result, Error> { + fn next_decrypted_message(&mut self) -> Result, IoError> { self.poll_all_enc_dec()?; Ok(self.plain_rx.pop_front()) } @@ -246,17 +254,17 @@ impl Debug for SansIoMachine { pub enum Event { HandshakePayload(Vec), Message(Vec), - ErrStuff(Error), + ErrStuff(IoError), } pub trait MachineIo: - Stream, Error>> + Sink> + Send + Unpin + 'static + Stream, IoError>> + Sink> + Send + Unpin + 'static { } impl MachineIo for T where - T: Stream, Error>> + Sink> + Send + Unpin + 'static, + T: Stream, IoError>> + Sink> + Send + Unpin + 'static, >>::Error: Into + std::fmt::Debug, { } @@ -322,7 +330,7 @@ impl Machine { } } - async fn complete_handshake(&mut self) -> Result<(), Error> { + async fn complete_handshake(&mut self) -> Result<(), IoError> { use futures::SinkExt; use futures::StreamExt; @@ -340,24 +348,24 @@ impl Machine { } #[instrument(skip_all, err)] - pub fn handshake_start(&mut self, payload: &[u8]) -> Result<(), Error> { + pub fn handshake_start(&mut self, payload: &[u8]) -> Result<(), IoError> { self.inner.handshake_start(payload) } - pub fn get_next_sendable_message(&mut self) -> Result>, Error> { + pub fn get_next_sendable_message(&mut self) -> Result>, IoError> { self.inner.get_next_sendable_message() } pub fn receive_next(&mut self, encrypted_msg: Vec) { self.inner.receive_next(encrypted_msg) } - pub fn next_decrypted_message(&mut self) -> Result, Error> { + pub fn next_decrypted_message(&mut self) -> Result, IoError> { self.inner.next_decrypted_message() } - fn get_io(&mut self) -> Result<&mut Box>, Error> { + fn get_io(&mut self) -> Result<&mut Box>, IoError> { if let Some(io) = self.io.as_mut() { return Ok(io); } - Err(Error::NoIoSetError) + Err(IoError::new(ErrorKind::Other, Error::NoIoSetError)) } pub fn set_io(&mut self, io: Box>) { self.io = Some(io); @@ -366,7 +374,7 @@ impl Machine { #[instrument(skip_all, err)] /// Encrypt outgoing messages, and decrypt encomming messages. /// This also processes messages to complete the handshake. - fn poll_encrypt_decrypt(&mut self) -> Result, Error> { + fn poll_encrypt_decrypt(&mut self) -> Result, IoError> { trace!( state =? self.inner.state, plain_tx = self.inner.plain_tx.len(), @@ -389,20 +397,20 @@ impl Machine { } #[instrument(skip_all)] - fn poll_outgoing_encrypted(&mut self, cx: &mut Context<'_>) -> Poll> { + fn poll_outgoing_encrypted(&mut self, cx: &mut Context<'_>) -> Poll> { while let Some(msg) = self.inner.encrypted_tx.pop_front() { match Pin::new(&mut self.get_io().unwrap()).poll_ready(cx) { Poll::Ready(Ok(())) => { if let Err(_e) = Pin::new(&mut self.get_io().expect("Missing IO")).start_send(msg) { - return Poll::Ready(Err(Error::IoError( + return Poll::Ready(Err(err( "Send failed: TODO Error should have fmt::Debug here".into(), ))); } } Poll::Ready(Err(_e)) => { - return Poll::Ready(Err(Error::IoError( + return Poll::Ready(Err(err( "IO error: TODO Error should have fmt::Debug here".into(), ))); } @@ -415,7 +423,7 @@ impl Machine { match Pin::new(&mut self.get_io().expect("Missing IO")).poll_flush(cx) { Poll::Ready(Ok(())) => Poll::Ready(Ok(())), - Poll::Ready(Err(_e)) => Poll::Ready(Err(Error::IoError( + Poll::Ready(Err(_e)) => Poll::Ready(Err(err( "Flush failed: TODO Error should have fmt::Debug here".into(), ))), Poll::Pending => Poll::Pending, @@ -477,7 +485,7 @@ impl Stream for Machine { } impl Sink> for Machine { - type Error = Error; + type Error = IoError; #[instrument(skip_all)] fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -555,11 +563,7 @@ impl Sink> for Machine { // Now close the underlying IO Pin::new(&mut self.get_io().expect("Missing IO")) .poll_close(cx) - .map_err(|_e| { - Error::IoError(format!( - "Close failed TODO Error should have fmt::debug here" - )) - }) + .map_err(|_e| err("Close failed TODO Error should have fmt::debug here")) } Poll::Ready(Err(e)) => Poll::Ready(Err(e)), Poll::Pending => Poll::Pending, @@ -578,13 +582,16 @@ mod tests { // Mock IO that implements Stream + Sink for testing #[derive(Debug)] - struct MockIo, std::io::Error>>> { + struct MockIo + where + S: Stream, std::io::Error>>, + { receiver: S, sender: mpsc::UnboundedSender>, } - impl, Error>> + Unpin> Stream for MockIo { - type Item = Result, Error>; + impl, IoError>> + Unpin> Stream for MockIo { + type Item = Result, IoError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.receiver).poll_next(cx) @@ -603,10 +610,7 @@ mod tests { fn start_send(self: Pin<&mut Self>, item: Vec) -> Result<(), Self::Error> { self.sender.unbounded_send(item).map_err(|_| { - std::io::Error::new( - std::io::ErrorKind::Other, - Error::InvalidState("Send failed".into()), - ) + IoError::new(ErrorKind::Other, Error::InvalidState("Send failed".into())) }) } From d6ac1aa73eecb545570df99db8c8c384affec56f Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Fri, 19 Sep 2025 00:13:36 -0400 Subject: [PATCH 196/206] rm logs --- src/framing.rs | 5 +---- src/message.rs | 4 ---- src/mqueue.rs | 2 -- 3 files changed, 1 insertion(+), 10 deletions(-) diff --git a/src/framing.rs b/src/framing.rs index fcb0e48..1be1611 100644 --- a/src/framing.rs +++ b/src/framing.rs @@ -197,7 +197,7 @@ where } #[cfg(test)] pub(crate) mod test { - use crate::{Duplex, test_utils::log}; + use crate::Duplex; use super::*; use futures::{SinkExt, StreamExt}; @@ -224,7 +224,6 @@ pub(crate) mod test { #[tokio::test] async fn input() -> Result<()> { - log(); let (left, mut right) = duplex(64); let mut lp = Uint24LELengthPrefixedFraming::new(left); let input = b"yelp"; @@ -238,7 +237,6 @@ pub(crate) mod test { } #[tokio::test] async fn stream_many() -> Result<()> { - log(); let (left, mut right) = duplex(64); let mut lp = Uint24LELengthPrefixedFraming::new(left); let data: &[&[u8]] = &[b"yolo", b"squalor", b"idle", b"hello", b"stuff"]; @@ -256,7 +254,6 @@ pub(crate) mod test { } #[tokio::test] async fn sink_many() -> Result<()> { - log(); let (left, mut right) = duplex(64); let mut lp = Uint24LELengthPrefixedFraming::new(left); let data: &[&[u8]] = &[b"yolo", b"squalor", b"idle", b"hello", b"stuff"]; diff --git a/src/message.rs b/src/message.rs index 7665df4..ffb8852 100644 --- a/src/message.rs +++ b/src/message.rs @@ -520,9 +520,6 @@ impl VecEncodable for ChannelMessage { #[cfg(test)] mod tests { - - use crate::test_utils::log; - use super::*; use hypercore::{ DataBlock, DataHash, DataSeek, DataUpgrade, Node, RequestBlock, RequestSeek, RequestUpgrade, @@ -647,7 +644,6 @@ mod tests { length: 4, }); let msgs = vec![ChannelMessage::new(1, one), ChannelMessage::new(1, two)]; - log(); let buff = msgs.to_encoded_bytes()?; let (result, rest) = as CompactEncoding>::decode(&buff)?; assert!(rest.is_empty()); diff --git a/src/mqueue.rs b/src/mqueue.rs index 9a2d91a..7ea43da 100644 --- a/src/mqueue.rs +++ b/src/mqueue.rs @@ -177,8 +177,6 @@ mod test { #[tokio::test] async fn mqueue() -> Result<()> { - log(); - let rtolm = new_msg(38); let ltorm = new_msg(42); From a5c573e13bba47e8d5afffcdb6a6b87817667e0e Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Fri, 19 Sep 2025 00:15:59 -0400 Subject: [PATCH 197/206] renames --- src/crypto/handshake.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index 8fdc27d..8eb2169 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -280,10 +280,10 @@ fn new_handshake_state( NoiseParams, }; - let hs_name = name_from_pattern(&config.pattern)?; + let hs_pattern = name_from_pattern(&config.pattern)?; let params = NoiseParams::new( - hs_name.to_string(), + hs_pattern.to_string(), BaseChoice::Noise, HandshakeChoice { pattern: config.pattern, From 831f9c0b84251a74d00a5fe55bfb1b51d5d077a8 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Fri, 19 Sep 2025 00:17:14 -0400 Subject: [PATCH 198/206] rmunused & lint --- src/lib.rs | 12 ------------ src/message.rs | 2 +- 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 15527da..168400e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -158,15 +158,3 @@ pub use protocol::{Command, CommandTx, DiscoveryKey, Event, Key, Protocol}; pub use util::discovery_key; // Export DHT-related crypto functionality pub use crypto::{DecryptCipher, EncryptCipher, Handshake, HandshakeConfig, handshake_constants}; - -pub mod ext_test_utils { - use crate::sstream::PARAMS; - use snow::{Builder, Keypair}; - - pub fn new_key_pair() -> Keypair { - let kp = Builder::new(PARAMS.parse().unwrap()) - .generate_keypair() - .unwrap(); - kp - } -} diff --git a/src/message.rs b/src/message.rs index ffb8852..3523778 100644 --- a/src/message.rs +++ b/src/message.rs @@ -237,7 +237,7 @@ impl CompactEncoding for Message { return Err(EncodingError::new( EncodingErrorKind::InvalidData, &format!("Invalid message type to decode: {typ}"), - )) + )); } }) } From 69e91628c515239b8e1824620145e8f691d8a2b3 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Fri, 19 Sep 2025 00:19:03 -0400 Subject: [PATCH 199/206] fixes related to handshake change --- src/crypto/mod.rs | 2 +- src/sstream/sm2.rs | 29 +++++++++++++++++------------ 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 38fd25e..c7f522b 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -1,5 +1,5 @@ mod cipher; -mod curve; +pub(crate) mod curve; mod handshake; pub(crate) use cipher::write_stream_id; pub use cipher::{DecryptCipher, EncryptCipher}; diff --git a/src/sstream/sm2.rs b/src/sstream/sm2.rs index fe1b466..57808c2 100644 --- a/src/sstream/sm2.rs +++ b/src/sstream/sm2.rs @@ -573,11 +573,10 @@ impl Sink> for Machine { #[cfg(test)] mod tests { - use crate::{sstream::PARAMS, test_utils::log}; + use crate::sstream::hc_specific; use super::*; use futures::{SinkExt, StreamExt, channel::mpsc, join}; - use snow::Builder; use tracing::instrument; // Mock IO that implements Stream + Sink for testing @@ -646,9 +645,7 @@ mod tests { } fn new_connected_secret_stream() -> (SecStream>, SecStream>) { - let kp = Builder::new(PARAMS.parse().unwrap()) - .generate_keypair() - .unwrap(); + let kp = hc_specific::generate_keypair().unwrap(); ( SecStream::new_initiator(&kp.public.try_into().unwrap(), &[]).unwrap(), SecStream::new_responder(&kp.private).unwrap(), @@ -691,13 +688,19 @@ mod tests { let (lss, rss) = new_connected_secret_stream(); let (mut l, mut r) = (SansIoMachine::new_init(lss), SansIoMachine::new_resp(rss)); - r.receive_next_messages(l.get_sendable_messages()?); - l.receive_next_messages(r.get_sendable_messages()?); - r.receive_next_messages(l.get_sendable_messages()?); + let lx = l.get_sendable_messages()?; + r.receive_next_messages(lx); + + let rx = r.get_sendable_messages()?; // <-- here. r is responder + l.receive_next_messages(rx); + + let lx = l.get_sendable_messages()?; + r.receive_next_messages(lx); + assert!(l.ready()); - l.receive_next_messages(r.get_sendable_messages()?); + let rx = r.get_sendable_messages()?; + l.receive_next_messages(rx); assert!(r.ready()); - //assert!(dbg!(l.next_decrypted_message()?).is_none()); Ok(()) } @@ -810,8 +813,10 @@ mod tests { #[tokio::test] async fn test_machine_handshake_start() -> Result<(), Error> { - let remote_key = [4u8; 32]; - let initiator_state = SecStream::new_initiator(&remote_key, &[])?; + //let remote_key = [4u8; 32]; + let kp = hc_specific::generate_keypair().unwrap(); + let public = kp.public.try_into().unwrap(); + let initiator_state = SecStream::new_initiator(&public, &[])?; let (mock_io, _io_tx, mut out_rx) = create_mock_io_pair(); let mut machine = Machine::new_init(Box::new(mock_io), initiator_state); From b6bbe544b7c9035d19c6f1f76df3aed582882904 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Fri, 19 Sep 2025 00:20:43 -0400 Subject: [PATCH 200/206] changes from handshake change --- src/sstream/mod.rs | 78 ++++++++++++++++++++++++++++--------- src/sstream/sm2.rs | 1 - src/sstream/statemachine.rs | 6 +-- 3 files changed, 62 insertions(+), 23 deletions(-) diff --git a/src/sstream/mod.rs b/src/sstream/mod.rs index c5d48af..bba1dc2 100644 --- a/src/sstream/mod.rs +++ b/src/sstream/mod.rs @@ -4,15 +4,13 @@ ``` // Excessive typing to demonstrate flow through typestates use hypercore_protocol::sstream::{ - EncryptorReady, HsDone, HsMsgSent, Initiator, Ready, Responder, SecStream, - Start, PARAMS, + EncryptorReady, HsDone, HsMsgSent, Initiator, Ready, Responder, SecStream, Start, + hc_specific::generate_keypair, }; -use crypto_secretstream::Tag; -use snow::{Builder, params::NoiseParams}; -let params: NoiseParams = PARAMS.parse().expect("known to work"); -let kp = Builder::new(params.clone()).generate_keypair()?; +let kp: snow::Keypair = generate_keypair()?; // Create an initiator and responder -let init: SecStream> = SecStream::new_initiator(&kp.public.try_into().unwrap(), &[])?; +let init: SecStream> = + SecStream::new_initiator(&kp.public.try_into().unwrap(), &[])?; let resp: SecStream> = SecStream::new_responder(&kp.private)?; // initiator sends the first handshake message, a payload can be included to send extra data to the @@ -25,7 +23,8 @@ assert_eq!(payload, b"one"); // responder sends a handshake message, which can include a payload. As well as a second // message which contains the symmetric key needed to set up the decryptor -let (resp, [msg1, msg2]): (SecStream, [Vec; 2]) = resp.write_msg(Some(b"two"))?; +let (resp, [msg1, msg2]): (SecStream, [Vec; 2]) = + resp.write_msg(Some(b"two"))?; // Initiator receives last handshake message, use handshake to create the extract payload. let (init, payload_recv): (SecStream>, Vec) = init.read_msg(&msg1)?; @@ -40,7 +39,7 @@ let mut resp: SecStream = resp.read_msg(&to_resp_final)?; // Now both sides can send and receive messages let mut msg = b"three".to_vec(); -init.push(&mut msg, &[], Tag::Message)?; +init.push(&mut msg, &[], crypto_secretstream::Tag::Message)?; let tag = resp.pull(&mut msg, &[])?; assert_eq!(msg, b"three"); Ok::<(), Box>(()) @@ -52,13 +51,13 @@ mod streamsink; use crypto_secretstream::{Header, Key, PullStream, PushStream, Tag}; use rand::rngs::OsRng; -use snow::{Builder, HandshakeState, params::NoiseParams}; +use snow::HandshakeState; use std::{fmt::Debug, marker::PhantomData}; use crate::{Error, crypto::write_stream_id}; -/// Default pattern -pub const PARAMS: &str = "Noise_IK_25519_ChaChaPoly_BLAKE2b"; +/// NB: This is what the params SHOULD be, but hypercore uses "..Ed25519.." +//pub const PARAMS: &str = "Noise_IK_25519_ChaChaPoly_BLAKE2b"; const STREAM_ID_LENGTH: usize = 32; const RAW_HEADER_MSG_LEN: usize = STREAM_ID_LENGTH + Header::BYTES; const SNOW_CIPHERKEYLEN: usize = 32; @@ -148,7 +147,51 @@ impl Debug for Ready { } } -pub const DHT_PARAM_STR: &str = "Noise_IK_Ed25519_ChaChaPoly_BLAKE2b"; +pub mod hc_specific { + use std::sync::LazyLock; + + use snow::{ + Builder, Keypair, + params::{BaseChoice, HandshakeChoice, HandshakePattern, NoiseParams}, + resolvers::{DefaultResolver, FallbackResolver}, + }; + + use crate::Error; + + pub const PARAM_STR: &str = "Noise_IK_Ed25519_ChaChaPoly_BLAKE2b"; + static NOISE_PARAMS: LazyLock = LazyLock::new(|| { + NoiseParams::new( + PARAM_STR.to_string(), + //PARAMS.to_string(), + BaseChoice::Noise, + HandshakeChoice { + pattern: HandshakePattern::IK, + modifiers: snow::params::HandshakeModifierList { list: vec![] }, + }, + snow::params::DHChoice::Curve25519, + snow::params::CipherChoice::ChaChaPoly, + snow::params::HashChoice::Blake2b, + ) + }); + pub fn noise_params() -> &'static NoiseParams { + &*NOISE_PARAMS + } + pub(super) fn builder() -> Builder<'static> { + let params = noise_params(); + Builder::with_resolver( + params.clone(), + //Box::new(DefaultResolver::default()), + Box::new(FallbackResolver::new( + Box::::default(), + Box::::default(), + )), + ) + } + + pub fn generate_keypair() -> Result { + Ok(builder().generate_keypair()?) + } +} impl SecStream> { /// Create an initiator of a secret stream @@ -156,9 +199,9 @@ impl SecStream> { remote_public_key: &[u8; PUBLIC_KEYLEN], prologue: &[u8], ) -> Result { - let params: NoiseParams = PARAMS.parse().expect("known to work"); - let key_pair = Builder::new(params.clone()).generate_keypair()?; - let state = Builder::new(params.clone()) + let key_pair = hc_specific::generate_keypair()?; + + let state = hc_specific::builder() .prologue(prologue)? .local_private_key(&key_pair.private)? .remote_public_key(remote_public_key.as_slice())? @@ -204,8 +247,7 @@ impl SecStream> { impl SecStream> { /// Create a responder of a secret stream pub fn new_responder(private: &[u8]) -> Result { - let params: NoiseParams = PARAMS.parse().expect("known to work"); - let state = Builder::new(params.clone()) + let state = hc_specific::builder() .local_private_key(private)? .build_responder()?; Ok(Self { diff --git a/src/sstream/sm2.rs b/src/sstream/sm2.rs index 57808c2..fa624ba 100644 --- a/src/sstream/sm2.rs +++ b/src/sstream/sm2.rs @@ -813,7 +813,6 @@ mod tests { #[tokio::test] async fn test_machine_handshake_start() -> Result<(), Error> { - //let remote_key = [4u8; 32]; let kp = hc_specific::generate_keypair().unwrap(); let public = kp.public.try_into().unwrap(); let initiator_state = SecStream::new_initiator(&public, &[])?; diff --git a/src/sstream/statemachine.rs b/src/sstream/statemachine.rs index 3bf6b9b..40a99f2 100644 --- a/src/sstream/statemachine.rs +++ b/src/sstream/statemachine.rs @@ -157,14 +157,12 @@ impl Manager { mod test { use snow::Builder; - use crate::sstream::PARAMS; + use crate::sstream::hc_specific; use super::*; fn new_paired() -> (SecStream>, SecStream>) { - let kp = Builder::new(PARAMS.parse().unwrap()) - .generate_keypair() - .unwrap(); + let kp = hc_specific::generate_keypair().unwrap(); ( SecStream::new_initiator(&kp.public.try_into().unwrap(), &[]).unwrap(), SecStream::new_responder(&kp.private).unwrap(), From 7fa4480f592723afc7c47fedbeac957b85b0e62a Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Fri, 19 Sep 2025 00:26:37 -0400 Subject: [PATCH 201/206] lints --- src/sstream/sm2.rs | 34 ++++++++++++++-------------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/src/sstream/sm2.rs b/src/sstream/sm2.rs index fa624ba..837ae48 100644 --- a/src/sstream/sm2.rs +++ b/src/sstream/sm2.rs @@ -1,4 +1,3 @@ -#[allow(unused)] use std::{ collections::VecDeque, fmt::Debug, @@ -19,10 +18,6 @@ use crate::{ }, }; -fn err(msg: &str) -> IoError { - IoError::new(ErrorKind::Other, Error::IoError(msg.to_string())) -} - pub(crate) enum State { InitiatorStart(SecStream>), InitiatorSent(SecStream>), @@ -190,10 +185,7 @@ impl SansIoMachine { self.state = State::InitiatorSent(s2); Ok(Some(())) } - State::Invalid => Err(IoError::new( - ErrorKind::Other, - Error::InvalidState("Invalid state".into()), - )), + State::Invalid => Err(IoError::other("Invalid state")), } } @@ -404,14 +396,14 @@ impl Machine { if let Err(_e) = Pin::new(&mut self.get_io().expect("Missing IO")).start_send(msg) { - return Poll::Ready(Err(err( - "Send failed: TODO Error should have fmt::Debug here".into(), + return Poll::Ready(Err(IoError::other( + "Send failed: TODO Error should have fmt::Debug here", ))); } } Poll::Ready(Err(_e)) => { - return Poll::Ready(Err(err( - "IO error: TODO Error should have fmt::Debug here".into(), + return Poll::Ready(Err(IoError::other( + "IO error: TODO Error should have fmt::Debug here", ))); } Poll::Pending => { @@ -423,8 +415,8 @@ impl Machine { match Pin::new(&mut self.get_io().expect("Missing IO")).poll_flush(cx) { Poll::Ready(Ok(())) => Poll::Ready(Ok(())), - Poll::Ready(Err(_e)) => Poll::Ready(Err(err( - "Flush failed: TODO Error should have fmt::Debug here".into(), + Poll::Ready(Err(_e)) => Poll::Ready(Err(IoError::other( + "Flush failed: TODO Error should have fmt::Debug here", ))), Poll::Pending => Poll::Pending, } @@ -563,7 +555,9 @@ impl Sink> for Machine { // Now close the underlying IO Pin::new(&mut self.get_io().expect("Missing IO")) .poll_close(cx) - .map_err(|_e| err("Close failed TODO Error should have fmt::debug here")) + .map_err(|_e| { + IoError::other("Close failed TODO Error should have fmt::debug here") + }) } Poll::Ready(Err(e)) => Poll::Ready(Err(e)), Poll::Pending => Poll::Pending, @@ -608,9 +602,9 @@ mod tests { } fn start_send(self: Pin<&mut Self>, item: Vec) -> Result<(), Self::Error> { - self.sender.unbounded_send(item).map_err(|_| { - IoError::new(ErrorKind::Other, Error::InvalidState("Send failed".into())) - }) + self.sender + .unbounded_send(item) + .map_err(|_| IoError::other("Send failed")) } fn poll_flush( @@ -736,7 +730,7 @@ mod tests { //let (lres, rres) = join!(lm.send(b"foo"), rm.next()); // TODO this hangs let (lres, rres) = join!(lm.flush(), rm.next()); assert!(matches!(rres, Some(Event::HandshakePayload(payload)))); - assert_eq!(lres?, ()); + lres?; Ok(()) } From 4727e793f24beaccf069dd823d8e995cda15be84 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Fri, 19 Sep 2025 00:28:12 -0400 Subject: [PATCH 202/206] clippy --fix --- src/channels.rs | 5 ++--- src/crypto/curve.rs | 2 +- src/crypto/handshake.rs | 4 ++-- src/error.rs | 2 +- src/mqueue.rs | 2 +- src/sstream/mod.rs | 2 +- src/sstream/sm2.rs | 8 ++++---- 7 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/channels.rs b/src/channels.rs index 64751ab..ae3031d 100644 --- a/src/channels.rs +++ b/src/channels.rs @@ -314,8 +314,8 @@ impl ChannelHandle { &mut self, message: Message, ) -> std::io::Result<()> { - if let Some(inbound_tx) = self.inbound_tx.as_mut() { - if let Err(err) = inbound_tx.try_send(message) { + if let Some(inbound_tx) = self.inbound_tx.as_mut() + && let Err(err) = inbound_tx.try_send(message) { match err { TrySendError::Full(e) => { return Err(error(format!("Sending to channel failed: {e}").as_str())) @@ -323,7 +323,6 @@ impl ChannelHandle { TrySendError::Closed(_) => {} } } - } Ok(()) } } diff --git a/src/crypto/curve.rs b/src/crypto/curve.rs index 4b1f482..7b7db75 100644 --- a/src/crypto/curve.rs +++ b/src/crypto/curve.rs @@ -78,7 +78,7 @@ impl Dh for Ed25519 { } #[derive(Default)] -pub struct CurveResolver; +pub(crate) struct CurveResolver; impl CryptoResolver for CurveResolver { fn resolve_dh(&self, choice: &DHChoice) -> Option> { diff --git a/src/crypto/handshake.rs b/src/crypto/handshake.rs index 8eb2169..7003dd1 100644 --- a/src/crypto/handshake.rs +++ b/src/crypto/handshake.rs @@ -348,6 +348,6 @@ fn replicate_capability(is_initiator: bool, key: &[u8], handshake_hash: &[u8]) - hasher.update(&seed); hasher.update(key); let hash = hasher.finalize_fixed(); - let capability = hash.as_slice().to_vec(); - capability + + hash.as_slice().to_vec() } diff --git a/src/error.rs b/src/error.rs index e12a932..0c1c482 100644 --- a/src/error.rs +++ b/src/error.rs @@ -27,6 +27,6 @@ impl From for Error { impl From for std::io::Error { fn from(value: Error) -> Self { - std::io::Error::new(std::io::ErrorKind::Other, value) + std::io::Error::other(value) } } diff --git a/src/mqueue.rs b/src/mqueue.rs index 7ea43da..923a589 100644 --- a/src/mqueue.rs +++ b/src/mqueue.rs @@ -143,7 +143,7 @@ mod test { use crate::{ encrypted_framed_message_channel, framing::test::duplex, message::ChannelMessage, - schema::NoData, test_utils::log, Encrypted, Uint24LELengthPrefixedFraming, + schema::NoData, Encrypted, Uint24LELengthPrefixedFraming, }; use super::{MessageIo, MqueueEvent}; diff --git a/src/sstream/mod.rs b/src/sstream/mod.rs index bba1dc2..7753ef9 100644 --- a/src/sstream/mod.rs +++ b/src/sstream/mod.rs @@ -174,7 +174,7 @@ pub mod hc_specific { ) }); pub fn noise_params() -> &'static NoiseParams { - &*NOISE_PARAMS + &NOISE_PARAMS } pub(super) fn builder() -> Builder<'static> { let params = noise_params(); diff --git a/src/sstream/sm2.rs b/src/sstream/sm2.rs index 837ae48..6ea6c9d 100644 --- a/src/sstream/sm2.rs +++ b/src/sstream/sm2.rs @@ -1,7 +1,7 @@ use std::{ collections::VecDeque, fmt::Debug, - io::{Error as IoError, ErrorKind}, + io::Error as IoError, mem::replace, pin::Pin, task::{Context, Poll}, @@ -209,7 +209,7 @@ impl SansIoMachine { fn receive_next_messages(&mut self, encrypted_messages: Vec>) { self.encrypted_rx - .extend(encrypted_messages.into_iter().map(|x| Ok(x))); + .extend(encrypted_messages.into_iter().map(Ok)); } fn receive_next(&mut self, encrypted_msg: Vec) { @@ -357,7 +357,7 @@ impl Machine { if let Some(io) = self.io.as_mut() { return Ok(io); } - Err(IoError::new(ErrorKind::Other, Error::NoIoSetError)) + Err(IoError::other(Error::NoIoSetError)) } pub fn set_io(&mut self, io: Box>) { self.io = Some(io); @@ -571,7 +571,7 @@ mod tests { use super::*; use futures::{SinkExt, StreamExt, channel::mpsc, join}; - use tracing::instrument; + // Mock IO that implements Stream + Sink for testing #[derive(Debug)] From c27a563712b9640db0a128fc677022f6a8b74d7d Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Fri, 19 Sep 2025 01:36:52 -0400 Subject: [PATCH 203/206] lints --- src/sstream/mod.rs | 4 ++++ src/sstream/sm2.rs | 13 ++++++++++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/sstream/mod.rs b/src/sstream/mod.rs index 7753ef9..443a74b 100644 --- a/src/sstream/mod.rs +++ b/src/sstream/mod.rs @@ -158,6 +158,7 @@ pub mod hc_specific { use crate::Error; + /// The Hypercore specific parameter string pub const PARAM_STR: &str = "Noise_IK_Ed25519_ChaChaPoly_BLAKE2b"; static NOISE_PARAMS: LazyLock = LazyLock::new(|| { NoiseParams::new( @@ -173,6 +174,8 @@ pub mod hc_specific { snow::params::HashChoice::Blake2b, ) }); + + /// Get Hypercore Noise parameters. pub fn noise_params() -> &'static NoiseParams { &NOISE_PARAMS } @@ -188,6 +191,7 @@ pub mod hc_specific { ) } + /// Generate Hypercore key pair. pub fn generate_keypair() -> Result { Ok(builder().generate_keypair()?) } diff --git a/src/sstream/sm2.rs b/src/sstream/sm2.rs index 6ea6c9d..de1a8e7 100644 --- a/src/sstream/sm2.rs +++ b/src/sstream/sm2.rs @@ -340,16 +340,22 @@ impl Machine { } #[instrument(skip_all, err)] + /// Start the handshake pub fn handshake_start(&mut self, payload: &[u8]) -> Result<(), IoError> { self.inner.handshake_start(payload) } + /// Try to get the next encrypted message to send. pub fn get_next_sendable_message(&mut self) -> Result>, IoError> { self.inner.get_next_sendable_message() } + + /// Manually add a received encrypted message to be decrypted. pub fn receive_next(&mut self, encrypted_msg: Vec) { self.inner.receive_next(encrypted_msg) } + + /// Try to get the next decrypted message. pub fn next_decrypted_message(&mut self) -> Result, IoError> { self.inner.next_decrypted_message() } @@ -359,6 +365,7 @@ impl Machine { } Err(IoError::other(Error::NoIoSetError)) } + /// Set the IO connection for sending and receiving encrypted messages. pub fn set_io(&mut self, io: Box>) { self.io = Some(io); } @@ -377,7 +384,7 @@ impl Machine { self.inner.poll_encrypt_decrypt() } - /// pull in new incomming encrypted messages + /// pull in new incomming encrypted messages. #[instrument(skip_all)] fn poll_incoming_encrypted(&mut self, cx: &mut Context<'_>) -> Poll<()> { while let Poll::Ready(Some(result)) = @@ -422,6 +429,7 @@ impl Machine { } } + /// `true` when handshake is completed. pub fn ready(&self) -> bool { self.inner.ready() } @@ -482,7 +490,7 @@ impl Sink> for Machine { #[instrument(skip_all)] fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { // Process any pending work to make space in queues - self.poll_incoming_encrypted(cx); + let _ = self.poll_incoming_encrypted(cx); match self.poll_outgoing_encrypted(cx) { Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), @@ -571,7 +579,6 @@ mod tests { use super::*; use futures::{SinkExt, StreamExt, channel::mpsc, join}; - // Mock IO that implements Stream + Sink for testing #[derive(Debug)] From ade402b3b20a832e95b3e79f5b2201cb9803e509 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sat, 20 Sep 2025 02:44:45 -0400 Subject: [PATCH 204/206] Fix all the warnings --- src/error.rs | 14 ++++++------ src/sstream/mod.rs | 3 +-- src/sstream/sm2.rs | 56 +++++++++++++++++++++++++++------------------- 3 files changed, 41 insertions(+), 32 deletions(-) diff --git a/src/error.rs b/src/error.rs index 0c1c482..788cd12 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,20 +1,20 @@ +/// Error type for this crate #[derive(Debug, thiserror::Error)] pub enum Error { + /// Error from the [`snow`] crate #[error("Error from `snow`: {0}")] Snow(#[from] snow::Error), + /// Error from [`crypto_secretstream`] crate #[error("Error from `crypto_secretstream`: {0}")] SecretStream(crypto_secretstream::aead::Error), + /// [`crate::sstream::statemachine::Manager`] handshake invalid state. TODO remove with + /// statemachine? #[error("Invalid Handshake State: {0}")] InvalidHandshakeState(String), - // TODO added by claude - #[error("Invalid Encryption Statemachine State: {0}")] - InvalidState(String), - // TODO added by claude - #[error("IoError: {0}")] - IoError(String), - // Missing IO + /// Missing IO in [`crate::sstream::sm2::Machine`] #[error("Machine IO is not set.")] NoIoSetError, + /// Error from [`std::io`] #[error("{0}")] StdIoError(#[from] std::io::Error), } diff --git a/src/sstream/mod.rs b/src/sstream/mod.rs index 443a74b..553bb72 100644 --- a/src/sstream/mod.rs +++ b/src/sstream/mod.rs @@ -146,8 +146,8 @@ impl Debug for Ready { .finish() } } - pub mod hc_specific { + //! Stuff for generating Hypercore specific things like Noise parameters, keys, etc use std::sync::LazyLock; use snow::{ @@ -163,7 +163,6 @@ pub mod hc_specific { static NOISE_PARAMS: LazyLock = LazyLock::new(|| { NoiseParams::new( PARAM_STR.to_string(), - //PARAMS.to_string(), BaseChoice::Noise, HandshakeChoice { pattern: HandshakePattern::IK, diff --git a/src/sstream/sm2.rs b/src/sstream/sm2.rs index de1a8e7..016fe93 100644 --- a/src/sstream/sm2.rs +++ b/src/sstream/sm2.rs @@ -1,3 +1,4 @@ +//! Interface for an encrypted channel use std::{ collections::VecDeque, fmt::Debug, @@ -198,15 +199,18 @@ impl SansIoMachine { Ok(made_progress.then_some(())) } + #[allow(unused, reason = "vectorized version of 'get_next_sendable_message'")] fn get_sendable_messages(&mut self) -> Result>, IoError> { self.poll_all_enc_dec()?; Ok(self.encrypted_tx.drain(..).collect()) } + fn get_next_sendable_message(&mut self) -> Result>, IoError> { self.poll_all_enc_dec()?; Ok(self.encrypted_tx.pop_front()) } + #[allow(unused, reason = "vectorized version of 'receive_next'")] fn receive_next_messages(&mut self, encrypted_messages: Vec>) { self.encrypted_rx .extend(encrypted_messages.into_iter().map(Ok)); @@ -216,6 +220,7 @@ impl SansIoMachine { self.encrypted_rx.push_back(Ok(encrypted_msg)); } + #[allow(unused, reason = "add a new plaintext message to send")] fn queue_msg(&mut self, msg: Vec) { self.plain_tx.push_back(msg); } @@ -243,12 +248,17 @@ impl Debug for SansIoMachine { } #[derive(Debug)] +/// Encryption event pub enum Event { + /// Data passed through the handshake payload HandshakePayload(Vec), + /// Decrypted message Message(Vec), + /// Error occured in encryption ErrStuff(IoError), } +/// Supertrait for duplex channel required by [`Machine`] pub trait MachineIo: Stream, IoError>> + Sink> + Send + Unpin + 'static { @@ -277,6 +287,12 @@ impl Debug for Machine { } impl Machine { + /// Create a new [`Machine`] + fn new(io: Option>>, inner: SansIoMachine) -> Self { + Self { io, inner } + } + + /// Create a new initiator pub fn new_dht_init( io: Option>>, remote_pub_key: &[u8; PUBLIC_KEYLEN], @@ -285,44 +301,38 @@ impl Machine { let ss = SecStream::new_initiator(remote_pub_key, prologue)?; let state = State::InitiatorStart(ss); let inner = SansIoMachine::new(state); - Ok(Self { io, inner }) + Ok(Self::new(io, inner)) + } + + /// Create a new initiator + pub fn new_init( + io: Box>, + state: SecStream>, + ) -> Self { + Self::new(Some(io), SansIoMachine::new_init(state)) } - fn new_dht_resp( + /// Create a new responder from a private key + pub fn resp_from_private( io: Option>>, private: &[u8], ) -> Result { let ss = SecStream::new_responder(private)?; let state = State::RespStart(ss); let inner = SansIoMachine::new(state); - Ok(Self { io, inner }) + Ok(Self::new(io, inner)) } - fn new(io: Option>>, inner: SansIoMachine) -> Self { - Self { io, inner } - } - - fn new_init( - io: Box>, - state: SecStream>, - ) -> Self { - Self { - io: Some(io), - inner: SansIoMachine::new_init(state), - } - } - - fn new_resp( + /// Create a new responder + pub fn new_resp( io: Box>, state: SecStream>, ) -> Self { - Self { - io: Some(io), - inner: SansIoMachine::new_resp(state), - } + Self::new(Some(io), SansIoMachine::new_resp(state)) } - async fn complete_handshake(&mut self) -> Result<(), IoError> { + /// Wait for handshake to complete + pub async fn complete_handshake(&mut self) -> Result<(), IoError> { use futures::SinkExt; use futures::StreamExt; From 540b42764e8f6e55ac94c9bd1722e746f3c0148b Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sat, 20 Sep 2025 03:00:14 -0400 Subject: [PATCH 205/206] lints --- src/sstream/sm2.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/sstream/sm2.rs b/src/sstream/sm2.rs index 016fe93..bec5760 100644 --- a/src/sstream/sm2.rs +++ b/src/sstream/sm2.rs @@ -639,6 +639,7 @@ mod tests { } } + #[allow(clippy::type_complexity)] fn create_mock_io_pair() -> ( MockIo, std::io::Error>>>, mpsc::UnboundedSender, std::io::Error>>, @@ -743,10 +744,9 @@ mod tests { let (mut lm, mut rm) = connected_machines(); let payload = b"Hello, World!".to_vec(); - lm.handshake_start(&payload); - //let (lres, rres) = join!(lm.send(b"foo"), rm.next()); // TODO this hangs + lm.handshake_start(&payload)?; let (lres, rres) = join!(lm.flush(), rm.next()); - assert!(matches!(rres, Some(Event::HandshakePayload(payload)))); + assert!(matches!(rres, Some(Event::HandshakePayload(_)))); lres?; Ok(()) } @@ -763,14 +763,14 @@ mod tests { }; let (empty, rtol, ltor): (Vec, _, _) = (vec![], b"rtol".to_vec(), b"ltor".to_vec()); - assert!(matches!(lr, Event::HandshakePayload(empty))); - assert!(matches!(rr, Event::HandshakePayload(empty))); + assert!(matches!(lr, Event::HandshakePayload(x) if x == empty)); + assert!(matches!(rr, Event::HandshakePayload(x) if x == empty)); let (Some(lr), Some(rr)) = join!(lm.next(), rm.next()) else { panic!() }; - assert!(matches!(lr, Event::Message(rtol))); - assert!(matches!(rr, Event::Message(ltor))); + assert!(matches!(lr, Event::Message(x) if x == rtol)); + assert!(matches!(rr, Event::Message(x) if x == ltor)); Ok(()) } From 87570a3f2a3f8adf73599488acbf8bbafb8ce2d1 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Sat, 20 Sep 2025 13:38:36 -0400 Subject: [PATCH 206/206] add .gitignore with yarn.lock, node_modules --- tests/js/.gitignore | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 tests/js/.gitignore diff --git a/tests/js/.gitignore b/tests/js/.gitignore new file mode 100644 index 0000000..cae21dd --- /dev/null +++ b/tests/js/.gitignore @@ -0,0 +1,2 @@ +yarn.lock +node_modules_for_tests