From ef5ed52e7121dc84ce2abd3c494d7ddf4fdbd5b1 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Thu, 11 Apr 2024 19:05:14 -0700 Subject: [PATCH 01/14] preliminary support for wisp v2 --- Cargo.lock | 39 +++- client/Cargo.toml | 3 +- client/src/lib.rs | 4 +- client/src/udp_stream.rs | 2 +- client/src/utils.rs | 30 +-- client/src/websocket.rs | 10 +- client/src/wrappers.rs | 11 +- rustfmt.toml | 1 + server/src/main.rs | 21 +- simple-wisp-client/src/main.rs | 31 ++- wisp/Cargo.toml | 5 +- wisp/src/extensions.rs | 190 ++++++++++++++++ wisp/src/fastwebsockets.rs | 15 +- wisp/src/lib.rs | 383 +++++++++++++++++++++++---------- wisp/src/packet.rs | 166 ++++++++++++-- wisp/src/sink_unfold.rs | 8 +- wisp/src/stream.rs | 11 +- wisp/src/ws.rs | 42 ++-- 18 files changed, 769 insertions(+), 203 deletions(-) create mode 100644 rustfmt.toml create mode 100644 wisp/src/extensions.rs diff --git a/Cargo.lock b/Cargo.lock index 4876e52..fb57868 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -133,9 +133,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.79" +version = "0.1.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a507401cad91ec6a857ed5513a2073c82a9b9048762b885bb98655b306964681" +checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" dependencies = [ "proc-macro2 1.0.79", "quote 1.0.36", @@ -525,6 +525,7 @@ name = "epoxy-client" version = "1.5.1" dependencies = [ "async-compression", + "async-trait", "async_io_stream", "base64", "bytes", @@ -542,7 +543,7 @@ dependencies = [ "pin-project-lite", "ring", "rustls-pki-types", - "send_wrapper", + "send_wrapper 0.6.0", "tokio", "tokio-rustls", "tokio-util", @@ -744,6 +745,16 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" +dependencies = [ + "gloo-timers", + "send_wrapper 0.4.0", +] + [[package]] name = "futures-util" version = "0.3.30" @@ -791,6 +802,18 @@ version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +[[package]] +name = "gloo-timers" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b995a66bb87bebce9a0f4a95aed01daca4872c050bfcb21653361c03bc35e5c" +dependencies = [ + "futures-channel", + "futures-core", + "js-sys", + "wasm-bindgen", +] + [[package]] name = "h2" version = "0.3.26" @@ -1659,6 +1682,12 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca" +[[package]] +name = "send_wrapper" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f638d531eccd6e23b980caf34876660d38e265409d8e99b397ab71eb3612fad0" + [[package]] name = "send_wrapper" version = "0.6.0" @@ -2531,14 +2560,16 @@ checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" [[package]] name = "wisp-mux" -version = "3.0.0" +version = "4.0.0" dependencies = [ + "async-trait", "async_io_stream", "bytes", "dashmap", "event-listener", "fastwebsockets 0.7.1", "futures", + "futures-timer", "futures-util", "pin-project-lite", "tokio", diff --git a/client/Cargo.toml b/client/Cargo.toml index 21a2af1..ee1d108 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -25,7 +25,7 @@ tokio-util = { version = "0.7.10", features = ["io"] } async-compression = { version = "0.4.5", features = ["tokio", "gzip", "brotli"] } fastwebsockets = { version = "0.6.0", features = ["unstable-split"] } base64 = "0.21.7" -wisp-mux = { path = "../wisp", features = ["tokio_io"] } +wisp-mux = { path = "../wisp", features = ["tokio_io", "wasm"] } async_io_stream = { version = "0.3.3", features = ["tokio_io"] } getrandom = { version = "0.2.12", features = ["js"] } hyper-util-wasm = { version = "0.1.3", features = ["client", "client-legacy", "http1", "http2"] } @@ -35,6 +35,7 @@ console_error_panic_hook = "0.1.7" send_wrapper = "0.6.0" event-listener = "5.2.0" wasmtimer = "0.2.0" +async-trait = "0.1.80" [dependencies.ring] features = ["wasm32_unknown_unknown_js"] diff --git a/client/src/lib.rs b/client/src/lib.rs index 6a6217c..0f8f678 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -105,7 +105,7 @@ pub fn certs() -> Result { #[wasm_bindgen(inspectable)] pub struct EpoxyClient { rustls_config: Arc, - mux: Arc>>, + mux: Arc>, hyper_client: Client, #[wasm_bindgen(getter_with_clone)] pub useragent: String, @@ -164,7 +164,7 @@ impl EpoxyClient { async fn get_tls_io(&self, url_host: &str, url_port: u16) -> Result { let channel = self .mux - .read() + .write() .await .client_new_stream(StreamType::Tcp, url_host.to_string(), url_port) .await diff --git a/client/src/udp_stream.rs b/client/src/udp_stream.rs index c026ca7..877bab4 100644 --- a/client/src/udp_stream.rs +++ b/client/src/udp_stream.rs @@ -33,7 +33,7 @@ impl EpxUdpStream { let io = tcp .mux - .read() + .write() .await .client_new_stream(StreamType::Udp, url_host.to_string(), url_port) .await diff --git a/client/src/utils.rs b/client/src/utils.rs index 1fdcf2e..3b05027 100644 --- a/client/src/utils.rs +++ b/client/src/utils.rs @@ -6,7 +6,10 @@ use wasm_bindgen_futures::JsFuture; use hyper::rt::Executor; use js_sys::ArrayBuffer; use std::future::Future; -use wisp_mux::WispError; +use wisp_mux::{ + extensions::udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder}, + WispError, +}; #[wasm_bindgen] extern "C" { @@ -192,25 +195,25 @@ pub fn get_url_port(url: &Uri) -> Result { pub async fn make_mux( url: &str, -) -> Result< - ( - ClientMux, - impl Future>, - ), - WispError, -> { +) -> Result<(ClientMux, impl Future> + Send), WispError> { let (wtx, wrx) = WebSocketWrapper::connect(url, vec![]) .await .map_err(|_| WispError::WsImplSocketClosed)?; wtx.wait_for_open().await; - let mux = ClientMux::new(wrx, wtx).await?; + let mux = ClientMux::new( + wrx, + wtx, + Some(vec![UdpProtocolExtension().into()]), + Some(&[&UdpProtocolExtensionBuilder()]), + ) + .await?; Ok(mux) } pub fn spawn_mux_fut( - mux: Arc>>, - fut: impl Future> + 'static, + mux: Arc>, + fut: impl Future> + Send + 'static, url: String, ) { wasm_bindgen_futures::spawn_local(async move { @@ -225,10 +228,7 @@ pub fn spawn_mux_fut( }); } -pub async fn replace_mux( - mux: Arc>>, - url: &str, -) -> Result<(), WispError> { +pub async fn replace_mux(mux: Arc>, url: &str) -> Result<(), WispError> { let (mux_replace, fut) = make_mux(url).await?; let mut mux_write = mux.write().await; mux_write.close().await?; diff --git a/client/src/websocket.rs b/client/src/websocket.rs index fff1f44..414e53e 100644 --- a/client/src/websocket.rs +++ b/client/src/websocket.rs @@ -106,7 +106,7 @@ impl EpxWebSocket { break; } // ping/pong/continue - _ => {}, + _ => {} } } }); @@ -115,7 +115,13 @@ impl EpxWebSocket { .call0(&Object::default()) .replace_err("Failed to call onopen")?; - Ok(Self { tx, onerror, origin, protocols, url: url.to_string() }) + Ok(Self { + tx, + onerror, + origin, + protocols, + url: url.to_string(), + }) } .await; if let Err(ret) = ret { diff --git a/client/src/wrappers.rs b/client/src/wrappers.rs index 9b16525..e67779e 100644 --- a/client/src/wrappers.rs +++ b/client/src/wrappers.rs @@ -53,7 +53,7 @@ impl Stream for IncomingBody { } #[derive(Clone)] -pub struct ServiceWrapper(pub Arc>>, pub String); +pub struct ServiceWrapper(pub Arc>, pub String); impl tower_service::Service for ServiceWrapper { type Response = TokioIo; @@ -69,7 +69,7 @@ impl tower_service::Service for ServiceWrapper { let mux_url = self.1.clone(); async move { let stream = mux - .read() + .write() .await .client_new_stream( StreamType::Tcp, @@ -193,11 +193,9 @@ pub struct WebSocketReader { close_event: Arc, } +#[async_trait::async_trait] impl WebSocketRead for WebSocketReader { - async fn wisp_read_frame( - &mut self, - _: &LockedWebSocketWrite, - ) -> Result { + async fn wisp_read_frame(&mut self, _: &LockedWebSocketWrite) -> Result { use WebSocketMessage::*; if self.closed.load(Ordering::Acquire) { return Err(WispError::WsImplSocketClosed); @@ -306,6 +304,7 @@ impl WebSocketWrapper { } } +#[async_trait::async_trait] impl WebSocketWrite for WebSocketWrapper { async fn wisp_write_frame(&mut self, frame: Frame) -> Result<(), WispError> { use wisp_mux::ws::OpCode::*; diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..c3c8c37 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1 @@ +imports_granularity = "Crate" diff --git a/server/src/main.rs b/server/src/main.rs index 0b1f6f3..61561b7 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -4,13 +4,12 @@ use std::io::Error; use bytes::Bytes; use clap::Parser; use fastwebsockets::{ - upgrade::{self, UpgradeFut}, CloseCode, FragmentCollector, FragmentCollectorRead, Frame, OpCode, Payload, - WebSocketError, + upgrade::{self, UpgradeFut}, + CloseCode, FragmentCollector, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, }; use futures_util::{SinkExt, StreamExt, TryFutureExt}; use hyper::{ - body::Incoming, server::conn::http1, service::service_fn, Request, Response, - StatusCode, + body::Incoming, server::conn::http1, service::service_fn, Request, Response, StatusCode, }; use hyper_util::rt::TokioIo; use tokio::net::{lookup_host, TcpListener, TcpStream, UdpSocket}; @@ -20,7 +19,10 @@ use tokio_util::codec::{BytesCodec, Framed}; #[cfg(unix)] use tokio_util::either::Either; -use wisp_mux::{CloseReason, ConnectPacket, MuxStream, ServerMux, StreamType, WispError}; +use wisp_mux::{ + extensions::udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder}, + CloseReason, ConnectPacket, MuxStream, ServerMux, StreamType, WispError, +}; type HttpBody = http_body_util::Full; @@ -261,7 +263,14 @@ async fn accept_ws( println!("{:?}: connected", addr); - let (mut mux, fut) = ServerMux::new(rx, tx, u32::MAX); + let (mut mux, fut) = ServerMux::new( + rx, + tx, + u32::MAX, + Some(vec![UdpProtocolExtension().into()]), + Some(&[&UdpProtocolExtensionBuilder()]), + ) + .await?; tokio::spawn(async move { if let Err(e) = fut.await { diff --git a/simple-wisp-client/src/main.rs b/simple-wisp-client/src/main.rs index e626b80..4dd329a 100644 --- a/simple-wisp-client/src/main.rs +++ b/simple-wisp-client/src/main.rs @@ -11,7 +11,14 @@ use hyper::{ }; use simple_moving_average::{SingleSumSMA, SMA}; use std::{ - error::Error, future::Future, io::{stdout, IsTerminal, Write}, net::SocketAddr, process::exit, sync::Arc, time::{Duration, Instant}, usize + error::Error, + future::Future, + io::{stdout, IsTerminal, Write}, + net::SocketAddr, + process::exit, + sync::Arc, + time::{Duration, Instant}, + usize, }; use tokio::{ net::TcpStream, @@ -21,7 +28,10 @@ use tokio::{ }; use tokio_native_tls::{native_tls, TlsConnector}; use tokio_util::either::Either; -use wisp_mux::{ClientMux, StreamType, WispError}; +use wisp_mux::{ + extensions::udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder}, + ClientMux, StreamType, WispError, +}; #[derive(Debug)] enum WispClientError { @@ -71,6 +81,9 @@ struct Cli { /// Duration to run the test for #[arg(short, long)] duration: Option, + /// Ask for UDP + #[arg(short, long)] + udp: bool, } #[tokio::main(flavor = "multi_thread")] @@ -117,7 +130,6 @@ async fn main() -> Result<(), Box> { fastwebsockets::handshake::generate_key(), ) .header("Sec-WebSocket-Version", "13") - .header("Sec-WebSocket-Protocol", "wisp-v1") .body(Empty::::new())?; let (ws, _) = handshake::client(&SpawnExecutor, req, socket).await?; @@ -125,7 +137,18 @@ async fn main() -> Result<(), Box> { let (rx, tx) = ws.split(tokio::io::split); let rx = FragmentCollectorRead::new(rx); - let (mux, fut) = ClientMux::new(rx, tx).await?; + let (mut mux, fut) = if opts.udp { + ClientMux::new( + rx, + tx, + Some(vec![UdpProtocolExtension().into()]), + Some(&[&UdpProtocolExtensionBuilder()]), + ) + .await? + } else { + ClientMux::new(rx, tx, Some(vec![]), Some(&[])).await? + }; + let mut threads = Vec::with_capacity(opts.streams * 2 + 3); threads.push(tokio::spawn(fut)); diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index 473640d..8cf2cba 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "wisp-mux" -version = "3.0.0" +version = "4.0.0" license = "LGPL-3.0-only" description = "A library for easily creating Wisp servers and clients." homepage = "https://github.com/MercuryWorkshop/epoxy-tls/tree/multiplexed/wisp" @@ -9,12 +9,14 @@ readme = "README.md" edition = "2021" [dependencies] +async-trait = "0.1.79" async_io_stream = "0.3.3" bytes = "1.5.0" dashmap = { version = "5.5.3", features = ["inline"] } event-listener = "5.0.0" fastwebsockets = { version = "0.7.1", features = ["unstable-split"], optional = true } futures = "0.3.30" +futures-timer = "3.0.3" futures-util = "0.3.30" pin-project-lite = "0.2.13" tokio = { version = "1.35.1", optional = true, default-features = false } @@ -22,6 +24,7 @@ tokio = { version = "1.35.1", optional = true, default-features = false } [features] fastwebsockets = ["dep:fastwebsockets", "dep:tokio"] tokio_io = ["async_io_stream/tokio_io"] +wasm = ["futures-timer/wasm-bindgen"] [package.metadata.docs.rs] all-features = true diff --git a/wisp/src/extensions.rs b/wisp/src/extensions.rs new file mode 100644 index 0000000..9358c4a --- /dev/null +++ b/wisp/src/extensions.rs @@ -0,0 +1,190 @@ +//! Wisp protocol extensions. + +use std::ops::{Deref, DerefMut}; + +use async_trait::async_trait; +use bytes::{BufMut, Bytes, BytesMut}; + +use crate::{ + ws::{LockedWebSocketWrite, WebSocketRead}, + Role, WispError, +}; + +/// Type-erased protocol extension that implements Clone. +#[derive(Debug)] +pub struct AnyProtocolExtension(Box); + +impl AnyProtocolExtension { + /// Create a new type-erased protocol extension. + pub fn new(extension: T) -> Self { + Self(Box::new(extension)) + } +} + +impl Deref for AnyProtocolExtension { + type Target = dyn ProtocolExtension; + fn deref(&self) -> &Self::Target { + self.0.deref() + } +} + +impl DerefMut for AnyProtocolExtension { + fn deref_mut(&mut self) -> &mut Self::Target { + self.0.deref_mut() + } +} + +impl Clone for AnyProtocolExtension { + fn clone(&self) -> Self { + Self(self.0.box_clone()) + } +} + +impl From for Bytes { + fn from(value: AnyProtocolExtension) -> Self { + let mut bytes = BytesMut::with_capacity(5); + let payload = value.encode(); + bytes.put_u8(value.get_id()); + bytes.put_u32_le(payload.len() as u32); + bytes.extend(payload); + bytes.freeze() + } +} + +/// A Wisp protocol extension. +/// +/// See [the +/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#protocol-extensions). +#[async_trait] +pub trait ProtocolExtension: std::fmt::Debug { + /// Get the protocol extension ID. + fn get_id(&self) -> u8; + /// Get the protocol extension's supported packets. + /// + /// Used to decide whether to call the protocol extension's packet handler. + fn get_supported_packets(&self) -> &'static [u8]; + + /// Encode self into Bytes. + fn encode(&self) -> Bytes; + + /// Handle the handshake part of a Wisp connection. + /// + /// This should be used to send or receive data before any streams are created. + async fn handle_handshake( + &mut self, + read: &mut dyn WebSocketRead, + write: &LockedWebSocketWrite, + ) -> Result<(), WispError>; + + /// Handle receiving a packet. + async fn handle_packet( + &mut self, + packet: Bytes, + read: &mut dyn WebSocketRead, + write: &LockedWebSocketWrite, + ) -> Result<(), WispError>; + + /// Clone the protocol extension. + fn box_clone(&self) -> Box; +} + +/// Trait to build a Wisp protocol extension for the client. +pub trait ProtocolExtensionBuilder { + /// Get the protocol extension ID. + /// + /// Used to decide whether this builder should be used. + fn get_id(&self) -> u8; + + /// Build a protocol extension from the extension's metadata. + fn build(&self, bytes: Bytes, role: Role) -> AnyProtocolExtension; +} + +pub mod udp { + //! UDP protocol extension. + //! + //! # Example + //! ``` + //! let (mux, fut) = ServerMux::new( + //! rx, + //! tx, + //! 128, + //! Some(vec![UdpProtocolExtension().into()]), + //! Some(&[&UdpProtocolExtensionBuilder()]) + //! ); + //! ``` + //! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---udp) + use async_trait::async_trait; + use bytes::Bytes; + + use crate::{ + ws::{LockedWebSocketWrite, WebSocketRead}, + WispError, + }; + + use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder}; + + #[derive(Debug)] + /// UDP protocol extension. + pub struct UdpProtocolExtension(); + + impl UdpProtocolExtension { + /// UDP protocol extension ID. + pub const ID: u8 = 0x01; + } + + #[async_trait] + impl ProtocolExtension for UdpProtocolExtension { + fn get_id(&self) -> u8 { + Self::ID + } + + fn get_supported_packets(&self) -> &'static [u8] { + &[] + } + + fn encode(&self) -> Bytes { + Bytes::new() + } + + async fn handle_handshake( + &mut self, + _: &mut dyn WebSocketRead, + _: &LockedWebSocketWrite, + ) -> Result<(), WispError> { + Ok(()) + } + + /// Handle receiving a packet. + async fn handle_packet( + &mut self, + _: Bytes, + _: &mut dyn WebSocketRead, + _: &LockedWebSocketWrite, + ) -> Result<(), WispError> { + Ok(()) + } + + fn box_clone(&self) -> Box { + Box::new(Self()) + } + } + + impl From for AnyProtocolExtension { + fn from(value: UdpProtocolExtension) -> Self { + AnyProtocolExtension(Box::new(value)) + } + } + + /// UDP protocol extension builder. + pub struct UdpProtocolExtensionBuilder(); + + impl ProtocolExtensionBuilder for UdpProtocolExtensionBuilder { + fn get_id(&self) -> u8 { + 0x01 + } + + fn build(&self, _: Bytes, _: crate::Role) -> AnyProtocolExtension { + AnyProtocolExtension(Box::new(UdpProtocolExtension())) + } + } +} diff --git a/wisp/src/fastwebsockets.rs b/wisp/src/fastwebsockets.rs index 7a66908..548649f 100644 --- a/wisp/src/fastwebsockets.rs +++ b/wisp/src/fastwebsockets.rs @@ -1,9 +1,12 @@ +use async_trait::async_trait; use bytes::Bytes; use fastwebsockets::{ FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite, }; use tokio::io::{AsyncRead, AsyncWrite}; +use crate::{ws::LockedWebSocketWrite, WispError}; + impl From for crate::ws::OpCode { fn from(opcode: OpCode) -> Self { use OpCode::*; @@ -58,11 +61,12 @@ impl From for crate::WispError { } } -impl crate::ws::WebSocketRead for FragmentCollectorRead { +#[async_trait] +impl crate::ws::WebSocketRead for FragmentCollectorRead { async fn wisp_read_frame( &mut self, - tx: &crate::ws::LockedWebSocketWrite, - ) -> Result { + tx: &LockedWebSocketWrite, + ) -> Result { Ok(self .read_frame(&mut |frame| async { tx.write_frame(frame.into()).await }) .await? @@ -70,8 +74,9 @@ impl crate::ws::WebSocketRead for FragmentCollectorRead } } -impl crate::ws::WebSocketWrite for WebSocketWrite { - async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), crate::WispError> { +#[async_trait] +impl crate::ws::WebSocketWrite for WebSocketWrite { + async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), WispError> { self.write_frame(frame.into()).await.map_err(|e| e.into()) } } diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 152be13..076e10c 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -4,6 +4,7 @@ //! //! [Wisp]: https://github.com/MercuryWorkshop/wisp-protocol +pub mod extensions; #[cfg(feature = "fastwebsockets")] #[cfg_attr(docsrs, doc(cfg(feature = "fastwebsockets")))] mod fastwebsockets; @@ -12,18 +13,28 @@ mod sink_unfold; mod stream; pub mod ws; -pub use crate::packet::*; -pub use crate::stream::*; +pub use crate::{packet::*, stream::*}; use bytes::Bytes; use dashmap::DashMap; use event_listener::Event; -use futures::SinkExt; -use futures::{channel::mpsc, Future, FutureExt, StreamExt}; -use std::sync::{ - atomic::{AtomicBool, AtomicU32, Ordering}, - Arc, +use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder}; +use futures::{ + channel::{mpsc, oneshot}, + select, Future, FutureExt, SinkExt, StreamExt, }; +use futures_timer::Delay; +use std::{ + sync::{ + atomic::{AtomicBool, AtomicU32, Ordering}, + Arc, + }, + time::Duration, +}; +use ws::AppendingWebSocketRead; + +/// Wisp version supported by this crate. +pub const WISP_VERSION: WispVersion = WispVersion { major: 2, minor: 0 }; /// The role of the multiplexor. #[derive(Debug, PartialEq, Copy, Clone)] @@ -37,9 +48,9 @@ pub enum Role { /// Errors the Wisp implementation can return. #[derive(Debug)] pub enum WispError { - /// The packet recieved did not have enough data. + /// The packet received did not have enough data. PacketTooSmall, - /// The packet recieved had an invalid type. + /// The packet received had an invalid type. InvalidPacketType, /// The stream had an invalid type. InvalidStreamType, @@ -47,19 +58,19 @@ pub enum WispError { InvalidStreamId, /// The close packet had an invalid reason. InvalidCloseReason, - /// The URI recieved was invalid. + /// The URI received was invalid. InvalidUri, - /// The URI recieved had no host. + /// The URI received had no host. UriHasNoHost, - /// The URI recieved had no port. + /// The URI received had no port. UriHasNoPort, /// The max stream count was reached. MaxStreamCountReached, /// The stream had already been closed. StreamAlreadyClosed, - /// The websocket frame recieved had an invalid type. + /// The websocket frame received had an invalid type. WsFrameInvalidType, - /// The websocket frame recieved was not finished. + /// The websocket frame received was not finished. WsFrameNotFinished, /// Error specific to the websocket implementation. WsImplError(Box), @@ -67,17 +78,33 @@ pub enum WispError { WsImplSocketClosed, /// The websocket implementation did not support the action. WsImplNotSupported, + /// Error specific to the protocol extension implementation. + ExtensionImplError(Box), + /// The protocol extension implementation did not support the action. + ExtensionImplNotSupported, + /// The UDP protocol extension is not supported by the server. + UdpExtensionNotSupported, /// The string was invalid UTF-8. Utf8Error(std::str::Utf8Error), + /// The integer failed to convert. + TryFromIntError(std::num::TryFromIntError), /// Other error. Other(Box), /// Failed to send message to multiplexor task. MuxMessageFailedToSend, + /// Failed to receive message from multiplexor task. + MuxMessageFailedToRecv, } impl From for WispError { - fn from(err: std::str::Utf8Error) -> WispError { - WispError::Utf8Error(err) + fn from(err: std::str::Utf8Error) -> Self { + Self::Utf8Error(err) + } +} + +impl From for WispError { + fn from(value: std::num::TryFromIntError) -> Self { + Self::TryFromIntError(value) } } @@ -103,9 +130,21 @@ impl std::fmt::Display for WispError { Self::WsImplNotSupported => { write!(f, "Websocket implementation error: unsupported feature") } + Self::ExtensionImplError(err) => { + write!(f, "Protocol extension implementation error: {}", err) + } + Self::ExtensionImplNotSupported => { + write!( + f, + "Protocol extension implementation error: unsupported feature" + ) + } + Self::UdpExtensionNotSupported => write!(f, "UDP protocol extension not supported"), Self::Utf8Error(err) => write!(f, "UTF-8 error: {}", err), + Self::TryFromIntError(err) => write!(f, "Integer conversion error: {}", err), Self::Other(err) => write!(f, "Other error: {}", err), Self::MuxMessageFailedToSend => write!(f, "Failed to send multiplexor message"), + Self::MuxMessageFailedToRecv => write!(f, "Failed to receive multiplexor message"), } } } @@ -120,29 +159,27 @@ struct MuxMapValue { is_closed: Arc, } -struct MuxInner -where - W: ws::WebSocketWrite, -{ - tx: ws::LockedWebSocketWrite, - stream_map: Arc>, +struct MuxInner { + tx: ws::LockedWebSocketWrite, + stream_map: DashMap, + buffer_size: u32, } -impl MuxInner { +impl MuxInner { pub async fn server_into_future( self, rx: R, close_rx: mpsc::Receiver, muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>, - buffer_size: u32, close_tx: mpsc::Sender, ) -> Result<(), WispError> where R: ws::WebSocketRead, { - self.into_future( + self.as_future( close_rx, - self.server_loop(rx, muxstream_sender, buffer_size, close_tx), + close_tx.clone(), + self.server_loop(rx, muxstream_sender, close_tx), ) .await } @@ -151,20 +188,23 @@ impl MuxInner { self, rx: R, close_rx: mpsc::Receiver, + close_tx: mpsc::Sender, ) -> Result<(), WispError> where R: ws::WebSocketRead, { - self.into_future(close_rx, self.client_loop(rx)).await + self.as_future(close_rx, close_tx, self.client_loop(rx)) + .await } - async fn into_future( + async fn as_future( &self, close_rx: mpsc::Receiver, + close_tx: mpsc::Sender, wisp_fut: impl Future>, ) -> Result<(), WispError> { let ret = futures::select! { - _ = self.stream_loop(close_rx).fuse() => Ok(()), + _ = self.stream_loop(close_rx, close_tx).fuse() => Ok(()), x = wisp_fut.fuse() => x, }; self.stream_map.iter_mut().for_each(|mut x| { @@ -176,7 +216,12 @@ impl MuxInner { ret } - async fn stream_loop(&self, mut stream_rx: mpsc::Receiver) { + async fn stream_loop( + &self, + mut stream_rx: mpsc::Receiver, + stream_tx: mpsc::Sender, + ) { + let mut next_free_stream_id: u32 = 1; while let Some(msg) = stream_rx.next().await { match msg { WsEvent::SendPacket(packet, channel) => { @@ -186,6 +231,53 @@ impl MuxInner { let _ = channel.send(Err(WispError::InvalidStreamId)); } } + WsEvent::CreateStream(stream_type, host, port, channel) => { + let ret: Result = async { + let (ch_tx, ch_rx) = mpsc::unbounded(); + let stream_id = next_free_stream_id; + let next_stream_id = next_free_stream_id + .checked_add(1) + .ok_or(WispError::MaxStreamCountReached)?; + + let flow_control_event: Arc = Event::new().into(); + let flow_control: Arc = AtomicU32::new(self.buffer_size).into(); + + let is_closed: Arc = AtomicBool::new(false).into(); + + self.tx + .write_frame( + Packet::new_connect(stream_id, stream_type, port, host).into(), + ) + .await?; + + next_free_stream_id = next_stream_id; + + self.stream_map.insert( + stream_id, + MuxMapValue { + stream: ch_tx, + stream_type, + flow_control: flow_control.clone(), + flow_control_event: flow_control_event.clone(), + is_closed: is_closed.clone(), + }, + ); + + Ok(MuxStream::new( + stream_id, + Role::Client, + stream_type, + ch_rx, + stream_tx.clone(), + is_closed, + flow_control, + flow_control_event, + 0, + )) + } + .await; + let _ = channel.send(ret); + } WsEvent::Close(packet, channel) => { if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) { stream.stream.disconnect(); @@ -204,17 +296,13 @@ impl MuxInner { &self, mut rx: R, muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>, - buffer_size: u32, close_tx: mpsc::Sender, ) -> Result<(), WispError> where R: ws::WebSocketRead, { // will send continues once flow_control is at 10% of max - let target_buffer_size = ((buffer_size as u64 * 90) / 100) as u32; - self.tx - .write_frame(Packet::new_continue(0, buffer_size).into()) - .await?; + let target_buffer_size = ((self.buffer_size as u64 * 90) / 100) as u32; loop { let frame = rx.wisp_read_frame(&self.tx).await?; @@ -228,7 +316,7 @@ impl MuxInner { Connect(inner_packet) => { let (ch_tx, ch_rx) = mpsc::unbounded(); let stream_type = inner_packet.stream_type; - let flow_control: Arc = AtomicU32::new(buffer_size).into(); + let flow_control: Arc = AtomicU32::new(self.buffer_size).into(); let flow_control_event: Arc = Event::new().into(); let is_closed: Arc = AtomicBool::new(false).into(); @@ -273,7 +361,7 @@ impl MuxInner { } } } - Continue(_) => break Err(WispError::InvalidPacketType), + Continue(_) | Info(_) => break Err(WispError::InvalidPacketType), Close(_) => { if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) { stream.is_closed.store(true, Ordering::Release); @@ -298,7 +386,7 @@ impl MuxInner { use PacketType::*; match packet.packet_type { - Connect(_) => break Err(WispError::InvalidPacketType), + Connect(_) | Info(_) => break Err(WispError::InvalidPacketType), Data(data) => { if let Some(stream) = self.stream_map.get(&packet.stream_id) { let _ = stream.stream.unbounded_send(data); @@ -332,7 +420,7 @@ impl MuxInner { /// ``` /// use wisp_mux::ServerMux; /// -/// let (mux, fut) = ServerMux::new(rx, tx, 128); +/// let (mux, fut) = ServerMux::new(rx, tx, 128, Some(vec![]), Some([])); /// tokio::spawn(async move { /// if let Err(e) = fut.await { /// println!("error in multiplexor: {:?}", e); @@ -346,34 +434,89 @@ impl MuxInner { /// } /// ``` pub struct ServerMux { + /// Whether the connection was downgraded to Wisp v1. + /// + /// If this variable is true you must assume no extensions are supported. + pub downgraded: bool, + /// Extensions that are supported by both sides. + pub supported_extensions: Arc<[AnyProtocolExtension]>, close_tx: mpsc::Sender, muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream)>, } impl ServerMux { /// Create a new server-side multiplexor. - pub fn new( - read: R, + /// + /// If either extensions or extension_builders are None a Wisp v1 connection is created + /// otherwise a Wisp v2 connection is created. + pub async fn new( + mut read: R, write: W, buffer_size: u32, - ) -> (Self, impl Future>) + extensions: Option>, + extension_builders: Option<&[&(dyn ProtocolExtensionBuilder + Sync)]>, + ) -> Result<(Self, impl Future> + Send), WispError> where - R: ws::WebSocketRead, + R: ws::WebSocketRead + Send, + W: ws::WebSocketWrite + Send + 'static, { let (close_tx, close_rx) = mpsc::channel::(256); let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>(); - let write = ws::LockedWebSocketWrite::new(write); - ( + let write = ws::LockedWebSocketWrite::new(Box::new(write)); + + write + .write_frame(Packet::new_continue(0, buffer_size).into()) + .await?; + + let mut supported_extensions = Vec::new(); + let mut extra_packet = Vec::with_capacity(1); + let mut downgraded = true; + + if let Some(extensions) = extensions { + if let Some(builders) = extension_builders { + let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect(); + write + .write_frame(Packet::new_info(extensions).into()) + .await?; + if let Some(frame) = select! { + x = read.wisp_read_frame(&write).fuse() => Some(x?), + // TODO change this to correct timeout once draft 2 is out + _ = Delay::new(Duration::from_secs(5)).fuse() => None + } { + let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?; + if let PacketType::Info(info) = packet.packet_type { + supported_extensions = info + .extensions + .into_iter() + .filter(|x| extension_ids.contains(&x.get_id())) + .collect(); + downgraded = false; + } else { + extra_packet.push(packet.into()); + } + } + } + } + + Ok(( Self { muxstream_recv: rx, close_tx: close_tx.clone(), + downgraded, + supported_extensions: supported_extensions.into(), }, MuxInner { tx: write, - stream_map: DashMap::new().into(), + stream_map: DashMap::new(), + buffer_size, } - .server_into_future(read, close_rx, tx, buffer_size, close_tx), - ) + .server_into_future( + AppendingWebSocketRead(extra_packet, read), + close_rx, + tx, + close_tx, + ), + )) } /// Wait for a stream to be created. @@ -398,7 +541,7 @@ impl ServerMux { /// ``` /// use wisp_mux::{ClientMux, StreamType}; /// -/// let (mux, fut) = ClientMux::new(rx, tx).await?; +/// let (mux, fut) = ClientMux::new(rx, tx, Some(vec![]), []).await?; /// tokio::spawn(async move { /// if let Err(e) = fut.await { /// println!("error in multiplexor: {:?}", e); @@ -406,50 +549,88 @@ impl ServerMux { /// }); /// let stream = mux.client_new_stream(StreamType::Tcp, "google.com", 80); /// ``` -pub struct ClientMux -where - W: ws::WebSocketWrite, -{ - tx: ws::LockedWebSocketWrite, - stream_map: Arc>, - next_free_stream_id: AtomicU32, +pub struct ClientMux { + /// Whether the connection was downgraded to Wisp v1. + /// + /// If this variable is true you must assume no extensions are supported. + pub downgraded: bool, + /// Extensions that are supported by both sides. + pub supported_extensions: Arc<[AnyProtocolExtension]>, close_tx: mpsc::Sender, - buf_size: u32, - target_buf_size: u32, } -impl ClientMux { +impl ClientMux { /// Create a new client side multiplexor. - pub async fn new( + /// + /// If either extensions or extension_builders are None a Wisp v1 connection is created + /// otherwise a Wisp v2 connection is created. + pub async fn new( mut read: R, write: W, - ) -> Result<(Self, impl Future>), WispError> + extensions: Option>, + extension_builders: Option<&[&(dyn ProtocolExtensionBuilder + Sync)]>, + ) -> Result<(Self, impl Future> + Send), WispError> where - R: ws::WebSocketRead, + R: ws::WebSocketRead + Send, + W: ws::WebSocketWrite + Send + 'static, { - let write = ws::LockedWebSocketWrite::new(write); + let write = ws::LockedWebSocketWrite::new(Box::new(write)); let first_packet = Packet::try_from(read.wisp_read_frame(&write).await?)?; if first_packet.stream_id != 0 { return Err(WispError::InvalidStreamId); } if let PacketType::Continue(packet) = first_packet.packet_type { + let mut supported_extensions = Vec::new(); + let mut extra_packet = Vec::with_capacity(1); + let mut downgraded = true; + + if let Some(extensions) = extensions { + if let Some(builders) = extension_builders { + let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect(); + if let Some(frame) = select! { + x = read.wisp_read_frame(&write).fuse() => Some(x?), + // TODO change this to correct timeout once draft 2 is out + _ = Delay::new(Duration::from_secs(5)).fuse() => None + } { + let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?; + if let PacketType::Info(info) = packet.packet_type { + supported_extensions = info + .extensions + .into_iter() + .filter(|x| extension_ids.contains(&x.get_id())) + .collect(); + write + .write_frame(Packet::new_info(extensions).into()) + .await?; + downgraded = false; + } else { + extra_packet.push(packet.into()); + } + } + } + } + + for extension in supported_extensions.iter_mut() { + extension.handle_handshake(&mut read, &write).await?; + } + let (tx, rx) = mpsc::channel::(256); - let map = Arc::new(DashMap::new()); Ok(( Self { - tx: write.clone(), - stream_map: map.clone(), - next_free_stream_id: AtomicU32::new(1), close_tx: tx.clone(), - buf_size: packet.buffer_remaining, - // server-only - target_buf_size: 0, + downgraded, + supported_extensions: supported_extensions.into(), }, MuxInner { - tx: write.clone(), - stream_map: map.clone(), + tx: write, + stream_map: DashMap::new(), + buffer_size: packet.buffer_remaining, } - .client_into_future(read, rx), + .client_into_future( + AppendingWebSocketRead(extra_packet, read), + rx, + tx, + ), )) } else { Err(WispError::InvalidPacketType) @@ -458,51 +639,25 @@ impl ClientMux { /// Create a new stream, multiplexed through Wisp. pub async fn client_new_stream( - &self, + &mut self, stream_type: StreamType, host: String, port: u16, ) -> Result { - let (ch_tx, ch_rx) = mpsc::unbounded(); - let stream_id = self.next_free_stream_id.load(Ordering::Acquire); - let next_stream_id = stream_id - .checked_add(1) - .ok_or(WispError::MaxStreamCountReached)?; - - let flow_control_event: Arc = Event::new().into(); - let flow_control: Arc = AtomicU32::new(self.buf_size).into(); - - let is_closed: Arc = AtomicBool::new(false).into(); - - self.tx - .write_frame(Packet::new_connect(stream_id, stream_type, port, host).into()) - .await?; - - self.next_free_stream_id - .store(next_stream_id, Ordering::Release); - - self.stream_map.insert( - stream_id, - MuxMapValue { - stream: ch_tx, - stream_type, - flow_control: flow_control.clone(), - flow_control_event: flow_control_event.clone(), - is_closed: is_closed.clone(), - }, - ); - - Ok(MuxStream::new( - stream_id, - Role::Client, - stream_type, - ch_rx, - self.close_tx.clone(), - is_closed, - flow_control, - flow_control_event, - self.target_buf_size, - )) + if stream_type == StreamType::Udp + && !self + .supported_extensions + .iter() + .any(|x| x.get_id() == UdpProtocolExtension::ID) + { + return Err(WispError::UdpExtensionNotSupported); + } + let (tx, rx) = oneshot::channel(); + self.close_tx + .send(WsEvent::CreateStream(stream_type, host, port, tx)) + .await + .map_err(|_| WispError::MuxMessageFailedToSend)?; + rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)? } /// Close all streams. diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index d3fb8c7..c2b2459 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -1,4 +1,8 @@ -use crate::{ws, WispError}; +use crate::{ + extensions::{AnyProtocolExtension, ProtocolExtensionBuilder}, + ws::{self, Frame, OpCode}, + Role, WispError, WISP_VERSION, +}; use bytes::{Buf, BufMut, Bytes, BytesMut}; /// Wisp stream type. @@ -34,6 +38,8 @@ pub enum CloseReason { Voluntary = 0x02, /// Unexpected stream closure due to a network error. Unexpected = 0x03, + /// Incompatible extensions. Only used during the handshake. + IncompatibleExtensions = 0x04, /// Stream creation failed due to invalid information. ServerStreamInvalidInfo = 0x41, /// Stream creation failed due to an unreachable destination host. @@ -55,19 +61,20 @@ pub enum CloseReason { impl TryFrom for CloseReason { type Error = WispError; fn try_from(stream_type: u8) -> Result { - use CloseReason::*; + use CloseReason as R; match stream_type { - 0x01 => Ok(Unknown), - 0x02 => Ok(Voluntary), - 0x03 => Ok(Unexpected), - 0x41 => Ok(ServerStreamInvalidInfo), - 0x42 => Ok(ServerStreamUnreachable), - 0x43 => Ok(ServerStreamConnectionTimedOut), - 0x44 => Ok(ServerStreamConnectionRefused), - 0x47 => Ok(ServerStreamTimedOut), - 0x48 => Ok(ServerStreamBlockedAddress), - 0x49 => Ok(ServerStreamThrottled), - 0x81 => Ok(ClientUnexpected), + 0x01 => Ok(R::Unknown), + 0x02 => Ok(R::Voluntary), + 0x03 => Ok(R::Unexpected), + 0x04 => Ok(R::IncompatibleExtensions), + 0x41 => Ok(R::ServerStreamInvalidInfo), + 0x42 => Ok(R::ServerStreamUnreachable), + 0x43 => Ok(R::ServerStreamConnectionTimedOut), + 0x44 => Ok(R::ServerStreamConnectionRefused), + 0x47 => Ok(R::ServerStreamTimedOut), + 0x48 => Ok(R::ServerStreamBlockedAddress), + 0x49 => Ok(R::ServerStreamThrottled), + 0x81 => Ok(R::ClientUnexpected), _ => Err(Self::Error::InvalidStreamType), } } @@ -198,6 +205,38 @@ impl From for Bytes { } } +/// Wisp version sent in the handshake. +#[derive(Debug, Clone)] +pub struct WispVersion { + /// Major Wisp version according to semver. + pub major: u8, + /// Minor Wisp version according to semver. + pub minor: u8, +} + +/// Packet used in the initial handshake. +/// +/// See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x05---info) +#[derive(Debug, Clone)] +pub struct InfoPacket { + /// Wisp version sent in the packet. + pub version: WispVersion, + /// List of protocol extensions sent in the packet. + pub extensions: Vec, +} + +impl From for Bytes { + fn from(value: InfoPacket) -> Self { + let mut bytes = BytesMut::with_capacity(2); + bytes.put_u8(value.version.major); + bytes.put_u8(value.version.minor); + for extension in value.extensions { + bytes.extend(Bytes::from(extension)); + } + bytes.freeze() + } +} + #[derive(Debug, Clone)] /// Type of packet recieved. pub enum PacketType { @@ -209,6 +248,8 @@ pub enum PacketType { Continue(ContinuePacket), /// Close packet. Close(ClosePacket), + /// Info packet. + Info(InfoPacket), } impl PacketType { @@ -220,6 +261,7 @@ impl PacketType { Data(_) => 0x02, Continue(_) => 0x03, Close(_) => 0x04, + Info(_) => 0x05, } } } @@ -232,6 +274,7 @@ impl From for Bytes { Data(x) => x, Continue(x) => x.into(), Close(x) => x.into(), + Info(x) => x.into(), } } } @@ -296,15 +339,18 @@ impl Packet { packet_type: PacketType::Close(ClosePacket::new(reason)), } } -} -impl TryFrom for Packet { - type Error = WispError; - fn try_from(mut bytes: Bytes) -> Result { - if bytes.remaining() < 5 { - return Err(Self::Error::PacketTooSmall); + pub(crate) fn new_info(extensions: Vec) -> Self { + Self { + stream_id: 0, + packet_type: PacketType::Info(InfoPacket { + version: WISP_VERSION, + extensions, + }), } - let packet_type = bytes.get_u8(); + } + + fn parse_packet(packet_type: u8, mut bytes: Bytes) -> Result { use PacketType::*; Ok(Self { stream_id: bytes.get_u32_le(), @@ -313,10 +359,88 @@ impl TryFrom for Packet { 0x02 => Data(bytes), 0x03 => Continue(ContinuePacket::try_from(bytes)?), 0x04 => Close(ClosePacket::try_from(bytes)?), - _ => return Err(Self::Error::InvalidPacketType), + // 0x05 is handled seperately + _ => return Err(WispError::InvalidPacketType), }, }) } + + pub(crate) fn maybe_parse_info( + frame: Frame, + role: Role, + extension_builders: &[&(dyn ProtocolExtensionBuilder + Sync)], + ) -> Result { + if !frame.finished { + return Err(WispError::WsFrameNotFinished); + } + if frame.opcode != OpCode::Binary { + return Err(WispError::WsFrameInvalidType); + } + let mut bytes = frame.payload; + if bytes.remaining() < 1 { + return Err(WispError::PacketTooSmall); + } + let packet_type = bytes.get_u8(); + if packet_type == 0x05 { + Self::parse_info(bytes, role, extension_builders) + } else { + Self::parse_packet(packet_type, bytes) + } + } + + fn parse_info( + mut bytes: Bytes, + role: Role, + extension_builders: &[&(dyn ProtocolExtensionBuilder + Sync)], + ) -> Result { + // packet type is already read by code that calls this + if bytes.remaining() < 4 + 2 { + return Err(WispError::PacketTooSmall); + } + if bytes.get_u32_le() != 0 { + return Err(WispError::InvalidStreamId); + } + + let version = WispVersion { + major: bytes.get_u8(), + minor: bytes.get_u8(), + }; + + let mut extensions = Vec::new(); + + while bytes.remaining() > 4 { + // We have some extensions + let id = bytes.get_u8(); + let length = usize::try_from(bytes.get_u32_le())?; + if bytes.remaining() < length { + return Err(WispError::PacketTooSmall); + } + if let Some(builder) = extension_builders.iter().find(|x| x.get_id() == id) { + extensions.push(builder.build(bytes.copy_to_bytes(length), role)) + } else { + bytes.advance(length) + } + } + + Ok(Self { + stream_id: 0, + packet_type: PacketType::Info(InfoPacket { + version, + extensions, + }), + }) + } +} + +impl TryFrom for Packet { + type Error = WispError; + fn try_from(mut bytes: Bytes) -> Result { + if bytes.remaining() < 5 { + return Err(Self::Error::PacketTooSmall); + } + let packet_type = bytes.get_u8(); + Self::parse_packet(packet_type, bytes) + } } impl From for Bytes { diff --git a/wisp/src/sink_unfold.rs b/wisp/src/sink_unfold.rs index c82254a..dfb170e 100644 --- a/wisp/src/sink_unfold.rs +++ b/wisp/src/sink_unfold.rs @@ -1,8 +1,10 @@ //! futures sink unfold with a close function use core::{future::Future, pin::Pin}; -use futures::ready; -use futures::task::{Context, Poll}; -use futures::Sink; +use futures::{ + ready, + task::{Context, Poll}, + Sink, +}; use pin_project_lite::pin_project; pin_project! { diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index 109b9ab..f579140 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -21,6 +21,12 @@ use std::{ pub(crate) enum WsEvent { SendPacket(Packet, oneshot::Sender>), Close(Packet, oneshot::Sender>), + CreateStream( + StreamType, + String, + u16, + oneshot::Sender>, + ), EndFut, } @@ -317,7 +323,10 @@ impl MuxStreamIo { impl Stream for MuxStreamIo { type Item = Result, std::io::Error>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.project().rx.poll_next(cx).map(|x| x.map(|x| Ok(x.to_vec()))) + self.project() + .rx + .poll_next(cx) + .map(|x| x.map(|x| Ok(x.to_vec()))) } } diff --git a/wisp/src/ws.rs b/wisp/src/ws.rs index af57572..7348bb8 100644 --- a/wisp/src/ws.rs +++ b/wisp/src/ws.rs @@ -4,9 +4,10 @@ //! for other WebSocket implementations. //! //! [`fastwebsockets`]: https://github.com/MercuryWorkshop/epoxy-tls/blob/multiplexed/wisp/src/fastwebsockets.rs +use crate::WispError; +use async_trait::async_trait; use bytes::Bytes; use futures::lock::Mutex; -use std::sync::Arc; /// Opcode of the WebSocket frame. #[derive(Debug, PartialEq, Clone, Copy)] @@ -64,30 +65,26 @@ impl Frame { } /// Generic WebSocket read trait. +#[async_trait] pub trait WebSocketRead { /// Read a frame from the socket. - fn wisp_read_frame( - &mut self, - tx: &crate::ws::LockedWebSocketWrite, - ) -> impl std::future::Future>; + async fn wisp_read_frame(&mut self, tx: &LockedWebSocketWrite) -> Result; } /// Generic WebSocket write trait. +#[async_trait] pub trait WebSocketWrite { /// Write a frame to the socket. - fn wisp_write_frame( - &mut self, - frame: Frame, - ) -> impl std::future::Future>; + async fn wisp_write_frame(&mut self, frame: Frame) -> Result<(), WispError>; } -/// Locked WebSocket that can be shared between threads. -pub struct LockedWebSocketWrite(Arc>); +/// Locked WebSocket. +pub struct LockedWebSocketWrite(Mutex>); -impl LockedWebSocketWrite { +impl LockedWebSocketWrite { /// Create a new locked websocket. - pub fn new(ws: S) -> Self { - Self(Arc::new(Mutex::new(ws))) + pub fn new(ws: Box) -> Self { + Self(Mutex::new(ws)) } /// Write a frame to the websocket. @@ -96,8 +93,19 @@ impl LockedWebSocketWrite { } } -impl Clone for LockedWebSocketWrite { - fn clone(&self) -> Self { - Self(self.0.clone()) +pub(crate) struct AppendingWebSocketRead(pub Vec, pub R) +where + R: WebSocketRead + Send; + +#[async_trait] +impl WebSocketRead for AppendingWebSocketRead +where + R: WebSocketRead + Send, +{ + async fn wisp_read_frame(&mut self, tx: &LockedWebSocketWrite) -> Result { + if let Some(x) = self.0.pop() { + return Ok(x); + } + return self.1.wisp_read_frame(tx).await; } } From b0d1038a3c5fe67fc109cba995377106000a3d05 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Fri, 12 Apr 2024 17:18:56 -0700 Subject: [PATCH 02/14] call wisp v2 extension packet handlers --- Cargo.lock | 693 +++++++++++++++++++++++++++++++++++++- Cargo.toml | 2 +- certs-grabber/Cargo.toml | 13 + certs-grabber/src/main.rs | 64 ++++ client/Cargo.toml | 2 +- client/build.sh | 9 + client/demo.js | 9 +- client/src/lib.rs | 42 +-- client/src/utils.rs | 23 +- client/tests/fetch.rs | 32 +- wisp/src/lib.rs | 196 ++++++----- wisp/src/packet.rs | 30 +- 12 files changed, 974 insertions(+), 141 deletions(-) create mode 100644 certs-grabber/Cargo.toml create mode 100644 certs-grabber/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index fb57868..8fa37cd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -41,6 +41,21 @@ dependencies = [ "alloc-no-stdlib", ] +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anstream" version = "0.6.13" @@ -95,6 +110,84 @@ version = "1.0.82" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519" +[[package]] +name = "asn1-rs" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f6fd5ddaf0351dff5b8da21b2fb4ff8e08ddd02857f0bf69c47639106c0fff0" +dependencies = [ + "asn1-rs-derive 0.4.0", + "asn1-rs-impl 0.1.0", + "displaydoc", + "nom", + "num-traits", + "rusticata-macros", + "thiserror", + "time", +] + +[[package]] +name = "asn1-rs" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ad1373757efa0f70ec53939aabc7152e1591cb485208052993070ac8d2429d" +dependencies = [ + "asn1-rs-derive 0.5.0", + "asn1-rs-impl 0.2.0", + "displaydoc", + "nom", + "num-traits", + "rusticata-macros", + "thiserror", + "time", +] + +[[package]] +name = "asn1-rs-derive" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "726535892e8eae7e70657b4c8ea93d26b8553afb1ce617caee529ef96d7dee6c" +dependencies = [ + "proc-macro2 1.0.79", + "quote 1.0.36", + "syn 1.0.109", + "synstructure 0.12.6", +] + +[[package]] +name = "asn1-rs-derive" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7378575ff571966e99a744addeff0bff98b8ada0dedf1956d59e634db95eaac1" +dependencies = [ + "proc-macro2 1.0.79", + "quote 1.0.36", + "syn 2.0.58", + "synstructure 0.13.1", +] + +[[package]] +name = "asn1-rs-impl" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2777730b2039ac0f95f093556e61b6d26cebed5393ca6f152717777cec3a42ed" +dependencies = [ + "proc-macro2 1.0.79", + "quote 1.0.36", + "syn 1.0.109", +] + +[[package]] +name = "asn1-rs-impl" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b18050c2cd6fe86c3a76584ef5e0baf286d038cda203eb6223df2cc413565f7" +dependencies = [ + "proc-macro2 1.0.79", + "quote 1.0.36", + "syn 2.0.58", +] + [[package]] name = "async-compression" version = "0.4.8" @@ -231,6 +324,12 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "base64" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9475866fec1451be56a3c2400fd081ff546538961565ccb5b7142cbd22bc7a51" + [[package]] name = "bitflags" version = "1.3.2" @@ -297,12 +396,37 @@ version = "1.0.92" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2678b2e3449475e95b0aa6f9b506a28e61b3dc8996592b983695e8ebb58a8b41" +[[package]] +name = "certs-grabber" +version = "0.1.0" +dependencies = [ + "hex", + "ring", + "rustls-pki-types", + "rustls-webpki 0.102.2", + "tokio", + "webpki-ccadb", + "x509-parser 0.16.0", +] + [[package]] name = "cfg-if" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chrono" +version = "0.4.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a0d04d43504c61aa6c7531f1871dd0d418d91130162063b789da00fd7057a5e" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "num-traits", + "windows-targets 0.52.4", +] + [[package]] name = "clap" version = "4.5.4" @@ -480,6 +604,27 @@ dependencies = [ "typenum", ] +[[package]] +name = "csv" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70" +dependencies = [ + "memchr", +] + [[package]] name = "dashmap" version = "5.5.3" @@ -493,6 +638,12 @@ dependencies = [ "parking_lot_core", ] +[[package]] +name = "data-encoding" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e962a19be5cfc3f3bf6dd8f61eb50107f356ad6270fbb3ed41476571db78be5" + [[package]] name = "default-env" version = "0.1.1" @@ -504,6 +655,43 @@ dependencies = [ "syn 0.15.44", ] +[[package]] +name = "der-parser" +version = "8.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbd676fbbab537128ef0278adb5576cf363cff6aa22a7b24effe97347cfab61e" +dependencies = [ + "asn1-rs 0.5.2", + "displaydoc", + "nom", + "num-bigint", + "num-traits", + "rusticata-macros", +] + +[[package]] +name = "der-parser" +version = "9.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cd0a5c643689626bec213c4d8bd4d96acc8ffdb4ad4bb6bc16abf27d5f4b553" +dependencies = [ + "asn1-rs 0.6.1", + "displaydoc", + "nom", + "num-bigint", + "num-traits", + "rusticata-macros", +] + +[[package]] +name = "deranged" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +dependencies = [ + "powerfmt", +] + [[package]] name = "digest" version = "0.10.7" @@ -514,12 +702,32 @@ dependencies = [ "crypto-common", ] +[[package]] +name = "displaydoc" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "487585f4d0c6655fe74905e2504d8ad6908e4db67f744eb140876906c2f3175d" +dependencies = [ + "proc-macro2 1.0.79", + "quote 1.0.36", + "syn 2.0.58", +] + [[package]] name = "either" version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a" +[[package]] +name = "encoding_rs" +version = "0.8.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b45de904aa0b010bce2ab45264d0631681847fa7b6f2eaa7dab7619943bc4f59" +dependencies = [ + "cfg-if", +] + [[package]] name = "epoxy-client" version = "1.5.1" @@ -527,7 +735,7 @@ dependencies = [ "async-compression", "async-trait", "async_io_stream", - "base64", + "base64 0.21.7", "bytes", "console_error_panic_hook", "default-env", @@ -545,7 +753,7 @@ dependencies = [ "rustls-pki-types", "send_wrapper 0.6.0", "tokio", - "tokio-rustls", + "tokio-rustls 0.25.0", "tokio-util", "tower-service", "wasm-bindgen", @@ -629,7 +837,7 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ac14c1f19ff7eab47c9ee7263a088296bec2abd1f9345964c863b5134da40a1" dependencies = [ - "base64", + "base64 0.21.7", "bytes", "http-body-util", "hyper 1.2.0", @@ -674,6 +882,15 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" +[[package]] +name = "form_urlencoded" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +dependencies = [ + "percent-encoding", +] + [[package]] name = "futures" version = "0.3.30" @@ -870,7 +1087,7 @@ version = "7.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "765c9198f173dd59ce26ff9f95ef0aafd0a0fe01fb9d72841bc5066a4c06511d" dependencies = [ - "base64", + "base64 0.21.7", "byteorder", "flate2", "nom", @@ -889,6 +1106,12 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + [[package]] name = "http" version = "0.2.12" @@ -1008,6 +1231,20 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-rustls" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" +dependencies = [ + "futures-util", + "http 0.2.12", + "hyper 0.14.28", + "rustls 0.21.10", + "tokio", + "tokio-rustls 0.24.1", +] + [[package]] name = "hyper-timeout" version = "0.4.1" @@ -1020,6 +1257,19 @@ dependencies = [ "tokio-io-timeout", ] +[[package]] +name = "hyper-tls" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" +dependencies = [ + "bytes", + "hyper 0.14.28", + "native-tls", + "tokio", + "tokio-native-tls", +] + [[package]] name = "hyper-util" version = "0.1.3" @@ -1056,6 +1306,39 @@ dependencies = [ "wasmtimer", ] +[[package]] +name = "iana-time-zone" +version = "0.1.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "idna" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + [[package]] name = "indexmap" version = "1.9.3" @@ -1076,6 +1359,12 @@ dependencies = [ "hashbrown 0.14.3", ] +[[package]] +name = "ipnet" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" + [[package]] name = "is-terminal" version = "0.4.12" @@ -1226,6 +1515,32 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "num-bigint" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.18" @@ -1254,6 +1569,24 @@ dependencies = [ "memchr", ] +[[package]] +name = "oid-registry" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9bedf36ffb6ba96c2eb7144ef6270557b52e54b20c0a8e1eb2ff99a6c6959bff" +dependencies = [ + "asn1-rs 0.5.2", +] + +[[package]] +name = "oid-registry" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c958dd45046245b9c3c2547369bb634eb461670b2e7e0de552905801a648d1d" +dependencies = [ + "asn1-rs 0.6.1", +] + [[package]] name = "once_cell" version = "1.19.0" @@ -1377,6 +1710,12 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -1389,7 +1728,7 @@ version = "0.4.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf3d2011ab5c909338f7887f4fc896d35932e29146c12c8d01da6b22a80ba759" dependencies = [ - "unicode-xid", + "unicode-xid 0.1.0", ] [[package]] @@ -1534,6 +1873,49 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" +[[package]] +name = "reqwest" +version = "0.11.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62" +dependencies = [ + "base64 0.21.7", + "bytes", + "encoding_rs", + "futures-core", + "futures-util", + "h2 0.3.26", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.28", + "hyper-rustls", + "hyper-tls", + "ipnet", + "js-sys", + "log", + "mime", + "native-tls", + "once_cell", + "percent-encoding", + "pin-project-lite", + "rustls 0.21.10", + "rustls-pemfile 1.0.4", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "system-configuration", + "tokio", + "tokio-native-tls", + "tokio-rustls 0.24.1", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "winreg", +] + [[package]] name = "ring" version = "0.17.8" @@ -1564,6 +1946,15 @@ dependencies = [ "semver", ] +[[package]] +name = "rusticata-macros" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "faf0c4a6ece9950b9abdb62b1cfcf2a68b3b67a10ba445b3bb85be2a293d0632" +dependencies = [ + "nom", +] + [[package]] name = "rustix" version = "0.38.32" @@ -1577,6 +1968,18 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rustls" +version = "0.21.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba" +dependencies = [ + "log", + "ring", + "rustls-webpki 0.101.7", + "sct", +] + [[package]] name = "rustls" version = "0.22.3" @@ -1586,11 +1989,30 @@ dependencies = [ "log", "ring", "rustls-pki-types", - "rustls-webpki", + "rustls-webpki 0.102.2", "subtle", "zeroize", ] +[[package]] +name = "rustls-pemfile" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" +dependencies = [ + "base64 0.21.7", +] + +[[package]] +name = "rustls-pemfile" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d" +dependencies = [ + "base64 0.22.0", + "rustls-pki-types", +] + [[package]] name = "rustls-pki-types" version = "1.4.1" @@ -1600,6 +2022,16 @@ dependencies = [ "web-time", ] +[[package]] +name = "rustls-webpki" +version = "0.101.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "rustls-webpki" version = "0.102.2" @@ -1653,6 +2085,16 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "sct" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "security-framework" version = "2.10.0" @@ -1725,6 +2167,18 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + [[package]] name = "sha1" version = "0.10.6" @@ -1840,7 +2294,18 @@ checksum = "9ca4b3b69a77cbe1ffc9e198781b7acb0c7365a883670e8f1c1bc66fba79a5c5" dependencies = [ "proc-macro2 0.4.30", "quote 0.6.13", - "unicode-xid", + "unicode-xid 0.1.0", +] + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2 1.0.79", + "quote 1.0.36", + "unicode-ident", ] [[package]] @@ -1860,6 +2325,50 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" +[[package]] +name = "synstructure" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f36bdaa60a83aca3921b5259d5400cbf5e90fc51931376a9bd4a0eb79aa7210f" +dependencies = [ + "proc-macro2 1.0.79", + "quote 1.0.36", + "syn 1.0.109", + "unicode-xid 0.2.4", +] + +[[package]] +name = "synstructure" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" +dependencies = [ + "proc-macro2 1.0.79", + "quote 1.0.36", + "syn 2.0.58", +] + +[[package]] +name = "system-configuration" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "tempfile" version = "3.10.1" @@ -1912,6 +2421,52 @@ dependencies = [ "once_cell", ] +[[package]] +name = "time" +version = "0.3.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885" +dependencies = [ + "deranged", + "itoa", + "num-conv", + "powerfmt", + "serde", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" + +[[package]] +name = "time-macros" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf" +dependencies = [ + "num-conv", + "time-core", +] + +[[package]] +name = "tinyvec" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + [[package]] name = "tokio" version = "1.37.0" @@ -1963,13 +2518,23 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" +dependencies = [ + "rustls 0.21.10", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "775e0c0f0adb3a2f22a00c4745d728b479985fc15ee7ca6a2608388c5569860f" dependencies = [ - "rustls", + "rustls 0.22.3", "rustls-pki-types", "tokio", ] @@ -2008,7 +2573,7 @@ dependencies = [ "async-stream", "async-trait", "axum", - "base64", + "base64 0.21.7", "bytes", "h2 0.3.26", "http 0.2.12", @@ -2118,24 +2683,56 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +[[package]] +name = "unicode-bidi" +version = "0.3.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" + [[package]] name = "unicode-ident" version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "unicode-normalization" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" +dependencies = [ + "tinyvec", +] + [[package]] name = "unicode-xid" version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc72304796d0818e357ead4e000d19c9c174ab23dc11093ac919054d20a6a7fc" +[[package]] +name = "unicode-xid" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" + [[package]] name = "untrusted" version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "url" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + [[package]] name = "utf-8" version = "0.7.6" @@ -2329,6 +2926,25 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki-ccadb" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90e7a3db4ffc4795bc5d07067eb1cb40cdb8c4ebd9f2e14a5c67ad7f809408ef" +dependencies = [ + "chrono", + "csv", + "hex", + "num-bigint", + "reqwest", + "rustls-pemfile 2.1.2", + "rustls-pki-types", + "rustls-webpki 0.102.2", + "serde", + "x509-parser 0.15.1", + "yasna", +] + [[package]] name = "webpki-roots" version = "0.26.1" @@ -2369,6 +2985,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets 0.52.4", +] + [[package]] name = "windows-sys" version = "0.42.0" @@ -2558,6 +3183,16 @@ version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" +[[package]] +name = "winreg" +version = "0.50.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" +dependencies = [ + "cfg-if", + "windows-sys 0.48.0", +] + [[package]] name = "wisp-mux" version = "4.0.0" @@ -2575,6 +3210,46 @@ dependencies = [ "tokio", ] +[[package]] +name = "x509-parser" +version = "0.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7069fba5b66b9193bd2c5d3d4ff12b839118f6bcbef5328efafafb5395cf63da" +dependencies = [ + "asn1-rs 0.5.2", + "data-encoding", + "der-parser 8.2.0", + "lazy_static", + "nom", + "oid-registry 0.6.1", + "rusticata-macros", + "thiserror", + "time", +] + +[[package]] +name = "x509-parser" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcbc162f30700d6f3f82a24bf7cc62ffe7caea42c0b2cba8bf7f3ae50cf51f69" +dependencies = [ + "asn1-rs 0.6.1", + "data-encoding", + "der-parser 9.0.0", + "lazy_static", + "nom", + "oid-registry 0.7.0", + "rusticata-macros", + "thiserror", + "time", +] + +[[package]] +name = "yasna" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd" + [[package]] name = "zeroize" version = "1.7.0" diff --git a/Cargo.toml b/Cargo.toml index 9c822e0..6ad022d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "2" -members = ["server", "client", "wisp", "simple-wisp-client"] +members = ["server", "client", "wisp", "simple-wisp-client", "certs-grabber"] default-members = ["server"] [profile.release] diff --git a/certs-grabber/Cargo.toml b/certs-grabber/Cargo.toml new file mode 100644 index 0000000..f68fda9 --- /dev/null +++ b/certs-grabber/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "certs-grabber" +version = "0.1.0" +edition = "2021" + +[dependencies] +hex = "0.4.3" +ring = "0.17.8" +rustls-pki-types = "1.4.1" +rustls-webpki = "0.102.2" +tokio = { version = "1.37.0", features = ["full"] } +webpki-ccadb = "0.1.0" +x509-parser = "0.16.0" diff --git a/certs-grabber/src/main.rs b/certs-grabber/src/main.rs new file mode 100644 index 0000000..11574b8 --- /dev/null +++ b/certs-grabber/src/main.rs @@ -0,0 +1,64 @@ +use std::fmt::Write; + +use ring::digest::{digest, SHA256}; +use rustls_pki_types::{CertificateDer, TrustAnchor}; +use webpki::anchor_from_trusted_cert; +use webpki_ccadb::fetch_ccadb_roots; + +#[tokio::main] +async fn main() { + let tls_roots_map = fetch_ccadb_roots().await; + let mut code = String::with_capacity(256 * 1_024); + code.push_str("const ROOTS = ["); + for (_, root) in tls_roots_map { + // Verify the DER FP matches the metadata FP. + let der = root.der(); + let calculated_fp = digest(&SHA256, &der); + let metadata_fp = hex::decode(&root.sha256_fingerprint).expect("malformed fingerprint"); + assert_eq!(calculated_fp.as_ref(), metadata_fp.as_slice()); + + let ta_der = CertificateDer::from(der.as_ref()); + let TrustAnchor { + subject, + subject_public_key_info, + name_constraints, + } = anchor_from_trusted_cert(&ta_der).expect("malformed trust anchor der"); + + /* + let (_, parsed_cert) = + x509_parser::parse_x509_certificate(&der).expect("malformed x509 der"); + let issuer = name_to_string(parsed_cert.issuer()); + let subject_str = name_to_string(parsed_cert.subject()); + let label = root.common_name_or_certificate_name.clone(); + let serial = root.serial().to_string(); + let sha256_fp = root.sha256_fp(); + */ + + code.write_fmt(format_args!( + "{{subject:new Uint8Array([{}]),subject_public_key_info:new Uint8Array([{}]),name_constraints:{}}},", + subject + .as_ref() + .iter() + .map(|x| x.to_string()) + .collect::>().join(","), + subject_public_key_info + .as_ref() + .iter() + .map(|x| x.to_string()) + .collect::>().join(","), + if let Some(constraints) = name_constraints { + format!("new Uint8Array([{}])",constraints + .as_ref() + .iter() + .map(|x| x.to_string()) + .collect::>().join(",")) + } else { + "null".into() + } + )) + .unwrap(); + } + code.pop(); + code.push_str("];"); + println!("{}",code); +} diff --git a/client/Cargo.toml b/client/Cargo.toml index ee1d108..8489e2f 100644 --- a/client/Cargo.toml +++ b/client/Cargo.toml @@ -17,7 +17,6 @@ wasm-bindgen = { version = "0.2.91", features = ["enable-interning"] } wasm-bindgen-futures = "0.4.39" futures-util = "0.3.30" js-sys = "0.3.66" -webpki-roots = "0.26.0" tokio-rustls = "0.25.0" web-sys = { version = "0.3.66", features = ["Request", "RequestInit", "Headers", "Response", "ResponseInit", "WebSocket", "BinaryType", "MessageEvent"] } wasm-streams = "0.4.0" @@ -47,3 +46,4 @@ features = ["web"] default-env = "0.1.1" wasm-bindgen-test = "0.3.42" web-sys = { version = "0.3.69", features = ["FormData", "UrlSearchParams"] } +webpki-roots = "0.26.0" diff --git a/client/build.sh b/client/build.sh index 2d77294..9219e00 100755 --- a/client/build.sh +++ b/client/build.sh @@ -41,5 +41,14 @@ echo "}\ndeclare function epoxy(maybe_memory?: WebAssembly.Memory): Promise pkg/certs.js +cat pkg/certs.js > pkg/certs-module.js +echo "export default ROOTS;" >> pkg/certs-module.js +echo "[epx] fetching certs finished" + rm -r out/ echo "[epx] done!" diff --git a/client/demo.js b/client/demo.js index 67549d0..a5b35a2 100644 --- a/client/demo.js +++ b/client/demo.js @@ -1,4 +1,5 @@ import epoxy from "./pkg/epoxy-module-bundled.js"; +import CERTS from "./pkg/certs-module.js"; onmessage = async (msg) => { console.debug("recieved demo:", msg); @@ -29,13 +30,13 @@ onmessage = async (msg) => { postMessage(JSON.stringify(str, null, 4)); } - const { EpoxyClient, certs } = await epoxy(); + const { EpoxyClient } = await epoxy(); - console.log("certs:", certs()); + console.log("certs:", CERTS); const tconn0 = performance.now(); - // args: websocket url, user agent, redirect limit - let epoxy_client = await new EpoxyClient("ws://localhost:4000", navigator.userAgent, 10); + // args: websocket url, user agent, redirect limit, certs + let epoxy_client = await new EpoxyClient("ws://localhost:4000", navigator.userAgent, 10, CERTS); const tconn1 = performance.now(); log(`conn establish took ${tconn1 - tconn0} ms or ${(tconn1 - tconn0) / 1000} s`); diff --git a/client/src/lib.rs b/client/src/lib.rs index 0f8f678..5f5946d 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -10,6 +10,7 @@ mod wrappers; use tls_stream::EpxTlsStream; use tokioio::TokioIo; use udp_stream::EpxUdpStream; +use utils::object_to_trustanchor; pub use utils::{Boolinator, ReplaceErr, UriExt}; use websocket::EpxWebSocket; use wrappers::{IncomingBody, ServiceWrapper, TlsWispService, WebSocketWrapper}; @@ -70,38 +71,6 @@ fn init() { intern("Content-Type"); } -fn cert_to_jval(cert: &TrustAnchor) -> Result { - let val = Object::new(); - Reflect::set( - &val, - &jval!("subject"), - &Uint8Array::from(cert.subject.as_ref()), - )?; - Reflect::set( - &val, - &jval!("subject_public_key_info"), - &Uint8Array::from(cert.subject_public_key_info.as_ref()), - )?; - Reflect::set( - &val, - &jval!("name_constraints"), - &jval!(cert - .name_constraints - .as_ref() - .map(|x| Uint8Array::from(x.as_ref()))), - )?; - Ok(val.into()) -} - -#[wasm_bindgen] -pub fn certs() -> Result { - Ok(webpki_roots::TLS_SERVER_ROOTS - .iter() - .map(cert_to_jval) - .collect::>()? - .into()) -} - #[wasm_bindgen(inspectable)] pub struct EpoxyClient { rustls_config: Arc, @@ -120,6 +89,7 @@ impl EpoxyClient { ws_url: String, useragent: String, redirect_limit: usize, + certs: Array, ) -> Result { let ws_uri = ws_url .parse::() @@ -137,7 +107,13 @@ impl EpoxyClient { utils::spawn_mux_fut(mux.clone(), fut, ws_url.clone()); let mut certstore = RootCertStore::empty(); - certstore.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + let certs: Result, JsValue> = + certs.iter().map(object_to_trustanchor).collect(); + certstore.extend( + certs + .replace_err("Failed to get certificates from cert store")? + .into_iter(), + ); let rustls_config = Arc::new( rustls::ClientConfig::builder() diff --git a/client/src/utils.rs b/client/src/utils.rs index 3b05027..b956d02 100644 --- a/client/src/utils.rs +++ b/client/src/utils.rs @@ -1,5 +1,6 @@ use crate::*; +use rustls_pki_types::Der; use wasm_bindgen::prelude::*; use wasm_bindgen_futures::JsFuture; @@ -195,7 +196,13 @@ pub fn get_url_port(url: &Uri) -> Result { pub async fn make_mux( url: &str, -) -> Result<(ClientMux, impl Future> + Send), WispError> { +) -> Result< + ( + ClientMux, + impl Future> + Send, + ), + WispError, +> { let (wtx, wrx) = WebSocketWrapper::connect(url, vec![]) .await .map_err(|_| WispError::WsImplSocketClosed)?; @@ -264,3 +271,17 @@ pub async fn jval_to_u8_array_req(val: JsValue) -> Result<(Uint8Array, web_sys:: req, )) } + +pub fn object_to_trustanchor(obj: JsValue) -> Result, JsValue> { + let subject: Uint8Array = Reflect::get(&obj, &jval!("subject"))?.dyn_into()?; + let pub_key_info: Uint8Array = + Reflect::get(&obj, &jval!("subject_public_key_info"))?.dyn_into()?; + let name_constraints: Option = Reflect::get(&obj, &jval!("name_constraints")) + .and_then(|x| x.dyn_into()) + .ok(); + Ok(TrustAnchor { + subject: Der::from(subject.to_vec()), + subject_public_key_info: Der::from(pub_key_info.to_vec()), + name_constraints: name_constraints.map(|x| Der::from(x.to_vec())), + }) +} diff --git a/client/tests/fetch.rs b/client/tests/fetch.rs index 02eabb2..c7307fb 100644 --- a/client/tests/fetch.rs +++ b/client/tests/fetch.rs @@ -1,6 +1,7 @@ use default_env::default_env; use epoxy_client::EpoxyClient; -use js_sys::{JsString, Object, Reflect, Uint8Array, JSON}; +use js_sys::{Array, JsString, Object, Reflect, Uint8Array, JSON}; +use rustls_pki_types::TrustAnchor; use tokio::sync::OnceCell; use wasm_bindgen::JsValue; use wasm_bindgen_futures::JsFuture; @@ -12,11 +13,40 @@ wasm_bindgen_test_configure!(run_in_dedicated_worker); static USER_AGENT: &str = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/125.0.0.0 Safari/537.36"; static EPOXY_CLIENT: OnceCell = OnceCell::const_new(); +pub fn trustanchor_to_object(cert: &TrustAnchor) -> Result { + let val = Object::new(); + Reflect::set( + &val, + &JsValue::from("subject"), + &Uint8Array::from(cert.subject.as_ref()), + )?; + Reflect::set( + &val, + &JsValue::from("subject_public_key_info"), + &Uint8Array::from(cert.subject_public_key_info.as_ref()), + )?; + Reflect::set( + &val, + &JsValue::from("name_constraints"), + &JsValue::from( + cert.name_constraints + .as_ref() + .map(|x| Uint8Array::from(x.as_ref())), + ), + )?; + Ok(val.into()) +} + async fn get_client_w_ua(useragent: &str, redirect_limit: usize) -> EpoxyClient { EpoxyClient::new( "ws://localhost:4000".into(), useragent.into(), redirect_limit, + webpki_roots::TLS_SERVER_ROOTS + .iter() + .map(trustanchor_to_object) + .collect::>() + .expect("Failed to create certs"), ) .await .ok() diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 076e10c..7458bf4 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -169,17 +169,18 @@ impl MuxInner { pub async fn server_into_future( self, rx: R, + extensions: Vec, close_rx: mpsc::Receiver, muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>, close_tx: mpsc::Sender, ) -> Result<(), WispError> where - R: ws::WebSocketRead, + R: ws::WebSocketRead + Send, { self.as_future( close_rx, close_tx.clone(), - self.server_loop(rx, muxstream_sender, close_tx), + self.server_loop(rx, extensions, muxstream_sender, close_tx), ) .await } @@ -187,13 +188,14 @@ impl MuxInner { pub async fn client_into_future( self, rx: R, + extensions: Vec, close_rx: mpsc::Receiver, close_tx: mpsc::Sender, ) -> Result<(), WispError> where - R: ws::WebSocketRead, + R: ws::WebSocketRead + Send, { - self.as_future(close_rx, close_tx, self.client_loop(rx)) + self.as_future(close_rx, close_tx, self.client_loop(rx, extensions)) .await } @@ -295,11 +297,12 @@ impl MuxInner { async fn server_loop( &self, mut rx: R, + mut extensions: Vec, muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>, close_tx: mpsc::Sender, ) -> Result<(), WispError> where - R: ws::WebSocketRead, + R: ws::WebSocketRead + Send, { // will send continues once flow_control is at 10% of max let target_buffer_size = ((self.buffer_size as u64 * 90) / 100) as u32; @@ -309,104 +312,112 @@ impl MuxInner { if frame.opcode == ws::OpCode::Close { break Ok(()); } - let packet = Packet::try_from(frame)?; - - use PacketType::*; - match packet.packet_type { - Connect(inner_packet) => { - let (ch_tx, ch_rx) = mpsc::unbounded(); - let stream_type = inner_packet.stream_type; - let flow_control: Arc = AtomicU32::new(self.buffer_size).into(); - let flow_control_event: Arc = Event::new().into(); - let is_closed: Arc = AtomicBool::new(false).into(); - - self.stream_map.insert( - packet.stream_id, - MuxMapValue { - stream: ch_tx, - stream_type, - flow_control: flow_control.clone(), - flow_control_event: flow_control_event.clone(), - is_closed: is_closed.clone(), - }, - ); - muxstream_sender - .unbounded_send(( - inner_packet, - MuxStream::new( - packet.stream_id, - Role::Server, + if let Some(packet) = + Packet::maybe_handle_extension(frame, &mut extensions, &mut rx, &self.tx).await? + { + use PacketType::*; + match packet.packet_type { + Connect(inner_packet) => { + let (ch_tx, ch_rx) = mpsc::unbounded(); + let stream_type = inner_packet.stream_type; + let flow_control: Arc = AtomicU32::new(self.buffer_size).into(); + let flow_control_event: Arc = Event::new().into(); + let is_closed: Arc = AtomicBool::new(false).into(); + + self.stream_map.insert( + packet.stream_id, + MuxMapValue { + stream: ch_tx, stream_type, - ch_rx, - close_tx.clone(), - is_closed, - flow_control, - flow_control_event, - target_buffer_size, - ), - )) - .map_err(|x| WispError::Other(Box::new(x)))?; - } - Data(data) => { - if let Some(stream) = self.stream_map.get(&packet.stream_id) { - let _ = stream.stream.unbounded_send(data); - if stream.stream_type == StreamType::Tcp { - stream.flow_control.store( - stream - .flow_control - .load(Ordering::Acquire) - .saturating_sub(1), - Ordering::Release, - ); + flow_control: flow_control.clone(), + flow_control_event: flow_control_event.clone(), + is_closed: is_closed.clone(), + }, + ); + muxstream_sender + .unbounded_send(( + inner_packet, + MuxStream::new( + packet.stream_id, + Role::Server, + stream_type, + ch_rx, + close_tx.clone(), + is_closed, + flow_control, + flow_control_event, + target_buffer_size, + ), + )) + .map_err(|x| WispError::Other(Box::new(x)))?; + } + Data(data) => { + if let Some(stream) = self.stream_map.get(&packet.stream_id) { + let _ = stream.stream.unbounded_send(data); + if stream.stream_type == StreamType::Tcp { + stream.flow_control.store( + stream + .flow_control + .load(Ordering::Acquire) + .saturating_sub(1), + Ordering::Release, + ); + } } } - } - Continue(_) | Info(_) => break Err(WispError::InvalidPacketType), - Close(_) => { - if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) { - stream.is_closed.store(true, Ordering::Release); - stream.stream.disconnect(); - stream.stream.close_channel(); + Continue(_) | Info(_) => break Err(WispError::InvalidPacketType), + Close(_) => { + if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) { + stream.is_closed.store(true, Ordering::Release); + stream.stream.disconnect(); + stream.stream.close_channel(); + } } } } } } - async fn client_loop(&self, mut rx: R) -> Result<(), WispError> + async fn client_loop( + &self, + mut rx: R, + mut extensions: Vec, + ) -> Result<(), WispError> where - R: ws::WebSocketRead, + R: ws::WebSocketRead + Send, { loop { let frame = rx.wisp_read_frame(&self.tx).await?; if frame.opcode == ws::OpCode::Close { break Ok(()); } - let packet = Packet::try_from(frame)?; - - use PacketType::*; - match packet.packet_type { - Connect(_) | Info(_) => break Err(WispError::InvalidPacketType), - Data(data) => { - if let Some(stream) = self.stream_map.get(&packet.stream_id) { - let _ = stream.stream.unbounded_send(data); + if let Some(packet) = + Packet::maybe_handle_extension(frame, &mut extensions, &mut rx, &self.tx).await? + { + use PacketType::*; + match packet.packet_type { + Connect(_) | Info(_) => break Err(WispError::InvalidPacketType), + Data(data) => { + if let Some(stream) = self.stream_map.get(&packet.stream_id) { + let _ = stream.stream.unbounded_send(data); + } } - } - Continue(inner_packet) => { - if let Some(stream) = self.stream_map.get(&packet.stream_id) { - if stream.stream_type == StreamType::Tcp { - stream - .flow_control - .store(inner_packet.buffer_remaining, Ordering::Release); - let _ = stream.flow_control_event.notify(u32::MAX); + Continue(inner_packet) => { + if let Some(stream) = self.stream_map.get(&packet.stream_id) { + if stream.stream_type == StreamType::Tcp { + stream + .flow_control + .store(inner_packet.buffer_remaining, Ordering::Release); + let _ = stream.flow_control_event.notify(u32::MAX); + } } } - } - Close(_) => { - if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) { - stream.is_closed.store(true, Ordering::Release); - stream.stream.disconnect(); - stream.stream.close_channel(); + Close(_) => { + if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) { + stream.is_closed.store(true, Ordering::Release); + stream.stream.disconnect(); + stream.stream.close_channel(); + } } } } @@ -439,7 +450,7 @@ pub struct ServerMux { /// If this variable is true you must assume no extensions are supported. pub downgraded: bool, /// Extensions that are supported by both sides. - pub supported_extensions: Arc<[AnyProtocolExtension]>, + pub supported_extension_ids: Vec, close_tx: mpsc::Sender, muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream)>, } @@ -503,7 +514,7 @@ impl ServerMux { muxstream_recv: rx, close_tx: close_tx.clone(), downgraded, - supported_extensions: supported_extensions.into(), + supported_extension_ids: supported_extensions.iter().map(|x| x.get_id()).collect(), }, MuxInner { tx: write, @@ -512,6 +523,7 @@ impl ServerMux { } .server_into_future( AppendingWebSocketRead(extra_packet, read), + supported_extensions, close_rx, tx, close_tx, @@ -555,7 +567,7 @@ pub struct ClientMux { /// If this variable is true you must assume no extensions are supported. pub downgraded: bool, /// Extensions that are supported by both sides. - pub supported_extensions: Arc<[AnyProtocolExtension]>, + pub supported_extension_ids: Vec, close_tx: mpsc::Sender, } @@ -619,7 +631,10 @@ impl ClientMux { Self { close_tx: tx.clone(), downgraded, - supported_extensions: supported_extensions.into(), + supported_extension_ids: supported_extensions + .iter() + .map(|x| x.get_id()) + .collect(), }, MuxInner { tx: write, @@ -628,6 +643,7 @@ impl ClientMux { } .client_into_future( AppendingWebSocketRead(extra_packet, read), + supported_extensions, rx, tx, ), @@ -646,9 +662,9 @@ impl ClientMux { ) -> Result { if stream_type == StreamType::Udp && !self - .supported_extensions + .supported_extension_ids .iter() - .any(|x| x.get_id() == UdpProtocolExtension::ID) + .any(|x| *x == UdpProtocolExtension::ID) { return Err(WispError::UdpExtensionNotSupported); } diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index c2b2459..388fae7 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -1,6 +1,6 @@ use crate::{ extensions::{AnyProtocolExtension, ProtocolExtensionBuilder}, - ws::{self, Frame, OpCode}, + ws::{self, Frame, LockedWebSocketWrite, OpCode, WebSocketRead}, Role, WispError, WISP_VERSION, }; use bytes::{Buf, BufMut, Bytes, BytesMut}; @@ -388,6 +388,34 @@ impl Packet { } } + pub(crate) async fn maybe_handle_extension( + frame: Frame, + extensions: &mut [AnyProtocolExtension], + read: &mut (dyn WebSocketRead + Send), + write: &LockedWebSocketWrite, + ) -> Result, WispError> { + if !frame.finished { + return Err(WispError::WsFrameNotFinished); + } + if frame.opcode != OpCode::Binary { + return Err(WispError::WsFrameInvalidType); + } + let mut bytes = frame.payload; + if bytes.remaining() < 1 { + return Err(WispError::PacketTooSmall); + } + let packet_type = bytes.get_u8(); + if let Some(extension) = extensions + .iter_mut() + .find(|x| x.get_supported_packets().iter().any(|x| *x == packet_type)) + { + extension.handle_packet(bytes, read, write).await?; + Ok(None) + } else { + Ok(Some(Self::parse_packet(packet_type, bytes)?)) + } + } + fn parse_info( mut bytes: Bytes, role: Role, From 481128e4f5e02c1ec00332f25bbe9e9b6d51afb7 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sat, 13 Apr 2024 16:29:20 -0700 Subject: [PATCH 03/14] add password protocol extension, simplify protocol extension api --- wisp/src/extensions.rs | 299 ++++++++++++++++++++++++++++++++++++++++- wisp/src/lib.rs | 98 +++++++------- wisp/src/packet.rs | 4 +- 3 files changed, 343 insertions(+), 58 deletions(-) diff --git a/wisp/src/extensions.rs b/wisp/src/extensions.rs index 9358c4a..dfefae6 100644 --- a/wisp/src/extensions.rs +++ b/wisp/src/extensions.rs @@ -88,7 +88,7 @@ pub trait ProtocolExtension: std::fmt::Debug { fn box_clone(&self) -> Box; } -/// Trait to build a Wisp protocol extension for the client. +/// Trait to build a Wisp protocol extension from a payload. pub trait ProtocolExtensionBuilder { /// Get the protocol extension ID. /// @@ -96,7 +96,11 @@ pub trait ProtocolExtensionBuilder { fn get_id(&self) -> u8; /// Build a protocol extension from the extension's metadata. - fn build(&self, bytes: Bytes, role: Role) -> AnyProtocolExtension; + fn build_from_bytes(&self, bytes: Bytes, role: Role) + -> Result; + + /// Build a protocol extension to send to the other side. + fn build_to_extension(&self, role: Role) -> AnyProtocolExtension; } pub mod udp { @@ -108,7 +112,6 @@ pub mod udp { //! rx, //! tx, //! 128, - //! Some(vec![UdpProtocolExtension().into()]), //! Some(&[&UdpProtocolExtensionBuilder()]) //! ); //! ``` @@ -154,7 +157,6 @@ pub mod udp { Ok(()) } - /// Handle receiving a packet. async fn handle_packet( &mut self, _: Bytes, @@ -180,11 +182,294 @@ pub mod udp { impl ProtocolExtensionBuilder for UdpProtocolExtensionBuilder { fn get_id(&self) -> u8 { - 0x01 + UdpProtocolExtension::ID + } + + fn build_from_bytes( + &self, + _: Bytes, + _: crate::Role, + ) -> Result { + Ok(UdpProtocolExtension().into()) + } + + fn build_to_extension(&self, _: crate::Role) -> AnyProtocolExtension { + UdpProtocolExtension().into() + } + } +} + +pub mod password { + //! Password protocol extension. + //! + //! # Example + //! Server: + //! ``` + //! let mut passwords = HashMap::new(); + //! passwords.insert("user1".to_string(), "pw".to_string()); + //! let (mux, fut) = ServerMux::new( + //! rx, + //! tx, + //! 128, + //! Some(&[&PasswordProtocolExtensionBuilder::new_server(passwords)]) + //! ); + //! ``` + //! + //! Client: + //! ``` + //! let (mux, fut) = ClientMux::new( + //! rx, + //! tx, + //! 128, + //! Some(&[ + //! &PasswordProtocolExtensionBuilder::new_client( + //! "user1".to_string(), + //! "pw".to_string() + //! ) + //! ]) + //! ); + //! ``` + //! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x02---password-authentication) + + use std::{collections::HashMap, error::Error, fmt::Display, string::FromUtf8Error}; + + use async_trait::async_trait; + use bytes::{Buf, BufMut, Bytes, BytesMut}; + + use crate::{ + ws::{LockedWebSocketWrite, WebSocketRead}, + Role, WispError, + }; + + use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder}; + + #[derive(Debug, Clone)] + /// Password protocol extension. + /// + /// **This extension will panic when encoding if the username's length does not fit within a u8 + /// or the password's length does not fit within a u16.** + pub struct PasswordProtocolExtension { + /// The username to log in with. + /// + /// This string's length must fit within a u8. + pub username: String, + /// The password to log in with. + /// + /// This string's length must fit within a u16. + pub password: String, + role: Role, + } + + impl PasswordProtocolExtension { + /// Password protocol extension ID. + pub const ID: u8 = 0x02; + + /// Create a new password protocol extension for the server. + /// + /// This signifies that the server requires a password. + pub fn new_server() -> Self { + Self { + username: String::new(), + password: String::new(), + role: Role::Server, + } + } + + /// Create a new password protocol extension for the client, with a username and password. + /// + /// The username's length must fit within a u8. The password's length must fit within a + /// u16. + pub fn new_client(username: String, password: String) -> Self { + Self { + username, + password, + role: Role::Client, + } + } + } + + #[async_trait] + impl ProtocolExtension for PasswordProtocolExtension { + fn get_id(&self) -> u8 { + Self::ID + } + + fn get_supported_packets(&self) -> &'static [u8] { + &[] + } + + fn encode(&self) -> Bytes { + match self.role { + Role::Server => Bytes::new(), + Role::Client => { + let username = Bytes::from(self.username.clone().into_bytes()); + let password = Bytes::from(self.password.clone().into_bytes()); + let username_len = u8::try_from(username.len()).expect("username was too long"); + let password_len = + u16::try_from(username.len()).expect("password was too long"); + + let mut bytes = + BytesMut::with_capacity(3 + username_len as usize + password_len as usize); + bytes.put_u8(username_len); + bytes.put_u16_le(password_len); + bytes.extend(username); + bytes.extend(password); + bytes.freeze() + } + } + } + + async fn handle_handshake( + &mut self, + _: &mut dyn WebSocketRead, + _: &LockedWebSocketWrite, + ) -> Result<(), WispError> { + Ok(()) + } + + async fn handle_packet( + &mut self, + _: Bytes, + _: &mut dyn WebSocketRead, + _: &LockedWebSocketWrite, + ) -> Result<(), WispError> { + Ok(()) + } + + fn box_clone(&self) -> Box { + Box::new(self.clone()) + } + } + + #[derive(Debug)] + enum PasswordProtocolExtensionError { + Utf8Error(FromUtf8Error), + InvalidUsername, + InvalidPassword, + } + + impl Display for PasswordProtocolExtensionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use PasswordProtocolExtensionError as E; + match self { + E::Utf8Error(e) => write!(f, "{}", e), + E::InvalidUsername => write!(f, "Invalid username"), + E::InvalidPassword => write!(f, "Invalid password"), + } + } + } + + impl Error for PasswordProtocolExtensionError {} + + impl From for WispError { + fn from(value: PasswordProtocolExtensionError) -> Self { + WispError::ExtensionImplError(Box::new(value)) + } + } + + impl From for PasswordProtocolExtensionError { + fn from(value: FromUtf8Error) -> Self { + PasswordProtocolExtensionError::Utf8Error(value) + } + } + + impl From for AnyProtocolExtension { + fn from(value: PasswordProtocolExtension) -> Self { + AnyProtocolExtension(Box::new(value)) + } + } + + /// Password protocol extension builder. + pub struct PasswordProtocolExtensionBuilder { + /// Map of users and their passwords to allow. Only used on server. + pub users: HashMap, + /// Username to authenticate with. Only used on client. + pub username: String, + /// Password to authenticate with. Only used on client. + pub password: String, + } + + impl PasswordProtocolExtensionBuilder { + /// Create a new password protocol extension builder for the server, with a map of users + /// and passwords to allow. + pub fn new_server(users: HashMap) -> Self { + Self { + users, + username: String::new(), + password: String::new(), + } + } + + /// Create a new password protocol extension builder for the client, with a username and + /// password to authenticate with. + pub fn new_client(username: String, password: String) -> Self { + Self { + users: HashMap::new(), + username, + password, + } + } + } + + impl ProtocolExtensionBuilder for PasswordProtocolExtensionBuilder { + fn get_id(&self) -> u8 { + PasswordProtocolExtension::ID + } + + fn build_from_bytes( + &self, + mut payload: Bytes, + role: crate::Role, + ) -> Result { + match role { + Role::Server => { + if payload.remaining() < 3 { + return Err(WispError::PacketTooSmall); + } + + let username_len = payload.get_u8(); + let password_len = payload.get_u16_le(); + if payload.remaining() < (password_len + username_len as u16) as usize { + return Err(WispError::PacketTooSmall); + } + + use PasswordProtocolExtensionError as EError; + let username = + String::from_utf8(payload.copy_to_bytes(username_len as usize).to_vec()) + .map_err(|x| WispError::from(EError::from(x)))?; + let password = + String::from_utf8(payload.copy_to_bytes(password_len as usize).to_vec()) + .map_err(|x| WispError::from(EError::from(x)))?; + + let Some(user) = self.users.iter().find(|x| *x.0 == username) else { + return Err(EError::InvalidUsername.into()); + }; + + if *user.1 != password { + return Err(EError::InvalidPassword.into()); + } + Ok(PasswordProtocolExtension { + username, + password, + role, + } + .into()) + } + Role::Client => { + Ok(PasswordProtocolExtension::new_client(String::new(), String::new()).into()) + } + } } - fn build(&self, _: Bytes, _: crate::Role) -> AnyProtocolExtension { - AnyProtocolExtension(Box::new(UdpProtocolExtension())) + fn build_to_extension(&self, role: Role) -> AnyProtocolExtension { + match role { + Role::Server => PasswordProtocolExtension::new_server(), + Role::Client => PasswordProtocolExtension::new_client( + self.username.clone(), + self.password.clone(), + ), + } + .into() } } } diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 7458bf4..40b21f7 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -458,13 +458,11 @@ pub struct ServerMux { impl ServerMux { /// Create a new server-side multiplexor. /// - /// If either extensions or extension_builders are None a Wisp v1 connection is created - /// otherwise a Wisp v2 connection is created. + /// If extension_builders is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. pub async fn new( mut read: R, write: W, buffer_size: u32, - extensions: Option>, extension_builders: Option<&[&(dyn ProtocolExtensionBuilder + Sync)]>, ) -> Result<(Self, impl Future> + Send), WispError> where @@ -483,28 +481,29 @@ impl ServerMux { let mut extra_packet = Vec::with_capacity(1); let mut downgraded = true; - if let Some(extensions) = extensions { - if let Some(builders) = extension_builders { - let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect(); - write - .write_frame(Packet::new_info(extensions).into()) - .await?; - if let Some(frame) = select! { - x = read.wisp_read_frame(&write).fuse() => Some(x?), - // TODO change this to correct timeout once draft 2 is out - _ = Delay::new(Duration::from_secs(5)).fuse() => None - } { - let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?; - if let PacketType::Info(info) = packet.packet_type { - supported_extensions = info - .extensions - .into_iter() - .filter(|x| extension_ids.contains(&x.get_id())) - .collect(); - downgraded = false; - } else { - extra_packet.push(packet.into()); - } + if let Some(builders) = extension_builders { + let extensions: Vec<_> = builders + .iter() + .map(|x| x.build_to_extension(Role::Server)) + .collect(); + let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect(); + write + .write_frame(Packet::new_info(extensions).into()) + .await?; + if let Some(frame) = select! { + x = read.wisp_read_frame(&write).fuse() => Some(x?), + _ = Delay::new(Duration::from_secs(5)).fuse() => None + } { + let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?; + if let PacketType::Info(info) = packet.packet_type { + supported_extensions = info + .extensions + .into_iter() + .filter(|x| extension_ids.contains(&x.get_id())) + .collect(); + downgraded = false; + } else { + extra_packet.push(packet.into()); } } } @@ -574,12 +573,10 @@ pub struct ClientMux { impl ClientMux { /// Create a new client side multiplexor. /// - /// If either extensions or extension_builders are None a Wisp v1 connection is created - /// otherwise a Wisp v2 connection is created. + /// If extension_builders is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. pub async fn new( mut read: R, write: W, - extensions: Option>, extension_builders: Option<&[&(dyn ProtocolExtensionBuilder + Sync)]>, ) -> Result<(Self, impl Future> + Send), WispError> where @@ -596,28 +593,29 @@ impl ClientMux { let mut extra_packet = Vec::with_capacity(1); let mut downgraded = true; - if let Some(extensions) = extensions { - if let Some(builders) = extension_builders { - let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect(); - if let Some(frame) = select! { - x = read.wisp_read_frame(&write).fuse() => Some(x?), - // TODO change this to correct timeout once draft 2 is out - _ = Delay::new(Duration::from_secs(5)).fuse() => None - } { - let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?; - if let PacketType::Info(info) = packet.packet_type { - supported_extensions = info - .extensions - .into_iter() - .filter(|x| extension_ids.contains(&x.get_id())) - .collect(); - write - .write_frame(Packet::new_info(extensions).into()) - .await?; - downgraded = false; - } else { - extra_packet.push(packet.into()); - } + if let Some(builders) = extension_builders { + let extensions: Vec<_> = builders + .iter() + .map(|x| x.build_to_extension(Role::Client)) + .collect(); + let extension_ids: Vec<_> = extensions.iter().map(|x| x.get_id()).collect(); + if let Some(frame) = select! { + x = read.wisp_read_frame(&write).fuse() => Some(x?), + _ = Delay::new(Duration::from_secs(5)).fuse() => None + } { + let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?; + if let PacketType::Info(info) = packet.packet_type { + supported_extensions = info + .extensions + .into_iter() + .filter(|x| extension_ids.contains(&x.get_id())) + .collect(); + write + .write_frame(Packet::new_info(extensions).into()) + .await?; + downgraded = false; + } else { + extra_packet.push(packet.into()); } } } diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index 388fae7..0017307 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -444,7 +444,9 @@ impl Packet { return Err(WispError::PacketTooSmall); } if let Some(builder) = extension_builders.iter().find(|x| x.get_id() == id) { - extensions.push(builder.build(bytes.copy_to_bytes(length), role)) + if let Ok(extension) = builder.build_from_bytes(bytes.copy_to_bytes(length), role) { + extensions.push(extension) + } } else { bytes.advance(length) } From b8eb13903b7f298f4a3297bbace13b3c3eb7ec88 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sat, 13 Apr 2024 16:39:05 -0700 Subject: [PATCH 04/14] update server,client,simple-wisp-client for new api --- client/src/utils.rs | 13 ++----------- server/src/main.rs | 14 ++++---------- simple-wisp-client/src/main.rs | 16 +++------------- 3 files changed, 9 insertions(+), 34 deletions(-) diff --git a/client/src/utils.rs b/client/src/utils.rs index b956d02..5ed12b7 100644 --- a/client/src/utils.rs +++ b/client/src/utils.rs @@ -7,10 +7,7 @@ use wasm_bindgen_futures::JsFuture; use hyper::rt::Executor; use js_sys::ArrayBuffer; use std::future::Future; -use wisp_mux::{ - extensions::udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder}, - WispError, -}; +use wisp_mux::{extensions::udp::UdpProtocolExtensionBuilder, WispError}; #[wasm_bindgen] extern "C" { @@ -207,13 +204,7 @@ pub async fn make_mux( .await .map_err(|_| WispError::WsImplSocketClosed)?; wtx.wait_for_open().await; - let mux = ClientMux::new( - wrx, - wtx, - Some(vec![UdpProtocolExtension().into()]), - Some(&[&UdpProtocolExtensionBuilder()]), - ) - .await?; + let mux = ClientMux::new(wrx, wtx, Some(&[&UdpProtocolExtensionBuilder()])).await?; Ok(mux) } diff --git a/server/src/main.rs b/server/src/main.rs index 61561b7..9644a22 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -20,8 +20,8 @@ use tokio_util::codec::{BytesCodec, Framed}; use tokio_util::either::Either; use wisp_mux::{ - extensions::udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder}, - CloseReason, ConnectPacket, MuxStream, ServerMux, StreamType, WispError, + extensions::udp::UdpProtocolExtensionBuilder, CloseReason, ConnectPacket, MuxStream, ServerMux, + StreamType, WispError, }; type HttpBody = http_body_util::Full; @@ -263,14 +263,8 @@ async fn accept_ws( println!("{:?}: connected", addr); - let (mut mux, fut) = ServerMux::new( - rx, - tx, - u32::MAX, - Some(vec![UdpProtocolExtension().into()]), - Some(&[&UdpProtocolExtensionBuilder()]), - ) - .await?; + let (mut mux, fut) = + ServerMux::new(rx, tx, u32::MAX, Some(&[&UdpProtocolExtensionBuilder()])).await?; tokio::spawn(async move { if let Err(e) = fut.await { diff --git a/simple-wisp-client/src/main.rs b/simple-wisp-client/src/main.rs index 4dd329a..fe29fe6 100644 --- a/simple-wisp-client/src/main.rs +++ b/simple-wisp-client/src/main.rs @@ -18,7 +18,6 @@ use std::{ process::exit, sync::Arc, time::{Duration, Instant}, - usize, }; use tokio::{ net::TcpStream, @@ -28,10 +27,7 @@ use tokio::{ }; use tokio_native_tls::{native_tls, TlsConnector}; use tokio_util::either::Either; -use wisp_mux::{ - extensions::udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder}, - ClientMux, StreamType, WispError, -}; +use wisp_mux::{extensions::udp::UdpProtocolExtensionBuilder, ClientMux, StreamType, WispError}; #[derive(Debug)] enum WispClientError { @@ -138,15 +134,9 @@ async fn main() -> Result<(), Box> { let rx = FragmentCollectorRead::new(rx); let (mut mux, fut) = if opts.udp { - ClientMux::new( - rx, - tx, - Some(vec![UdpProtocolExtension().into()]), - Some(&[&UdpProtocolExtensionBuilder()]), - ) - .await? + ClientMux::new(rx, tx, Some(&[&UdpProtocolExtensionBuilder()])).await? } else { - ClientMux::new(rx, tx, Some(vec![]), Some(&[])).await? + ClientMux::new(rx, tx, Some(&[])).await? }; let mut threads = Vec::with_capacity(opts.streams * 2 + 3); From 397fd43dc57ef7da8a7b78b884a3109a7dec6354 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sat, 13 Apr 2024 16:49:07 -0700 Subject: [PATCH 05/14] remove invalidstreamtype to allow for custom protocol extension streams --- server/src/main.rs | 4 +++ wisp/src/lib.rs | 3 -- wisp/src/packet.rs | 78 ++++++++++++++++++++++++++-------------------- 3 files changed, 49 insertions(+), 36 deletions(-) diff --git a/server/src/main.rs b/server/src/main.rs index 9644a22..a191a01 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -247,6 +247,10 @@ async fn handle_mux(packet: ConnectPacket, mut stream: MuxStream) -> Result { + stream.close(CloseReason::ServerStreamInvalidInfo).await?; + return Ok(false); + } } Ok(true) } diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 40b21f7..3c77584 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -52,8 +52,6 @@ pub enum WispError { PacketTooSmall, /// The packet received had an invalid type. InvalidPacketType, - /// The stream had an invalid type. - InvalidStreamType, /// The stream had an invalid ID. InvalidStreamId, /// The close packet had an invalid reason. @@ -113,7 +111,6 @@ impl std::fmt::Display for WispError { match self { Self::PacketTooSmall => write!(f, "Packet too small"), Self::InvalidPacketType => write!(f, "Invalid packet type"), - Self::InvalidStreamType => write!(f, "Invalid stream type"), Self::InvalidStreamId => write!(f, "Invalid stream id"), Self::InvalidCloseReason => write!(f, "Invalid close reason"), Self::InvalidUri => write!(f, "Invalid URI"), diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index 0017307..9ff6a3c 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -9,19 +9,31 @@ use bytes::{Buf, BufMut, Bytes, BytesMut}; #[derive(Debug, PartialEq, Copy, Clone)] pub enum StreamType { /// TCP Wisp stream. - Tcp = 0x01, + Tcp, /// UDP Wisp stream. - Udp = 0x02, + Udp, + /// Unknown Wisp stream type used for custom streams by protocol extensions. + Unknown(u8), } -impl TryFrom for StreamType { - type Error = WispError; - fn try_from(stream_type: u8) -> Result { - use StreamType::*; - match stream_type { - 0x01 => Ok(Tcp), - 0x02 => Ok(Udp), - _ => Err(Self::Error::InvalidStreamType), +impl From for StreamType { + fn from(value: u8) -> Self { + use StreamType as S; + match value { + 0x01 => S::Tcp, + 0x02 => S::Udp, + x => S::Unknown(x), + } + } +} + +impl From for u8 { + fn from(value: StreamType) -> Self { + use StreamType as S; + match value { + S::Tcp => 0x01, + S::Udp => 0x02, + S::Unknown(x) => x, } } } @@ -60,9 +72,9 @@ pub enum CloseReason { impl TryFrom for CloseReason { type Error = WispError; - fn try_from(stream_type: u8) -> Result { + fn try_from(close_reason: u8) -> Result { use CloseReason as R; - match stream_type { + match close_reason { 0x01 => Ok(R::Unknown), 0x02 => Ok(R::Voluntary), 0x03 => Ok(R::Unexpected), @@ -75,7 +87,7 @@ impl TryFrom for CloseReason { 0x48 => Ok(R::ServerStreamBlockedAddress), 0x49 => Ok(R::ServerStreamThrottled), 0x81 => Ok(R::ClientUnexpected), - _ => Err(Self::Error::InvalidStreamType), + _ => Err(Self::Error::InvalidCloseReason), } } } @@ -115,7 +127,7 @@ impl TryFrom for ConnectPacket { return Err(Self::Error::PacketTooSmall); } Ok(Self { - stream_type: bytes.get_u8().try_into()?, + stream_type: bytes.get_u8().into(), destination_port: bytes.get_u16_le(), destination_hostname: std::str::from_utf8(&bytes)?.to_string(), }) @@ -125,7 +137,7 @@ impl TryFrom for ConnectPacket { impl From for Bytes { fn from(packet: ConnectPacket) -> Self { let mut encoded = BytesMut::with_capacity(1 + 2 + packet.destination_hostname.len()); - encoded.put_u8(packet.stream_type as u8); + encoded.put_u8(packet.stream_type.into()); encoded.put_u16_le(packet.destination_port); encoded.extend(packet.destination_hostname.bytes()); encoded.freeze() @@ -255,26 +267,26 @@ pub enum PacketType { impl PacketType { /// Get the packet type used in the protocol. pub fn as_u8(&self) -> u8 { - use PacketType::*; + use PacketType as P; match self { - Connect(_) => 0x01, - Data(_) => 0x02, - Continue(_) => 0x03, - Close(_) => 0x04, - Info(_) => 0x05, + P::Connect(_) => 0x01, + P::Data(_) => 0x02, + P::Continue(_) => 0x03, + P::Close(_) => 0x04, + P::Info(_) => 0x05, } } } impl From for Bytes { fn from(packet: PacketType) -> Self { - use PacketType::*; + use PacketType as P; match packet { - Connect(x) => x.into(), - Data(x) => x, - Continue(x) => x.into(), - Close(x) => x.into(), - Info(x) => x.into(), + P::Connect(x) => x.into(), + P::Data(x) => x, + P::Continue(x) => x.into(), + P::Close(x) => x.into(), + P::Info(x) => x.into(), } } } @@ -351,14 +363,14 @@ impl Packet { } fn parse_packet(packet_type: u8, mut bytes: Bytes) -> Result { - use PacketType::*; + use PacketType as P; Ok(Self { stream_id: bytes.get_u32_le(), packet_type: match packet_type { - 0x01 => Connect(ConnectPacket::try_from(bytes)?), - 0x02 => Data(bytes), - 0x03 => Continue(ContinuePacket::try_from(bytes)?), - 0x04 => Close(ClosePacket::try_from(bytes)?), + 0x01 => P::Connect(ConnectPacket::try_from(bytes)?), + 0x02 => P::Data(bytes), + 0x03 => P::Continue(ContinuePacket::try_from(bytes)?), + 0x04 => P::Close(ClosePacket::try_from(bytes)?), // 0x05 is handled seperately _ => return Err(WispError::InvalidPacketType), }, @@ -465,7 +477,7 @@ impl Packet { impl TryFrom for Packet { type Error = WispError; fn try_from(mut bytes: Bytes) -> Result { - if bytes.remaining() < 5 { + if bytes.remaining() < 1 { return Err(Self::Error::PacketTooSmall); } let packet_type = bytes.get_u8(); From 4d433b60c408f5b29a1ca2da9157c93e22d3104b Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sat, 13 Apr 2024 20:32:21 -0700 Subject: [PATCH 06/14] enforce UdpProtocolExtension if requested --- server/src/main.rs | 2 ++ simple-wisp-client/src/main.rs | 11 +++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/server/src/main.rs b/server/src/main.rs index a191a01..61d1897 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -270,6 +270,8 @@ async fn accept_ws( let (mut mux, fut) = ServerMux::new(rx, tx, u32::MAX, Some(&[&UdpProtocolExtensionBuilder()])).await?; + println!("{:?}: downgraded: {} extensions supported: {:?}", addr, mux.downgraded, mux.supported_extension_ids); + tokio::spawn(async move { if let Err(e) = fut.await { println!("err in mux: {:?}", e); diff --git a/simple-wisp-client/src/main.rs b/simple-wisp-client/src/main.rs index fe29fe6..9c38cad 100644 --- a/simple-wisp-client/src/main.rs +++ b/simple-wisp-client/src/main.rs @@ -27,7 +27,7 @@ use tokio::{ }; use tokio_native_tls::{native_tls, TlsConnector}; use tokio_util::either::Either; -use wisp_mux::{extensions::udp::UdpProtocolExtensionBuilder, ClientMux, StreamType, WispError}; +use wisp_mux::{extensions::udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder}, ClientMux, StreamType, WispError}; #[derive(Debug)] enum WispClientError { @@ -134,11 +134,18 @@ async fn main() -> Result<(), Box> { let rx = FragmentCollectorRead::new(rx); let (mut mux, fut) = if opts.udp { - ClientMux::new(rx, tx, Some(&[&UdpProtocolExtensionBuilder()])).await? + let (mux, fut) = ClientMux::new(rx, tx, Some(&[&UdpProtocolExtensionBuilder()])).await?; + if !mux.supported_extension_ids.iter().any(|x| *x == UdpProtocolExtension::ID) { + println!("server did not support udp, was downgraded {}, extensions supported {:?}", mux.downgraded, mux.supported_extension_ids); + exit(1); + } + (mux, fut) } else { ClientMux::new(rx, tx, Some(&[])).await? }; + println!("connected and created ClientMux, was downgraded {}, extensions supported {:?}", mux.downgraded, mux.supported_extension_ids); + let mut threads = Vec::with_capacity(opts.streams * 2 + 3); threads.push(tokio::spawn(fut)); From d10b7691e4eff00a4295eb87e7bd4eb484a078e8 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sat, 13 Apr 2024 22:34:26 -0700 Subject: [PATCH 07/14] fix password protocol extension, respect stream id 0 close packets, allow sending stream id 0 close packets --- server/src/main.rs | 123 +++++++++++++++++++++++++-------- simple-wisp-client/src/main.rs | 74 ++++++++++++++++---- wisp/src/extensions.rs | 8 ++- wisp/src/lib.rs | 59 +++++++++++++--- wisp/src/packet.rs | 4 ++ wisp/src/stream.rs | 2 +- 6 files changed, 220 insertions(+), 50 deletions(-) diff --git a/server/src/main.rs b/server/src/main.rs index 61d1897..12cc030 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,5 +1,5 @@ #![feature(let_chains, ip)] -use std::io::Error; +use std::{collections::HashMap, io::Error, path::PathBuf, sync::Arc}; use bytes::Bytes; use clap::Parser; @@ -20,8 +20,11 @@ use tokio_util::codec::{BytesCodec, Framed}; use tokio_util::either::Either; use wisp_mux::{ - extensions::udp::UdpProtocolExtensionBuilder, CloseReason, ConnectPacket, MuxStream, ServerMux, - StreamType, WispError, + extensions::{ + password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder}, + udp::UdpProtocolExtensionBuilder, + }, + CloseReason, ConnectPacket, MuxStream, ServerMux, StreamType, WispError, }; type HttpBody = http_body_util::Full; @@ -56,6 +59,20 @@ struct Cli { /// Whether the server should block ports other than 80 or 443 #[arg(long)] block_non_http: bool, + /// Path to a file containing `user:password` separated by newlines. This is plaintext!!! + /// + /// `user` cannot contain `:`. Whitespace will be trimmed. + #[arg(long)] + auth: Option, +} + +#[derive(Clone)] +struct MuxOptions { + pub block_local: bool, + pub block_udp: bool, + pub block_non_http: bool, + pub enforce_auth: bool, + pub auth: Arc, } #[cfg(not(unix))] @@ -138,19 +155,44 @@ async fn main() -> Result<(), Error> { "/".to_string() }; + let mut auth = HashMap::new(); + let enforce_auth = opt.auth.is_some(); + if let Some(file) = opt.auth { + let file = std::fs::read_to_string(file)?; + for entry in file.split('\n').filter_map(|x| { + if x.contains(':') { + Some(x.trim()) + } else { + None + } + }) { + let split: Vec<_> = entry.split(':').collect(); + let username = split[0]; + let password = split[1..].join(":"); + println!( + "adding username {:?} password {:?} to allowed auth", + username, password + ); + auth.insert(username.to_string(), password.to_string()); + } + } + let pw_ext = Arc::new(PasswordProtocolExtensionBuilder::new_server(auth)); + + let mux_options = MuxOptions { + block_local: opt.block_local, + block_non_http: opt.block_non_http, + block_udp: opt.block_udp, + auth: pw_ext, + enforce_auth, + }; + println!("listening on `{}` with prefix `{}`", addr, prefix); while let Ok((stream, addr)) = socket.accept().await { let prefix = prefix.clone(); + let mux_options = mux_options.clone(); tokio::spawn(async move { let service = service_fn(move |res| { - accept_http( - res, - addr.clone(), - prefix.clone(), - opt.block_local, - opt.block_udp, - opt.block_non_http, - ) + accept_http(res, addr.clone(), prefix.clone(), mux_options.clone()) }); let conn = http1::Builder::new() .serve_connection(TokioIo::new(stream), service) @@ -168,9 +210,7 @@ async fn accept_http( mut req: Request, addr: String, prefix: String, - block_local: bool, - block_udp: bool, - block_non_http: bool, + mux_options: MuxOptions, ) -> Result, WebSocketError> { let uri = req.uri().path().to_string(); if upgrade::is_upgrade_request(&req) @@ -179,12 +219,17 @@ async fn accept_http( let (res, fut) = upgrade::upgrade(&mut req)?; if uri.is_empty() { - tokio::spawn(async move { - accept_ws(fut, addr.clone(), block_local, block_udp, block_non_http).await - }); + tokio::spawn(async move { accept_ws(fut, addr.clone(), mux_options).await }); } else if let Some(uri) = uri.strip_prefix('/').map(|x| x.to_string()) { tokio::spawn(async move { - accept_wsproxy(fut, uri, addr.clone(), block_local, block_non_http).await + accept_wsproxy( + fut, + uri, + addr.clone(), + mux_options.block_local, + mux_options.block_non_http, + ) + .await }); } @@ -258,19 +303,41 @@ async fn handle_mux(packet: ConnectPacket, mut stream: MuxStream) -> Result Result<(), Box> { let (rx, tx) = ws.await?.split(tokio::io::split); let rx = FragmentCollectorRead::new(rx); println!("{:?}: connected", addr); + let (mut mux, fut) = if mux_options.enforce_auth { + let (mut mux, fut) = ServerMux::new( + rx, + tx, + u32::MAX, + Some(&[&UdpProtocolExtensionBuilder(), mux_options.auth.as_ref()]), + ) + .await?; + if !mux + .supported_extension_ids + .iter() + .any(|x| *x == PasswordProtocolExtension::ID) + { + println!( + "{:?}: client did not support auth or password was invalid", + addr + ); + mux.close_extension_incompat().await?; + return Ok(()); + } + (mux, fut) + } else { + ServerMux::new(rx, tx, u32::MAX, Some(&[&UdpProtocolExtensionBuilder()])).await? + }; - let (mut mux, fut) = - ServerMux::new(rx, tx, u32::MAX, Some(&[&UdpProtocolExtensionBuilder()])).await?; - - println!("{:?}: downgraded: {} extensions supported: {:?}", addr, mux.downgraded, mux.supported_extension_ids); + println!( + "{:?}: downgraded: {} extensions supported: {:?}", + addr, mux.downgraded, mux.supported_extension_ids + ); tokio::spawn(async move { if let Err(e) = fut.await { @@ -280,14 +347,14 @@ async fn accept_ws( while let Some((packet, mut stream)) = mux.server_new_stream().await { tokio::spawn(async move { - if (block_non_http + if (mux_options.block_non_http && !(packet.destination_port == 80 || packet.destination_port == 443)) - || (block_udp && packet.stream_type == StreamType::Udp) + || (mux_options.block_udp && packet.stream_type == StreamType::Udp) { let _ = stream.close(CloseReason::ServerStreamBlockedAddress).await; return; } - if block_local { + if mux_options.block_local { match lookup_host(format!( "{}:{}", packet.destination_hostname, packet.destination_port diff --git a/simple-wisp-client/src/main.rs b/simple-wisp-client/src/main.rs index 9c38cad..1f5802f 100644 --- a/simple-wisp-client/src/main.rs +++ b/simple-wisp-client/src/main.rs @@ -27,7 +27,14 @@ use tokio::{ }; use tokio_native_tls::{native_tls, TlsConnector}; use tokio_util::either::Either; -use wisp_mux::{extensions::udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder}, ClientMux, StreamType, WispError}; +use wisp_mux::{ + extensions::{ + password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder}, + udp::{UdpProtocolExtension, UdpProtocolExtensionBuilder}, + ProtocolExtensionBuilder, + }, + ClientMux, StreamType, WispError, +}; #[derive(Debug)] enum WispClientError { @@ -80,6 +87,11 @@ struct Cli { /// Ask for UDP #[arg(short, long)] udp: bool, + /// Enable auth: format is `username:password` + /// + /// Usernames and passwords are sent in plaintext!! + #[arg(long)] + auth: Option, } #[tokio::main(flavor = "multi_thread")] @@ -103,6 +115,13 @@ async fn main() -> Result<(), Box> { let addr_dest = opts.tcp.ip().to_string(); let addr_dest_port = opts.tcp.port(); + let auth = opts.auth.map(|auth| { + let split: Vec<_> = auth.split(':').collect(); + let username = split[0].to_string(); + let password = split[1..].join(":"); + PasswordProtocolExtensionBuilder::new_client(username, password) + }); + println!( "connecting to {} and sending &[0; 1024 * {}] to {} with threads {}", opts.wisp, opts.packet_size, opts.tcp, opts.streams, @@ -133,18 +152,49 @@ async fn main() -> Result<(), Box> { let (rx, tx) = ws.split(tokio::io::split); let rx = FragmentCollectorRead::new(rx); - let (mut mux, fut) = if opts.udp { - let (mux, fut) = ClientMux::new(rx, tx, Some(&[&UdpProtocolExtensionBuilder()])).await?; - if !mux.supported_extension_ids.iter().any(|x| *x == UdpProtocolExtension::ID) { - println!("server did not support udp, was downgraded {}, extensions supported {:?}", mux.downgraded, mux.supported_extension_ids); - exit(1); - } - (mux, fut) - } else { - ClientMux::new(rx, tx, Some(&[])).await? - }; + let mut extensions: Vec> = Vec::new(); + if opts.udp { + extensions.push(Box::new(UdpProtocolExtensionBuilder())); + } + let enforce_auth = auth.is_some(); + if let Some(auth) = auth { + extensions.push(Box::new(auth)); + } + let extensions_mapped: Vec<&(dyn ProtocolExtensionBuilder + Sync)> = + extensions.iter().map(|x| x.as_ref()).collect(); + + let (mut mux, fut) = ClientMux::new(rx, tx, Some(&extensions_mapped)).await?; + if opts.udp + && !mux + .supported_extension_ids + .iter() + .any(|x| *x == UdpProtocolExtension::ID) + { + println!( + "server did not support udp, was downgraded {}, extensions supported {:?}", + mux.downgraded, mux.supported_extension_ids + ); + mux.close_extension_incompat().await?; + exit(1); + } + if enforce_auth + && !mux + .supported_extension_ids + .iter() + .any(|x| *x == PasswordProtocolExtension::ID) + { + println!( + "server did not support passwords or password was incorrect, was downgraded {}, extensions supported {:?}", + mux.downgraded, mux.supported_extension_ids + ); + mux.close_extension_incompat().await?; + exit(1); + } - println!("connected and created ClientMux, was downgraded {}, extensions supported {:?}", mux.downgraded, mux.supported_extension_ids); + println!( + "connected and created ClientMux, was downgraded {}, extensions supported {:?}", + mux.downgraded, mux.supported_extension_ids + ); let mut threads = Vec::with_capacity(opts.streams * 2 + 3); diff --git a/wisp/src/extensions.rs b/wisp/src/extensions.rs index dfefae6..661439c 100644 --- a/wisp/src/extensions.rs +++ b/wisp/src/extensions.rs @@ -202,6 +202,8 @@ pub mod udp { pub mod password { //! Password protocol extension. //! + //! Passwords are sent in plain text!! + //! //! # Example //! Server: //! ``` @@ -246,6 +248,7 @@ pub mod password { #[derive(Debug, Clone)] /// Password protocol extension. /// + /// **Passwords are sent in plain text!!** /// **This extension will panic when encoding if the username's length does not fit within a u8 /// or the password's length does not fit within a u16.** pub struct PasswordProtocolExtension { @@ -306,7 +309,7 @@ pub mod password { let password = Bytes::from(self.password.clone().into_bytes()); let username_len = u8::try_from(username.len()).expect("username was too long"); let password_len = - u16::try_from(username.len()).expect("password was too long"); + u16::try_from(password.len()).expect("password was too long"); let mut bytes = BytesMut::with_capacity(3 + username_len as usize + password_len as usize); @@ -380,6 +383,8 @@ pub mod password { } /// Password protocol extension builder. + /// + /// **Passwords are sent in plain text!!** pub struct PasswordProtocolExtensionBuilder { /// Map of users and their passwords to allow. Only used on server. pub users: HashMap, @@ -448,6 +453,7 @@ pub mod password { if *user.1 != password { return Err(EError::InvalidPassword.into()); } + Ok(PasswordProtocolExtension { username, password, diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 3c77584..d8f7dd0 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -64,6 +64,8 @@ pub enum WispError { UriHasNoPort, /// The max stream count was reached. MaxStreamCountReached, + /// The Wisp protocol version was incompatible. + IncompatibleProtocolVersion, /// The stream had already been closed. StreamAlreadyClosed, /// The websocket frame received had an invalid type. @@ -117,6 +119,7 @@ impl std::fmt::Display for WispError { Self::UriHasNoHost => write!(f, "URI has no host"), Self::UriHasNoPort => write!(f, "URI has no port"), Self::MaxStreamCountReached => write!(f, "Maximum stream count reached"), + Self::IncompatibleProtocolVersion => write!(f, "Incompatible Wisp protocol version"), Self::StreamAlreadyClosed => write!(f, "Stream already closed"), Self::WsFrameInvalidType => write!(f, "Invalid websocket frame type"), Self::WsFrameNotFinished => write!(f, "Unfinished websocket frame"), @@ -286,7 +289,15 @@ impl MuxInner { let _ = channel.send(Err(WispError::InvalidStreamId)); } } - WsEvent::EndFut => break, + WsEvent::EndFut(x) => { + if let Some(reason) = x { + let _ = self + .tx + .write_frame(Packet::new_close(0, reason).into()) + .await; + } + break; + } } } } @@ -364,6 +375,9 @@ impl MuxInner { } Continue(_) | Info(_) => break Err(WispError::InvalidPacketType), Close(_) => { + if packet.stream_id == 0 { + break Ok(()); + } if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) { stream.is_closed.store(true, Ordering::Release); stream.stream.disconnect(); @@ -410,6 +424,9 @@ impl MuxInner { } } Close(_) => { + if packet.stream_id == 0 { + break Ok(()); + } if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) { stream.is_closed.store(true, Ordering::Release); stream.stream.disconnect(); @@ -532,15 +549,28 @@ impl ServerMux { self.muxstream_recv.next().await } + async fn close_internal(&mut self, reason: Option) -> Result<(), WispError> { + self.close_tx + .send(WsEvent::EndFut(reason)) + .await + .map_err(|_| WispError::MuxMessageFailedToSend) + } + /// Close all streams. /// /// Also terminates the multiplexor future. Waiting for a new stream will never succeed after /// this function is called. pub async fn close(&mut self) -> Result<(), WispError> { - self.close_tx - .send(WsEvent::EndFut) + self.close_internal(None).await + } + + /// Close all streams and send an extension incompatibility error to the client. + /// + /// Also terminates the multiplexor future. Waiting for a new stream will never succed after + /// this function is called. + pub async fn close_extension_incompat(&mut self) -> Result<(), WispError> { + self.close_internal(Some(CloseReason::IncompatibleExtensions)) .await - .map_err(|_| WispError::MuxMessageFailedToSend) } } /// Client side multiplexor. @@ -600,7 +630,7 @@ impl ClientMux { x = read.wisp_read_frame(&write).fuse() => Some(x?), _ = Delay::new(Duration::from_secs(5)).fuse() => None } { - let packet = Packet::maybe_parse_info(frame, Role::Server, builders)?; + let packet = Packet::maybe_parse_info(frame, Role::Client, builders)?; if let PacketType::Info(info) = packet.packet_type { supported_extensions = info .extensions @@ -671,14 +701,27 @@ impl ClientMux { rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)? } + async fn close_internal(&mut self, reason: Option) -> Result<(), WispError> { + self.close_tx + .send(WsEvent::EndFut(reason)) + .await + .map_err(|_| WispError::MuxMessageFailedToSend) + } + /// Close all streams. /// /// Also terminates the multiplexor future. Creating a stream is UB after calling this /// function. pub async fn close(&mut self) -> Result<(), WispError> { - self.close_tx - .send(WsEvent::EndFut) + self.close_internal(None).await + } + + /// Close all streams and send an extension incompatibility error to the client. + /// + /// Also terminates the multiplexor future. Creating a stream is UB after calling this + /// function. + pub async fn close_extension_incompat(&mut self) -> Result<(), WispError> { + self.close_internal(Some(CloseReason::IncompatibleExtensions)) .await - .map_err(|_| WispError::MuxMessageFailedToSend) } } diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index 9ff6a3c..41554a3 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -446,6 +446,10 @@ impl Packet { minor: bytes.get_u8(), }; + if version.major != WISP_VERSION.major { + return Err(WispError::IncompatibleProtocolVersion); + } + let mut extensions = Vec::new(); while bytes.remaining() > 4 { diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index f579140..69c711b 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -27,7 +27,7 @@ pub(crate) enum WsEvent { u16, oneshot::Sender>, ), - EndFut, + EndFut(Option), } /// Read side of a multiplexor stream. From 76da9fd6192042d395ecd6a0d5e2acec08dc352f Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sat, 13 Apr 2024 22:57:27 -0700 Subject: [PATCH 08/14] add notice about protocol extension availability --- wisp/src/lib.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index d8f7dd0..b25f01b 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -473,6 +473,8 @@ impl ServerMux { /// Create a new server-side multiplexor. /// /// If extension_builders is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. + /// **It is not guaranteed that all extensions you specify are available.** You must manually check + /// if the extensions you need are available after the multiplexor has been created. pub async fn new( mut read: R, write: W, @@ -601,6 +603,8 @@ impl ClientMux { /// Create a new client side multiplexor. /// /// If extension_builders is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. + /// **It is not guaranteed that all extensions you specify are available.** You must manually check + /// if the extensions you need are available after the multiplexor has been created. pub async fn new( mut read: R, write: W, From ace9bf380dfc95429420ea26d97cc1f69708e536 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sat, 13 Apr 2024 23:45:40 -0700 Subject: [PATCH 09/14] make extensions owned --- client/src/utils.rs | 2 +- server/src/main.rs | 12 ++++++------ simple-wisp-client/src/main.rs | 6 ++---- wisp/src/extensions.rs | 8 ++++---- wisp/src/lib.rs | 12 ++++++------ wisp/src/packet.rs | 4 ++-- 6 files changed, 21 insertions(+), 23 deletions(-) diff --git a/client/src/utils.rs b/client/src/utils.rs index 5ed12b7..3f6dc9b 100644 --- a/client/src/utils.rs +++ b/client/src/utils.rs @@ -204,7 +204,7 @@ pub async fn make_mux( .await .map_err(|_| WispError::WsImplSocketClosed)?; wtx.wait_for_open().await; - let mux = ClientMux::new(wrx, wtx, Some(&[&UdpProtocolExtensionBuilder()])).await?; + let mux = ClientMux::new(wrx, wtx, Some(&[Box::new(UdpProtocolExtensionBuilder())])).await?; Ok(mux) } diff --git a/server/src/main.rs b/server/src/main.rs index 12cc030..a4d269a 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -22,7 +22,7 @@ use tokio_util::either::Either; use wisp_mux::{ extensions::{ password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder}, - udp::UdpProtocolExtensionBuilder, + udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder, }, CloseReason, ConnectPacket, MuxStream, ServerMux, StreamType, WispError, }; @@ -72,7 +72,7 @@ struct MuxOptions { pub block_udp: bool, pub block_non_http: bool, pub enforce_auth: bool, - pub auth: Arc, + pub auth: Arc>>, } #[cfg(not(unix))] @@ -176,13 +176,13 @@ async fn main() -> Result<(), Error> { auth.insert(username.to_string(), password.to_string()); } } - let pw_ext = Arc::new(PasswordProtocolExtensionBuilder::new_server(auth)); + let pw_ext = PasswordProtocolExtensionBuilder::new_server(auth); let mux_options = MuxOptions { block_local: opt.block_local, block_non_http: opt.block_non_http, block_udp: opt.block_udp, - auth: pw_ext, + auth: Arc::new(vec![Box::new(UdpProtocolExtensionBuilder()), Box::new(pw_ext)]), enforce_auth, }; @@ -314,7 +314,7 @@ async fn accept_ws( rx, tx, u32::MAX, - Some(&[&UdpProtocolExtensionBuilder(), mux_options.auth.as_ref()]), + Some(mux_options.auth.as_slice()), ) .await?; if !mux @@ -331,7 +331,7 @@ async fn accept_ws( } (mux, fut) } else { - ServerMux::new(rx, tx, u32::MAX, Some(&[&UdpProtocolExtensionBuilder()])).await? + ServerMux::new(rx, tx, u32::MAX, Some(&[Box::new(UdpProtocolExtensionBuilder())])).await? }; println!( diff --git a/simple-wisp-client/src/main.rs b/simple-wisp-client/src/main.rs index 1f5802f..97a7b15 100644 --- a/simple-wisp-client/src/main.rs +++ b/simple-wisp-client/src/main.rs @@ -152,7 +152,7 @@ async fn main() -> Result<(), Box> { let (rx, tx) = ws.split(tokio::io::split); let rx = FragmentCollectorRead::new(rx); - let mut extensions: Vec> = Vec::new(); + let mut extensions: Vec> = Vec::new(); if opts.udp { extensions.push(Box::new(UdpProtocolExtensionBuilder())); } @@ -160,10 +160,8 @@ async fn main() -> Result<(), Box> { if let Some(auth) = auth { extensions.push(Box::new(auth)); } - let extensions_mapped: Vec<&(dyn ProtocolExtensionBuilder + Sync)> = - extensions.iter().map(|x| x.as_ref()).collect(); - let (mut mux, fut) = ClientMux::new(rx, tx, Some(&extensions_mapped)).await?; + let (mut mux, fut) = ClientMux::new(rx, tx, Some(extensions.as_slice())).await?; if opts.udp && !mux .supported_extension_ids diff --git a/wisp/src/extensions.rs b/wisp/src/extensions.rs index 661439c..de472bd 100644 --- a/wisp/src/extensions.rs +++ b/wisp/src/extensions.rs @@ -112,7 +112,7 @@ pub mod udp { //! rx, //! tx, //! 128, - //! Some(&[&UdpProtocolExtensionBuilder()]) + //! Some(&[Box::new(UdpProtocolExtensionBuilder())]) //! ); //! ``` //! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---udp) @@ -213,7 +213,7 @@ pub mod password { //! rx, //! tx, //! 128, - //! Some(&[&PasswordProtocolExtensionBuilder::new_server(passwords)]) + //! Some(&[Box::new(PasswordProtocolExtensionBuilder::new_server(passwords))]) //! ); //! ``` //! @@ -224,10 +224,10 @@ pub mod password { //! tx, //! 128, //! Some(&[ - //! &PasswordProtocolExtensionBuilder::new_client( + //! Box::new(PasswordProtocolExtensionBuilder::new_client( //! "user1".to_string(), //! "pw".to_string() - //! ) + //! )) //! ]) //! ); //! ``` diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index b25f01b..ff732af 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -445,7 +445,7 @@ impl MuxInner { /// ``` /// use wisp_mux::ServerMux; /// -/// let (mux, fut) = ServerMux::new(rx, tx, 128, Some(vec![]), Some([])); +/// let (mux, fut) = ServerMux::new(rx, tx, 128, Some([])); /// tokio::spawn(async move { /// if let Err(e) = fut.await { /// println!("error in multiplexor: {:?}", e); @@ -472,14 +472,14 @@ pub struct ServerMux { impl ServerMux { /// Create a new server-side multiplexor. /// - /// If extension_builders is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. + /// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. /// **It is not guaranteed that all extensions you specify are available.** You must manually check /// if the extensions you need are available after the multiplexor has been created. pub async fn new( mut read: R, write: W, buffer_size: u32, - extension_builders: Option<&[&(dyn ProtocolExtensionBuilder + Sync)]>, + extension_builders: Option<&[Box]>, ) -> Result<(Self, impl Future> + Send), WispError> where R: ws::WebSocketRead + Send, @@ -581,7 +581,7 @@ impl ServerMux { /// ``` /// use wisp_mux::{ClientMux, StreamType}; /// -/// let (mux, fut) = ClientMux::new(rx, tx, Some(vec![]), []).await?; +/// let (mux, fut) = ClientMux::new(rx, tx, Some([])).await?; /// tokio::spawn(async move { /// if let Err(e) = fut.await { /// println!("error in multiplexor: {:?}", e); @@ -602,13 +602,13 @@ pub struct ClientMux { impl ClientMux { /// Create a new client side multiplexor. /// - /// If extension_builders is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. + /// If `extension_builders` is None a Wisp v1 connection is created otherwise a Wisp v2 connection is created. /// **It is not guaranteed that all extensions you specify are available.** You must manually check /// if the extensions you need are available after the multiplexor has been created. pub async fn new( mut read: R, write: W, - extension_builders: Option<&[&(dyn ProtocolExtensionBuilder + Sync)]>, + extension_builders: Option<&[Box]>, ) -> Result<(Self, impl Future> + Send), WispError> where R: ws::WebSocketRead + Send, diff --git a/wisp/src/packet.rs b/wisp/src/packet.rs index 41554a3..ad713f1 100644 --- a/wisp/src/packet.rs +++ b/wisp/src/packet.rs @@ -380,7 +380,7 @@ impl Packet { pub(crate) fn maybe_parse_info( frame: Frame, role: Role, - extension_builders: &[&(dyn ProtocolExtensionBuilder + Sync)], + extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>], ) -> Result { if !frame.finished { return Err(WispError::WsFrameNotFinished); @@ -431,7 +431,7 @@ impl Packet { fn parse_info( mut bytes: Bytes, role: Role, - extension_builders: &[&(dyn ProtocolExtensionBuilder + Sync)], + extension_builders: &[Box<(dyn ProtocolExtensionBuilder + Send + Sync)>], ) -> Result { // packet type is already read by code that calls this if bytes.remaining() < 4 + 2 { From f2021e23829f5dc08c2a8cbce59350318f9758e9 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sun, 14 Apr 2024 14:59:15 -0700 Subject: [PATCH 10/14] improve performance --- server/src/main.rs | 6 ++++-- wisp/src/fastwebsockets.rs | 26 ++++++++++++-------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/server/src/main.rs b/server/src/main.rs index a4d269a..a529872 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -309,11 +309,13 @@ async fn accept_ws( let rx = FragmentCollectorRead::new(rx); println!("{:?}: connected", addr); + // to prevent memory ""leaks"" because users are sending in packets way too fast the buffer + // size is set to 32 let (mut mux, fut) = if mux_options.enforce_auth { let (mut mux, fut) = ServerMux::new( rx, tx, - u32::MAX, + 32, Some(mux_options.auth.as_slice()), ) .await?; @@ -331,7 +333,7 @@ async fn accept_ws( } (mux, fut) } else { - ServerMux::new(rx, tx, u32::MAX, Some(&[Box::new(UdpProtocolExtensionBuilder())])).await? + ServerMux::new(rx, tx, 32, Some(&[Box::new(UdpProtocolExtensionBuilder())])).await? }; println!( diff --git a/wisp/src/fastwebsockets.rs b/wisp/src/fastwebsockets.rs index 548649f..cd0199c 100644 --- a/wisp/src/fastwebsockets.rs +++ b/wisp/src/fastwebsockets.rs @@ -1,5 +1,7 @@ +use std::ops::Deref; + use async_trait::async_trait; -use bytes::Bytes; +use bytes::BytesMut; use fastwebsockets::{ FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite, }; @@ -24,29 +26,25 @@ impl From for crate::ws::OpCode { } impl From> for crate::ws::Frame { - fn from(mut frame: Frame) -> Self { + fn from(frame: Frame) -> Self { Self { finished: frame.fin, opcode: frame.opcode.into(), - payload: Bytes::copy_from_slice(frame.payload.to_mut()), + payload: BytesMut::from(frame.payload.deref()).freeze(), } } } -impl From for Frame<'_> { +impl<'a> From for Frame<'a> { fn from(frame: crate::ws::Frame) -> Self { use crate::ws::OpCode::*; + let payload = Payload::Owned(frame.payload.into()); match frame.opcode { - Text => Self::text(Payload::Owned(frame.payload.to_vec())), - Binary => Self::binary(Payload::Owned(frame.payload.to_vec())), - Close => Self::close_raw(Payload::Owned(frame.payload.to_vec())), - Ping => Self::new( - true, - OpCode::Ping, - None, - Payload::Owned(frame.payload.to_vec()), - ), - Pong => Self::pong(Payload::Owned(frame.payload.to_vec())), + Text => Self::text(payload), + Binary => Self::binary(payload), + Close => Self::close_raw(payload), + Ping => Self::new(true, OpCode::Ping, None, payload), + Pong => Self::pong(payload), } } } From 5af56fe58261c3fd07f0e38d1e0fce416ebefb41 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Sun, 14 Apr 2024 17:59:24 -0700 Subject: [PATCH 11/14] force a bounded channel --- server/src/main.rs | 6 ++--- simple-wisp-client/src/main.rs | 10 +++++++- wisp/src/lib.rs | 43 +++++++++++++++++----------------- wisp/src/stream.rs | 4 ++-- 4 files changed, 35 insertions(+), 28 deletions(-) diff --git a/server/src/main.rs b/server/src/main.rs index a529872..7e0e581 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -310,12 +310,12 @@ async fn accept_ws( println!("{:?}: connected", addr); // to prevent memory ""leaks"" because users are sending in packets way too fast the buffer - // size is set to 32 + // size is set to 128 let (mut mux, fut) = if mux_options.enforce_auth { let (mut mux, fut) = ServerMux::new( rx, tx, - 32, + 128, Some(mux_options.auth.as_slice()), ) .await?; @@ -333,7 +333,7 @@ async fn accept_ws( } (mux, fut) } else { - ServerMux::new(rx, tx, 32, Some(&[Box::new(UdpProtocolExtensionBuilder())])).await? + ServerMux::new(rx, tx, 128, Some(&[Box::new(UdpProtocolExtensionBuilder())])).await? }; println!( diff --git a/simple-wisp-client/src/main.rs b/simple-wisp-client/src/main.rs index 97a7b15..198d915 100644 --- a/simple-wisp-client/src/main.rs +++ b/simple-wisp-client/src/main.rs @@ -92,6 +92,9 @@ struct Cli { /// Usernames and passwords are sent in plaintext!! #[arg(long)] auth: Option, + /// Make a Wisp V1 connection + #[arg(long)] + wisp_v1: bool, } #[tokio::main(flavor = "multi_thread")] @@ -161,7 +164,12 @@ async fn main() -> Result<(), Box> { extensions.push(Box::new(auth)); } - let (mut mux, fut) = ClientMux::new(rx, tx, Some(extensions.as_slice())).await?; + let (mut mux, fut) = if opts.wisp_v1 { + ClientMux::new(rx, tx, None).await? + } else { + ClientMux::new(rx, tx, Some(extensions.as_slice())).await? + }; + if opts.udp && !mux .supported_extension_ids diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index ff732af..1454145 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -20,8 +20,7 @@ use dashmap::DashMap; use event_listener::Event; use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder}; use futures::{ - channel::{mpsc, oneshot}, - select, Future, FutureExt, SinkExt, StreamExt, + channel::{mpsc, oneshot}, lock::Mutex, select, Future, FutureExt, SinkExt, StreamExt }; use futures_timer::Delay; use std::{ @@ -152,7 +151,7 @@ impl std::fmt::Display for WispError { impl std::error::Error for WispError {} struct MuxMapValue { - stream: mpsc::UnboundedSender, + stream: Mutex>, stream_type: StreamType, flow_control: Arc, flow_control_event: Arc, @@ -209,11 +208,11 @@ impl MuxInner { _ = self.stream_loop(close_rx, close_tx).fuse() => Ok(()), x = wisp_fut.fuse() => x, }; - self.stream_map.iter_mut().for_each(|mut x| { + for x in self.stream_map.iter_mut() { x.is_closed.store(true, Ordering::Release); - x.stream.disconnect(); - x.stream.close_channel(); - }); + x.stream.lock().await.disconnect(); + x.stream.lock().await.close_channel(); + } self.stream_map.clear(); ret } @@ -235,7 +234,7 @@ impl MuxInner { } WsEvent::CreateStream(stream_type, host, port, channel) => { let ret: Result = async { - let (ch_tx, ch_rx) = mpsc::unbounded(); + let (ch_tx, ch_rx) = mpsc::channel(self.buffer_size as usize); let stream_id = next_free_stream_id; let next_stream_id = next_free_stream_id .checked_add(1) @@ -257,7 +256,7 @@ impl MuxInner { self.stream_map.insert( stream_id, MuxMapValue { - stream: ch_tx, + stream: ch_tx.into(), stream_type, flow_control: flow_control.clone(), flow_control_event: flow_control_event.clone(), @@ -281,9 +280,9 @@ impl MuxInner { let _ = channel.send(ret); } WsEvent::Close(packet, channel) => { - if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) { - stream.stream.disconnect(); - stream.stream.close_channel(); + if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) { + stream.stream.lock().await.disconnect(); + stream.stream.lock().await.close_channel(); let _ = channel.send(self.tx.write_frame(packet.into()).await); } else { let _ = channel.send(Err(WispError::InvalidStreamId)); @@ -326,7 +325,7 @@ impl MuxInner { use PacketType::*; match packet.packet_type { Connect(inner_packet) => { - let (ch_tx, ch_rx) = mpsc::unbounded(); + let (ch_tx, ch_rx) = mpsc::channel(self.buffer_size as usize); let stream_type = inner_packet.stream_type; let flow_control: Arc = AtomicU32::new(self.buffer_size).into(); let flow_control_event: Arc = Event::new().into(); @@ -335,7 +334,7 @@ impl MuxInner { self.stream_map.insert( packet.stream_id, MuxMapValue { - stream: ch_tx, + stream: ch_tx.into(), stream_type, flow_control: flow_control.clone(), flow_control_event: flow_control_event.clone(), @@ -361,7 +360,7 @@ impl MuxInner { } Data(data) => { if let Some(stream) = self.stream_map.get(&packet.stream_id) { - let _ = stream.stream.unbounded_send(data); + let _ = stream.stream.lock().await.send(data).await; if stream.stream_type == StreamType::Tcp { stream.flow_control.store( stream @@ -378,10 +377,10 @@ impl MuxInner { if packet.stream_id == 0 { break Ok(()); } - if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) { + if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) { stream.is_closed.store(true, Ordering::Release); - stream.stream.disconnect(); - stream.stream.close_channel(); + stream.stream.lock().await.disconnect(); + stream.stream.lock().await.close_channel(); } } } @@ -410,7 +409,7 @@ impl MuxInner { Connect(_) | Info(_) => break Err(WispError::InvalidPacketType), Data(data) => { if let Some(stream) = self.stream_map.get(&packet.stream_id) { - let _ = stream.stream.unbounded_send(data); + let _ = stream.stream.lock().await.send(data).await; } } Continue(inner_packet) => { @@ -427,10 +426,10 @@ impl MuxInner { if packet.stream_id == 0 { break Ok(()); } - if let Some((_, mut stream)) = self.stream_map.remove(&packet.stream_id) { + if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) { stream.is_closed.store(true, Ordering::Release); - stream.stream.disconnect(); - stream.stream.close_channel(); + stream.stream.lock().await.disconnect(); + stream.stream.lock().await.close_channel(); } } } diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index 69c711b..1a8c2da 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -38,7 +38,7 @@ pub struct MuxStreamRead { pub stream_type: StreamType, role: Role, tx: mpsc::Sender, - rx: mpsc::UnboundedReceiver, + rx: mpsc::Receiver, is_closed: Arc, flow_control: Arc, flow_control_read: AtomicU32, @@ -193,7 +193,7 @@ impl MuxStream { stream_id: u32, role: Role, stream_type: StreamType, - rx: mpsc::UnboundedReceiver, + rx: mpsc::Receiver, tx: mpsc::Sender, is_closed: Arc, flow_control: Arc, From 5e741d380876b2aa1966b64d64a96a83d39f567b Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Mon, 15 Apr 2024 17:42:49 -0700 Subject: [PATCH 12/14] use blazingly fast flume channels :rocket: --- Cargo.lock | 25 +++++ client/demo.js | 4 +- client/src/utils.rs | 12 +- client/src/wrappers.rs | 11 +- server/src/main.rs | 39 ++++--- simple-wisp-client/src/main.rs | 8 +- wisp/Cargo.toml | 1 + wisp/src/fastwebsockets.rs | 6 +- wisp/src/lib.rs | 197 ++++++++++++++++++--------------- wisp/src/stream.rs | 47 +++++--- wisp/src/ws.rs | 10 +- 11 files changed, 225 insertions(+), 135 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8fa37cd..33b2300 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -861,6 +861,18 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "flume" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181" +dependencies = [ + "futures-core", + "futures-sink", + "nanorand", + "spin", +] + [[package]] name = "fnv" version = "1.0.7" @@ -1487,6 +1499,15 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "nanorand" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" +dependencies = [ + "getrandom", +] + [[package]] name = "native-tls" version = "0.2.11" @@ -2273,6 +2294,9 @@ name = "spin" version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] [[package]] name = "strsim" @@ -3203,6 +3227,7 @@ dependencies = [ "dashmap", "event-listener", "fastwebsockets 0.7.1", + "flume", "futures", "futures-timer", "futures-util", diff --git a/client/demo.js b/client/demo.js index a5b35a2..298e2f2 100644 --- a/client/demo.js +++ b/client/demo.js @@ -238,9 +238,9 @@ onmessage = async (msg) => { log(`total avg mux (${num_outer_tests} tests of ${num_inner_tests} reqs): ${total_mux_multi} ms or ${total_mux_multi / 1000} s`); } else { - let resp = await epoxy_client.fetch("https://httpbin.org/get"); + let resp = await epoxy_client.fetch("https://www.example.com/"); console.log(resp, Object.fromEntries(resp.headers)); - plog(await resp.json()); + log(await resp.text()); } log("done"); }; diff --git a/client/src/utils.rs b/client/src/utils.rs index 3f6dc9b..98717a5 100644 --- a/client/src/utils.rs +++ b/client/src/utils.rs @@ -200,13 +200,10 @@ pub async fn make_mux( ), WispError, > { - let (wtx, wrx) = WebSocketWrapper::connect(url, vec![]) - .await - .map_err(|_| WispError::WsImplSocketClosed)?; + let (wtx, wrx) = + WebSocketWrapper::connect(url, vec![]).map_err(|_| WispError::WsImplSocketClosed)?; wtx.wait_for_open().await; - let mux = ClientMux::new(wrx, wtx, Some(&[Box::new(UdpProtocolExtensionBuilder())])).await?; - - Ok(mux) + ClientMux::new(wrx, wtx, Some(&[Box::new(UdpProtocolExtensionBuilder())])).await } pub fn spawn_mux_fut( @@ -215,6 +212,7 @@ pub fn spawn_mux_fut( url: String, ) { wasm_bindgen_futures::spawn_local(async move { + debug!("epoxy: mux future started"); if let Err(e) = fut.await { log!("epoxy: error in mux future, restarting: {:?}", e); while let Err(e) = replace_mux(mux.clone(), &url).await { @@ -229,7 +227,7 @@ pub fn spawn_mux_fut( pub async fn replace_mux(mux: Arc>, url: &str) -> Result<(), WispError> { let (mux_replace, fut) = make_mux(url).await?; let mut mux_write = mux.write().await; - mux_write.close().await?; + let _ = mux_write.close().await; *mux_write = mux_replace; drop(mux_write); spawn_mux_fut(mux, fut, url.into()); diff --git a/client/src/wrappers.rs b/client/src/wrappers.rs index e67779e..5746ac2 100644 --- a/client/src/wrappers.rs +++ b/client/src/wrappers.rs @@ -123,6 +123,7 @@ impl tower_service::Service for TlsWispService { let stream = service.call(uri_parsed).await?.into_inner(); if utils::get_is_secure(&req).map_err(|_| WispError::InvalidUri)? { let connector = TlsConnector::from(rustls_config); + log!("got stream"); Ok(TokioIo::new(Either::Left( connector .connect( @@ -143,6 +144,7 @@ impl tower_service::Service for TlsWispService { pub enum WebSocketError { Unknown, SendFailed, + CloseFailed, } impl std::fmt::Display for WebSocketError { @@ -151,6 +153,7 @@ impl std::fmt::Display for WebSocketError { match self { Unknown => write!(f, "Unknown error"), SendFailed => write!(f, "Send failed"), + CloseFailed => write!(f, "Close failed"), } } } @@ -213,7 +216,7 @@ impl WebSocketRead for WebSocketReader { } impl WebSocketWrapper { - pub async fn connect( + pub fn connect( url: &str, protocols: Vec, ) -> Result<(Self, WebSocketReader), JsValue> { @@ -327,6 +330,12 @@ impl WebSocketWrite for WebSocketWrapper { _ => Err(WispError::WsImplNotSupported), } } + + async fn wisp_close(&mut self) -> Result<(), WispError> { + self.inner + .close() + .map_err(|_| WebSocketError::CloseFailed.into()) + } } impl Drop for WebSocketWrapper { diff --git a/server/src/main.rs b/server/src/main.rs index 7e0e581..a6c8b8c 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -12,9 +12,13 @@ use hyper::{ body::Incoming, server::conn::http1, service::service_fn, Request, Response, StatusCode, }; use hyper_util::rt::TokioIo; -use tokio::net::{lookup_host, TcpListener, TcpStream, UdpSocket}; #[cfg(unix)] use tokio::net::{UnixListener, UnixStream}; +use tokio::{ + io::{copy_bidirectional, split, BufReader, BufWriter}, + net::{lookup_host, TcpListener, TcpStream, UdpSocket}, + select, +}; use tokio_util::codec::{BytesCodec, Framed}; #[cfg(unix)] use tokio_util::either::Either; @@ -22,9 +26,10 @@ use tokio_util::either::Either; use wisp_mux::{ extensions::{ password::{PasswordProtocolExtension, PasswordProtocolExtensionBuilder}, - udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder, + udp::UdpProtocolExtensionBuilder, + ProtocolExtensionBuilder, }, - CloseReason, ConnectPacket, MuxStream, ServerMux, StreamType, WispError, + CloseReason, ConnectPacket, IoStream, MuxStream, MuxStreamIo, ServerMux, StreamType, WispError, }; type HttpBody = http_body_util::Full; @@ -182,7 +187,10 @@ async fn main() -> Result<(), Error> { block_local: opt.block_local, block_non_http: opt.block_non_http, block_udp: opt.block_udp, - auth: Arc::new(vec![Box::new(UdpProtocolExtensionBuilder()), Box::new(pw_ext)]), + auth: Arc::new(vec![ + Box::new(UdpProtocolExtensionBuilder()), + Box::new(pw_ext), + ]), enforce_auth, }; @@ -257,7 +265,7 @@ async fn handle_mux(packet: ConnectPacket, mut stream: MuxStream) -> Result Result<(), Box> { avg.get_average() * opts.packet_size, ); if is_term { - print!("\x1b[2K{}\r", stat); + println!("\x1b[1A\x1b[2K{}\r", stat); } else { println!("{}", stat); } @@ -284,6 +284,8 @@ async fn main() -> Result<(), Box> { let out = select_all(threads.into_iter()).await; + let duration_since = Instant::now().duration_since(start_time); + if let Err(err) = out.0? { println!("\n\nerr: {:?}", err); exit(1); @@ -291,10 +293,10 @@ async fn main() -> Result<(), Box> { out.2.into_iter().for_each(|x| x.abort()); - let duration_since = Instant::now().duration_since(start_time); + mux.close().await?; println!( - "\n\nresults: {} packets of &[0; 1024 * {}] ({} KiB) sent in {} ({} KiB/s)", + "\nresults: {} packets of &[0; 1024 * {}] ({} KiB) sent in {} ({} KiB/s)", cnt.get(), opts.packet_size, cnt.get() * opts.packet_size, diff --git a/wisp/Cargo.toml b/wisp/Cargo.toml index 8cf2cba..795a1c6 100644 --- a/wisp/Cargo.toml +++ b/wisp/Cargo.toml @@ -15,6 +15,7 @@ bytes = "1.5.0" dashmap = { version = "5.5.3", features = ["inline"] } event-listener = "5.0.0" fastwebsockets = { version = "0.7.1", features = ["unstable-split"], optional = true } +flume = "0.11.0" futures = "0.3.30" futures-timer = "3.0.3" futures-util = "0.3.30" diff --git a/wisp/src/fastwebsockets.rs b/wisp/src/fastwebsockets.rs index cd0199c..a2c2f7d 100644 --- a/wisp/src/fastwebsockets.rs +++ b/wisp/src/fastwebsockets.rs @@ -3,7 +3,7 @@ use std::ops::Deref; use async_trait::async_trait; use bytes::BytesMut; use fastwebsockets::{ - FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite, + CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite }; use tokio::io::{AsyncRead, AsyncWrite}; @@ -77,4 +77,8 @@ impl crate::ws::WebSocketWrite for WebSocketWrite< async fn wisp_write_frame(&mut self, frame: crate::ws::Frame) -> Result<(), WispError> { self.write_frame(frame.into()).await.map_err(|e| e.into()) } + + async fn wisp_close(&mut self) -> Result<(), WispError> { + self.write_frame(Frame::close(CloseCode::Normal.into(), b"")).await.map_err(|e| e.into()) + } } diff --git a/wisp/src/lib.rs b/wisp/src/lib.rs index 1454145..d68edf0 100644 --- a/wisp/src/lib.rs +++ b/wisp/src/lib.rs @@ -1,4 +1,4 @@ -#![deny(missing_docs)] +#![deny(missing_docs, warnings)] #![cfg_attr(docsrs, feature(doc_cfg))] //! A library for easily creating [Wisp] clients and servers. //! @@ -19,9 +19,8 @@ use bytes::Bytes; use dashmap::DashMap; use event_listener::Event; use extensions::{udp::UdpProtocolExtension, AnyProtocolExtension, ProtocolExtensionBuilder}; -use futures::{ - channel::{mpsc, oneshot}, lock::Mutex, select, Future, FutureExt, SinkExt, StreamExt -}; +use flume as mpsc; +use futures::{channel::oneshot, select, Future, FutureExt}; use futures_timer::Delay; use std::{ sync::{ @@ -151,11 +150,12 @@ impl std::fmt::Display for WispError { impl std::error::Error for WispError {} struct MuxMapValue { - stream: Mutex>, + stream: mpsc::Sender, stream_type: StreamType, flow_control: Arc, flow_control_event: Arc, is_closed: Arc, + is_closed_event: Arc, } struct MuxInner { @@ -170,7 +170,7 @@ impl MuxInner { rx: R, extensions: Vec, close_rx: mpsc::Receiver, - muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>, + muxstream_sender: mpsc::Sender<(ConnectPacket, MuxStream)>, close_tx: mpsc::Sender, ) -> Result<(), WispError> where @@ -210,20 +210,60 @@ impl MuxInner { }; for x in self.stream_map.iter_mut() { x.is_closed.store(true, Ordering::Release); - x.stream.lock().await.disconnect(); - x.stream.lock().await.close_channel(); + x.is_closed_event.notify(usize::MAX); } self.stream_map.clear(); + let _ = self.tx.close().await; ret } + async fn create_new_stream( + &self, + stream_id: u32, + stream_type: StreamType, + role: Role, + stream_tx: mpsc::Sender, + target_buffer_size: u32, + ) -> Result<(MuxMapValue, MuxStream), WispError> { + let (ch_tx, ch_rx) = mpsc::bounded(self.buffer_size as usize); + + let flow_control_event: Arc = Event::new().into(); + let flow_control: Arc = AtomicU32::new(self.buffer_size).into(); + + let is_closed: Arc = AtomicBool::new(false).into(); + let is_closed_event: Arc = Event::new().into(); + + Ok(( + MuxMapValue { + stream: ch_tx, + stream_type, + flow_control: flow_control.clone(), + flow_control_event: flow_control_event.clone(), + is_closed: is_closed.clone(), + is_closed_event: is_closed_event.clone(), + }, + MuxStream::new( + stream_id, + role, + stream_type, + ch_rx, + stream_tx.clone(), + is_closed, + is_closed_event, + flow_control, + flow_control_event, + target_buffer_size, + ), + )) + } + async fn stream_loop( &self, - mut stream_rx: mpsc::Receiver, + stream_rx: mpsc::Receiver, stream_tx: mpsc::Sender, ) { let mut next_free_stream_id: u32 = 1; - while let Some(msg) = stream_rx.next().await { + while let Ok(msg) = stream_rx.recv_async().await { match msg { WsEvent::SendPacket(packet, channel) => { if self.stream_map.get(&packet.stream_id).is_some() { @@ -234,16 +274,20 @@ impl MuxInner { } WsEvent::CreateStream(stream_type, host, port, channel) => { let ret: Result = async { - let (ch_tx, ch_rx) = mpsc::channel(self.buffer_size as usize); let stream_id = next_free_stream_id; let next_stream_id = next_free_stream_id .checked_add(1) .ok_or(WispError::MaxStreamCountReached)?; - let flow_control_event: Arc = Event::new().into(); - let flow_control: Arc = AtomicU32::new(self.buffer_size).into(); - - let is_closed: Arc = AtomicBool::new(false).into(); + let (map_value, stream) = self + .create_new_stream( + stream_id, + stream_type, + Role::Client, + stream_tx.clone(), + 0, + ) + .await?; self.tx .write_frame( @@ -251,39 +295,19 @@ impl MuxInner { ) .await?; + self.stream_map.insert(stream_id, map_value); + next_free_stream_id = next_stream_id; - self.stream_map.insert( - stream_id, - MuxMapValue { - stream: ch_tx.into(), - stream_type, - flow_control: flow_control.clone(), - flow_control_event: flow_control_event.clone(), - is_closed: is_closed.clone(), - }, - ); - - Ok(MuxStream::new( - stream_id, - Role::Client, - stream_type, - ch_rx, - stream_tx.clone(), - is_closed, - flow_control, - flow_control_event, - 0, - )) + Ok(stream) } .await; let _ = channel.send(ret); } WsEvent::Close(packet, channel) => { if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) { - stream.stream.lock().await.disconnect(); - stream.stream.lock().await.close_channel(); let _ = channel.send(self.tx.write_frame(packet.into()).await); + drop(stream.stream) } else { let _ = channel.send(Err(WispError::InvalidStreamId)); } @@ -305,8 +329,8 @@ impl MuxInner { &self, mut rx: R, mut extensions: Vec, - muxstream_sender: mpsc::UnboundedSender<(ConnectPacket, MuxStream)>, - close_tx: mpsc::Sender, + muxstream_sender: mpsc::Sender<(ConnectPacket, MuxStream)>, + stream_tx: mpsc::Sender, ) -> Result<(), WispError> where R: ws::WebSocketRead + Send, @@ -325,42 +349,24 @@ impl MuxInner { use PacketType::*; match packet.packet_type { Connect(inner_packet) => { - let (ch_tx, ch_rx) = mpsc::channel(self.buffer_size as usize); - let stream_type = inner_packet.stream_type; - let flow_control: Arc = AtomicU32::new(self.buffer_size).into(); - let flow_control_event: Arc = Event::new().into(); - let is_closed: Arc = AtomicBool::new(false).into(); - - self.stream_map.insert( - packet.stream_id, - MuxMapValue { - stream: ch_tx.into(), - stream_type, - flow_control: flow_control.clone(), - flow_control_event: flow_control_event.clone(), - is_closed: is_closed.clone(), - }, - ); + let (map_value, stream) = self + .create_new_stream( + packet.stream_id, + inner_packet.stream_type, + Role::Server, + stream_tx.clone(), + target_buffer_size, + ) + .await?; muxstream_sender - .unbounded_send(( - inner_packet, - MuxStream::new( - packet.stream_id, - Role::Server, - stream_type, - ch_rx, - close_tx.clone(), - is_closed, - flow_control, - flow_control_event, - target_buffer_size, - ), - )) - .map_err(|x| WispError::Other(Box::new(x)))?; + .send_async((inner_packet, stream)) + .await + .map_err(|_| WispError::MuxMessageFailedToSend)?; + self.stream_map.insert(packet.stream_id, map_value); } Data(data) => { if let Some(stream) = self.stream_map.get(&packet.stream_id) { - let _ = stream.stream.lock().await.send(data).await; + let _ = stream.stream.send_async(data).await; if stream.stream_type == StreamType::Tcp { stream.flow_control.store( stream @@ -379,8 +385,8 @@ impl MuxInner { } if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) { stream.is_closed.store(true, Ordering::Release); - stream.stream.lock().await.disconnect(); - stream.stream.lock().await.close_channel(); + stream.is_closed_event.notify(usize::MAX); + drop(stream.stream) } } } @@ -409,7 +415,7 @@ impl MuxInner { Connect(_) | Info(_) => break Err(WispError::InvalidPacketType), Data(data) => { if let Some(stream) = self.stream_map.get(&packet.stream_id) { - let _ = stream.stream.lock().await.send(data).await; + let _ = stream.stream.send_async(data).await; } } Continue(inner_packet) => { @@ -428,8 +434,8 @@ impl MuxInner { } if let Some((_, stream)) = self.stream_map.remove(&packet.stream_id) { stream.is_closed.store(true, Ordering::Release); - stream.stream.lock().await.disconnect(); - stream.stream.lock().await.close_channel(); + stream.is_closed_event.notify(usize::MAX); + drop(stream.stream) } } } @@ -465,7 +471,7 @@ pub struct ServerMux { /// Extensions that are supported by both sides. pub supported_extension_ids: Vec, close_tx: mpsc::Sender, - muxstream_recv: mpsc::UnboundedReceiver<(ConnectPacket, MuxStream)>, + muxstream_recv: mpsc::Receiver<(ConnectPacket, MuxStream)>, } impl ServerMux { @@ -484,7 +490,7 @@ impl ServerMux { R: ws::WebSocketRead + Send, W: ws::WebSocketWrite + Send + 'static, { - let (close_tx, close_rx) = mpsc::channel::(256); + let (close_tx, close_rx) = mpsc::bounded::(256); let (tx, rx) = mpsc::unbounded::<(ConnectPacket, MuxStream)>(); let write = ws::LockedWebSocketWrite::new(Box::new(write)); @@ -547,12 +553,12 @@ impl ServerMux { /// Wait for a stream to be created. pub async fn server_new_stream(&mut self) -> Option<(ConnectPacket, MuxStream)> { - self.muxstream_recv.next().await + self.muxstream_recv.recv_async().await.ok() } async fn close_internal(&mut self, reason: Option) -> Result<(), WispError> { self.close_tx - .send(WsEvent::EndFut(reason)) + .send_async(WsEvent::EndFut(reason)) .await .map_err(|_| WispError::MuxMessageFailedToSend) } @@ -574,6 +580,13 @@ impl ServerMux { .await } } + +impl Drop for ServerMux { + fn drop(&mut self) { + let _ = self.close_tx.send(WsEvent::EndFut(None)); + } +} + /// Client side multiplexor. /// /// # Example @@ -595,7 +608,7 @@ pub struct ClientMux { pub downgraded: bool, /// Extensions that are supported by both sides. pub supported_extension_ids: Vec, - close_tx: mpsc::Sender, + stream_tx: mpsc::Sender, } impl ClientMux { @@ -654,10 +667,10 @@ impl ClientMux { extension.handle_handshake(&mut read, &write).await?; } - let (tx, rx) = mpsc::channel::(256); + let (tx, rx) = mpsc::bounded::(256); Ok(( Self { - close_tx: tx.clone(), + stream_tx: tx.clone(), downgraded, supported_extension_ids: supported_extensions .iter() @@ -697,16 +710,16 @@ impl ClientMux { return Err(WispError::UdpExtensionNotSupported); } let (tx, rx) = oneshot::channel(); - self.close_tx - .send(WsEvent::CreateStream(stream_type, host, port, tx)) + self.stream_tx + .send_async(WsEvent::CreateStream(stream_type, host, port, tx)) .await .map_err(|_| WispError::MuxMessageFailedToSend)?; rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)? } async fn close_internal(&mut self, reason: Option) -> Result<(), WispError> { - self.close_tx - .send(WsEvent::EndFut(reason)) + self.stream_tx + .send_async(WsEvent::EndFut(reason)) .await .map_err(|_| WispError::MuxMessageFailedToSend) } @@ -728,3 +741,9 @@ impl ClientMux { .await } } + +impl Drop for ClientMux { + fn drop(&mut self) { + let _ = self.stream_tx.send(WsEvent::EndFut(None)); + } +} diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index 1a8c2da..8a074a7 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -1,13 +1,14 @@ use crate::{sink_unfold, CloseReason, Packet, Role, StreamType, WispError}; -use async_io_stream::IoStream; +pub use async_io_stream::IoStream; use bytes::Bytes; use event_listener::Event; +use flume as mpsc; use futures::{ - channel::{mpsc, oneshot}, - stream, + channel::oneshot, + select, stream, task::{Context, Poll}, - Sink, SinkExt, Stream, StreamExt, + FutureExt, Sink, Stream, }; use pin_project_lite::pin_project; use std::{ @@ -40,6 +41,7 @@ pub struct MuxStreamRead { tx: mpsc::Sender, rx: mpsc::Receiver, is_closed: Arc, + is_closed_event: Arc, flow_control: Arc, flow_control_read: AtomicU32, target_flow_control: u32, @@ -51,13 +53,16 @@ impl MuxStreamRead { if self.is_closed.load(Ordering::Acquire) { return None; } - let bytes = self.rx.next().await?; + let bytes = select! { + x = self.rx.recv_async() => x.ok()?, + _ = self.is_closed_event.listen().fuse() => return None + }; if self.role == Role::Server && self.stream_type == StreamType::Tcp { let val = self.flow_control_read.fetch_add(1, Ordering::AcqRel) + 1; if val > self.target_flow_control { let (tx, rx) = oneshot::channel::>(); self.tx - .send(WsEvent::SendPacket( + .send_async(WsEvent::SendPacket( Packet::new_continue( self.stream_id, self.flow_control.fetch_add(val, Ordering::AcqRel) + val, @@ -107,13 +112,13 @@ impl MuxStreamWrite { } let (tx, rx) = oneshot::channel::>(); self.tx - .send(WsEvent::SendPacket( + .send_async(WsEvent::SendPacket( Packet::new_data(self.stream_id, data), tx, )) .await - .map_err(|x| WispError::Other(Box::new(x)))?; - rx.await.map_err(|x| WispError::Other(Box::new(x)))??; + .map_err(|_| WispError::MuxMessageFailedToSend)?; + rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??; if self.role == Role::Client && self.stream_type == StreamType::Tcp { self.flow_control.store( self.flow_control.load(Ordering::Acquire).saturating_sub(1), @@ -151,13 +156,13 @@ impl MuxStreamWrite { let (tx, rx) = oneshot::channel::>(); self.tx - .send(WsEvent::Close( + .send_async(WsEvent::Close( Packet::new_close(self.stream_id, reason), tx, )) .await - .map_err(|x| WispError::Other(Box::new(x)))?; - rx.await.map_err(|x| WispError::Other(Box::new(x)))??; + .map_err(|_| WispError::MuxMessageFailedToSend)?; + rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??; Ok(()) } @@ -179,6 +184,16 @@ impl MuxStreamWrite { } } +impl Drop for MuxStreamWrite { + fn drop(&mut self) { + if !self.is_closed.load(Ordering::Acquire) { + self.is_closed.store(true, Ordering::Release); + let (tx, _) = oneshot::channel(); + let _ = self.tx.send(WsEvent::Close(Packet::new_close(self.stream_id, CloseReason::Unknown), tx)); + } + } +} + /// Multiplexor stream. pub struct MuxStream { /// ID of the stream. @@ -196,6 +211,7 @@ impl MuxStream { rx: mpsc::Receiver, tx: mpsc::Sender, is_closed: Arc, + is_closed_event: Arc, flow_control: Arc, continue_recieved: Arc, target_flow_control: u32, @@ -209,6 +225,7 @@ impl MuxStream { tx: tx.clone(), rx, is_closed: is_closed.clone(), + is_closed_event: is_closed_event.clone(), flow_control: flow_control.clone(), flow_control_read: AtomicU32::new(0), target_flow_control, @@ -288,13 +305,13 @@ impl MuxStreamCloser { let (tx, rx) = oneshot::channel::>(); self.close_channel - .send(WsEvent::Close( + .send_async(WsEvent::Close( Packet::new_close(self.stream_id, reason), tx, )) .await - .map_err(|x| WispError::Other(Box::new(x)))?; - rx.await.map_err(|x| WispError::Other(Box::new(x)))??; + .map_err(|_| WispError::MuxMessageFailedToSend)?; + rx.await.map_err(|_| WispError::MuxMessageFailedToRecv)??; Ok(()) } diff --git a/wisp/src/ws.rs b/wisp/src/ws.rs index 7348bb8..258a5d1 100644 --- a/wisp/src/ws.rs +++ b/wisp/src/ws.rs @@ -76,6 +76,9 @@ pub trait WebSocketRead { pub trait WebSocketWrite { /// Write a frame to the socket. async fn wisp_write_frame(&mut self, frame: Frame) -> Result<(), WispError>; + + /// Close the socket. + async fn wisp_close(&mut self) -> Result<(), WispError>; } /// Locked WebSocket. @@ -88,9 +91,14 @@ impl LockedWebSocketWrite { } /// Write a frame to the websocket. - pub async fn write_frame(&self, frame: Frame) -> Result<(), crate::WispError> { + pub async fn write_frame(&self, frame: Frame) -> Result<(), WispError> { self.0.lock().await.wisp_write_frame(frame).await } + + /// Close the websocket. + pub async fn close(&self) -> Result<(), WispError> { + self.0.lock().await.wisp_close().await + } } pub(crate) struct AppendingWebSocketRead(pub Vec, pub R) From 359bf7107a7c7ef05baa25db4d5bb37778702e27 Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Mon, 15 Apr 2024 17:45:46 -0700 Subject: [PATCH 13/14] remove unused imports --- server/src/main.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/server/src/main.rs b/server/src/main.rs index a6c8b8c..623ff09 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -15,9 +15,8 @@ use hyper_util::rt::TokioIo; #[cfg(unix)] use tokio::net::{UnixListener, UnixStream}; use tokio::{ - io::{copy_bidirectional, split, BufReader, BufWriter}, + io::copy_bidirectional, net::{lookup_host, TcpListener, TcpStream, UdpSocket}, - select, }; use tokio_util::codec::{BytesCodec, Framed}; #[cfg(unix)] @@ -29,7 +28,7 @@ use wisp_mux::{ udp::UdpProtocolExtensionBuilder, ProtocolExtensionBuilder, }, - CloseReason, ConnectPacket, IoStream, MuxStream, MuxStreamIo, ServerMux, StreamType, WispError, + CloseReason, ConnectPacket, MuxStream, ServerMux, StreamType, WispError, }; type HttpBody = http_body_util::Full; From 4fd9e02879283c98727751bf4ab9bfefe9806fcb Mon Sep 17 00:00:00 2001 From: Toshit Chawda Date: Mon, 15 Apr 2024 22:58:31 -0700 Subject: [PATCH 14/14] separate out protocol extensions into new files; `cargo fmt` --- certs-grabber/src/main.rs | 2 +- client/src/wrappers.rs | 6 +- wisp/src/extensions.rs | 481 -------------------------------- wisp/src/extensions/mod.rs | 106 +++++++ wisp/src/extensions/password.rs | 276 ++++++++++++++++++ wisp/src/extensions/udp.rs | 93 ++++++ wisp/src/fastwebsockets.rs | 6 +- wisp/src/stream.rs | 5 +- 8 files changed, 485 insertions(+), 490 deletions(-) delete mode 100644 wisp/src/extensions.rs create mode 100644 wisp/src/extensions/mod.rs create mode 100644 wisp/src/extensions/password.rs create mode 100644 wisp/src/extensions/udp.rs diff --git a/certs-grabber/src/main.rs b/certs-grabber/src/main.rs index 11574b8..c6dacf8 100644 --- a/certs-grabber/src/main.rs +++ b/certs-grabber/src/main.rs @@ -60,5 +60,5 @@ async fn main() { } code.pop(); code.push_str("];"); - println!("{}",code); + println!("{}", code); } diff --git a/client/src/wrappers.rs b/client/src/wrappers.rs index 5746ac2..47ff2c3 100644 --- a/client/src/wrappers.rs +++ b/client/src/wrappers.rs @@ -123,7 +123,6 @@ impl tower_service::Service for TlsWispService { let stream = service.call(uri_parsed).await?.into_inner(); if utils::get_is_secure(&req).map_err(|_| WispError::InvalidUri)? { let connector = TlsConnector::from(rustls_config); - log!("got stream"); Ok(TokioIo::new(Either::Left( connector .connect( @@ -216,10 +215,7 @@ impl WebSocketRead for WebSocketReader { } impl WebSocketWrapper { - pub fn connect( - url: &str, - protocols: Vec, - ) -> Result<(Self, WebSocketReader), JsValue> { + pub fn connect(url: &str, protocols: Vec) -> Result<(Self, WebSocketReader), JsValue> { let (read_tx, read_rx) = mpsc::unbounded_channel(); let closed = Arc::new(AtomicBool::new(false)); diff --git a/wisp/src/extensions.rs b/wisp/src/extensions.rs deleted file mode 100644 index de472bd..0000000 --- a/wisp/src/extensions.rs +++ /dev/null @@ -1,481 +0,0 @@ -//! Wisp protocol extensions. - -use std::ops::{Deref, DerefMut}; - -use async_trait::async_trait; -use bytes::{BufMut, Bytes, BytesMut}; - -use crate::{ - ws::{LockedWebSocketWrite, WebSocketRead}, - Role, WispError, -}; - -/// Type-erased protocol extension that implements Clone. -#[derive(Debug)] -pub struct AnyProtocolExtension(Box); - -impl AnyProtocolExtension { - /// Create a new type-erased protocol extension. - pub fn new(extension: T) -> Self { - Self(Box::new(extension)) - } -} - -impl Deref for AnyProtocolExtension { - type Target = dyn ProtocolExtension; - fn deref(&self) -> &Self::Target { - self.0.deref() - } -} - -impl DerefMut for AnyProtocolExtension { - fn deref_mut(&mut self) -> &mut Self::Target { - self.0.deref_mut() - } -} - -impl Clone for AnyProtocolExtension { - fn clone(&self) -> Self { - Self(self.0.box_clone()) - } -} - -impl From for Bytes { - fn from(value: AnyProtocolExtension) -> Self { - let mut bytes = BytesMut::with_capacity(5); - let payload = value.encode(); - bytes.put_u8(value.get_id()); - bytes.put_u32_le(payload.len() as u32); - bytes.extend(payload); - bytes.freeze() - } -} - -/// A Wisp protocol extension. -/// -/// See [the -/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#protocol-extensions). -#[async_trait] -pub trait ProtocolExtension: std::fmt::Debug { - /// Get the protocol extension ID. - fn get_id(&self) -> u8; - /// Get the protocol extension's supported packets. - /// - /// Used to decide whether to call the protocol extension's packet handler. - fn get_supported_packets(&self) -> &'static [u8]; - - /// Encode self into Bytes. - fn encode(&self) -> Bytes; - - /// Handle the handshake part of a Wisp connection. - /// - /// This should be used to send or receive data before any streams are created. - async fn handle_handshake( - &mut self, - read: &mut dyn WebSocketRead, - write: &LockedWebSocketWrite, - ) -> Result<(), WispError>; - - /// Handle receiving a packet. - async fn handle_packet( - &mut self, - packet: Bytes, - read: &mut dyn WebSocketRead, - write: &LockedWebSocketWrite, - ) -> Result<(), WispError>; - - /// Clone the protocol extension. - fn box_clone(&self) -> Box; -} - -/// Trait to build a Wisp protocol extension from a payload. -pub trait ProtocolExtensionBuilder { - /// Get the protocol extension ID. - /// - /// Used to decide whether this builder should be used. - fn get_id(&self) -> u8; - - /// Build a protocol extension from the extension's metadata. - fn build_from_bytes(&self, bytes: Bytes, role: Role) - -> Result; - - /// Build a protocol extension to send to the other side. - fn build_to_extension(&self, role: Role) -> AnyProtocolExtension; -} - -pub mod udp { - //! UDP protocol extension. - //! - //! # Example - //! ``` - //! let (mux, fut) = ServerMux::new( - //! rx, - //! tx, - //! 128, - //! Some(&[Box::new(UdpProtocolExtensionBuilder())]) - //! ); - //! ``` - //! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---udp) - use async_trait::async_trait; - use bytes::Bytes; - - use crate::{ - ws::{LockedWebSocketWrite, WebSocketRead}, - WispError, - }; - - use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder}; - - #[derive(Debug)] - /// UDP protocol extension. - pub struct UdpProtocolExtension(); - - impl UdpProtocolExtension { - /// UDP protocol extension ID. - pub const ID: u8 = 0x01; - } - - #[async_trait] - impl ProtocolExtension for UdpProtocolExtension { - fn get_id(&self) -> u8 { - Self::ID - } - - fn get_supported_packets(&self) -> &'static [u8] { - &[] - } - - fn encode(&self) -> Bytes { - Bytes::new() - } - - async fn handle_handshake( - &mut self, - _: &mut dyn WebSocketRead, - _: &LockedWebSocketWrite, - ) -> Result<(), WispError> { - Ok(()) - } - - async fn handle_packet( - &mut self, - _: Bytes, - _: &mut dyn WebSocketRead, - _: &LockedWebSocketWrite, - ) -> Result<(), WispError> { - Ok(()) - } - - fn box_clone(&self) -> Box { - Box::new(Self()) - } - } - - impl From for AnyProtocolExtension { - fn from(value: UdpProtocolExtension) -> Self { - AnyProtocolExtension(Box::new(value)) - } - } - - /// UDP protocol extension builder. - pub struct UdpProtocolExtensionBuilder(); - - impl ProtocolExtensionBuilder for UdpProtocolExtensionBuilder { - fn get_id(&self) -> u8 { - UdpProtocolExtension::ID - } - - fn build_from_bytes( - &self, - _: Bytes, - _: crate::Role, - ) -> Result { - Ok(UdpProtocolExtension().into()) - } - - fn build_to_extension(&self, _: crate::Role) -> AnyProtocolExtension { - UdpProtocolExtension().into() - } - } -} - -pub mod password { - //! Password protocol extension. - //! - //! Passwords are sent in plain text!! - //! - //! # Example - //! Server: - //! ``` - //! let mut passwords = HashMap::new(); - //! passwords.insert("user1".to_string(), "pw".to_string()); - //! let (mux, fut) = ServerMux::new( - //! rx, - //! tx, - //! 128, - //! Some(&[Box::new(PasswordProtocolExtensionBuilder::new_server(passwords))]) - //! ); - //! ``` - //! - //! Client: - //! ``` - //! let (mux, fut) = ClientMux::new( - //! rx, - //! tx, - //! 128, - //! Some(&[ - //! Box::new(PasswordProtocolExtensionBuilder::new_client( - //! "user1".to_string(), - //! "pw".to_string() - //! )) - //! ]) - //! ); - //! ``` - //! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x02---password-authentication) - - use std::{collections::HashMap, error::Error, fmt::Display, string::FromUtf8Error}; - - use async_trait::async_trait; - use bytes::{Buf, BufMut, Bytes, BytesMut}; - - use crate::{ - ws::{LockedWebSocketWrite, WebSocketRead}, - Role, WispError, - }; - - use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder}; - - #[derive(Debug, Clone)] - /// Password protocol extension. - /// - /// **Passwords are sent in plain text!!** - /// **This extension will panic when encoding if the username's length does not fit within a u8 - /// or the password's length does not fit within a u16.** - pub struct PasswordProtocolExtension { - /// The username to log in with. - /// - /// This string's length must fit within a u8. - pub username: String, - /// The password to log in with. - /// - /// This string's length must fit within a u16. - pub password: String, - role: Role, - } - - impl PasswordProtocolExtension { - /// Password protocol extension ID. - pub const ID: u8 = 0x02; - - /// Create a new password protocol extension for the server. - /// - /// This signifies that the server requires a password. - pub fn new_server() -> Self { - Self { - username: String::new(), - password: String::new(), - role: Role::Server, - } - } - - /// Create a new password protocol extension for the client, with a username and password. - /// - /// The username's length must fit within a u8. The password's length must fit within a - /// u16. - pub fn new_client(username: String, password: String) -> Self { - Self { - username, - password, - role: Role::Client, - } - } - } - - #[async_trait] - impl ProtocolExtension for PasswordProtocolExtension { - fn get_id(&self) -> u8 { - Self::ID - } - - fn get_supported_packets(&self) -> &'static [u8] { - &[] - } - - fn encode(&self) -> Bytes { - match self.role { - Role::Server => Bytes::new(), - Role::Client => { - let username = Bytes::from(self.username.clone().into_bytes()); - let password = Bytes::from(self.password.clone().into_bytes()); - let username_len = u8::try_from(username.len()).expect("username was too long"); - let password_len = - u16::try_from(password.len()).expect("password was too long"); - - let mut bytes = - BytesMut::with_capacity(3 + username_len as usize + password_len as usize); - bytes.put_u8(username_len); - bytes.put_u16_le(password_len); - bytes.extend(username); - bytes.extend(password); - bytes.freeze() - } - } - } - - async fn handle_handshake( - &mut self, - _: &mut dyn WebSocketRead, - _: &LockedWebSocketWrite, - ) -> Result<(), WispError> { - Ok(()) - } - - async fn handle_packet( - &mut self, - _: Bytes, - _: &mut dyn WebSocketRead, - _: &LockedWebSocketWrite, - ) -> Result<(), WispError> { - Ok(()) - } - - fn box_clone(&self) -> Box { - Box::new(self.clone()) - } - } - - #[derive(Debug)] - enum PasswordProtocolExtensionError { - Utf8Error(FromUtf8Error), - InvalidUsername, - InvalidPassword, - } - - impl Display for PasswordProtocolExtensionError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - use PasswordProtocolExtensionError as E; - match self { - E::Utf8Error(e) => write!(f, "{}", e), - E::InvalidUsername => write!(f, "Invalid username"), - E::InvalidPassword => write!(f, "Invalid password"), - } - } - } - - impl Error for PasswordProtocolExtensionError {} - - impl From for WispError { - fn from(value: PasswordProtocolExtensionError) -> Self { - WispError::ExtensionImplError(Box::new(value)) - } - } - - impl From for PasswordProtocolExtensionError { - fn from(value: FromUtf8Error) -> Self { - PasswordProtocolExtensionError::Utf8Error(value) - } - } - - impl From for AnyProtocolExtension { - fn from(value: PasswordProtocolExtension) -> Self { - AnyProtocolExtension(Box::new(value)) - } - } - - /// Password protocol extension builder. - /// - /// **Passwords are sent in plain text!!** - pub struct PasswordProtocolExtensionBuilder { - /// Map of users and their passwords to allow. Only used on server. - pub users: HashMap, - /// Username to authenticate with. Only used on client. - pub username: String, - /// Password to authenticate with. Only used on client. - pub password: String, - } - - impl PasswordProtocolExtensionBuilder { - /// Create a new password protocol extension builder for the server, with a map of users - /// and passwords to allow. - pub fn new_server(users: HashMap) -> Self { - Self { - users, - username: String::new(), - password: String::new(), - } - } - - /// Create a new password protocol extension builder for the client, with a username and - /// password to authenticate with. - pub fn new_client(username: String, password: String) -> Self { - Self { - users: HashMap::new(), - username, - password, - } - } - } - - impl ProtocolExtensionBuilder for PasswordProtocolExtensionBuilder { - fn get_id(&self) -> u8 { - PasswordProtocolExtension::ID - } - - fn build_from_bytes( - &self, - mut payload: Bytes, - role: crate::Role, - ) -> Result { - match role { - Role::Server => { - if payload.remaining() < 3 { - return Err(WispError::PacketTooSmall); - } - - let username_len = payload.get_u8(); - let password_len = payload.get_u16_le(); - if payload.remaining() < (password_len + username_len as u16) as usize { - return Err(WispError::PacketTooSmall); - } - - use PasswordProtocolExtensionError as EError; - let username = - String::from_utf8(payload.copy_to_bytes(username_len as usize).to_vec()) - .map_err(|x| WispError::from(EError::from(x)))?; - let password = - String::from_utf8(payload.copy_to_bytes(password_len as usize).to_vec()) - .map_err(|x| WispError::from(EError::from(x)))?; - - let Some(user) = self.users.iter().find(|x| *x.0 == username) else { - return Err(EError::InvalidUsername.into()); - }; - - if *user.1 != password { - return Err(EError::InvalidPassword.into()); - } - - Ok(PasswordProtocolExtension { - username, - password, - role, - } - .into()) - } - Role::Client => { - Ok(PasswordProtocolExtension::new_client(String::new(), String::new()).into()) - } - } - } - - fn build_to_extension(&self, role: Role) -> AnyProtocolExtension { - match role { - Role::Server => PasswordProtocolExtension::new_server(), - Role::Client => PasswordProtocolExtension::new_client( - self.username.clone(), - self.password.clone(), - ), - } - .into() - } - } -} diff --git a/wisp/src/extensions/mod.rs b/wisp/src/extensions/mod.rs new file mode 100644 index 0000000..8c3ec12 --- /dev/null +++ b/wisp/src/extensions/mod.rs @@ -0,0 +1,106 @@ +//! Wisp protocol extensions. +pub mod password; +pub mod udp; + +use std::ops::{Deref, DerefMut}; + +use async_trait::async_trait; +use bytes::{BufMut, Bytes, BytesMut}; + +use crate::{ + ws::{LockedWebSocketWrite, WebSocketRead}, + Role, WispError, +}; + +/// Type-erased protocol extension that implements Clone. +#[derive(Debug)] +pub struct AnyProtocolExtension(Box); + +impl AnyProtocolExtension { + /// Create a new type-erased protocol extension. + pub fn new(extension: T) -> Self { + Self(Box::new(extension)) + } +} + +impl Deref for AnyProtocolExtension { + type Target = dyn ProtocolExtension; + fn deref(&self) -> &Self::Target { + self.0.deref() + } +} + +impl DerefMut for AnyProtocolExtension { + fn deref_mut(&mut self) -> &mut Self::Target { + self.0.deref_mut() + } +} + +impl Clone for AnyProtocolExtension { + fn clone(&self) -> Self { + Self(self.0.box_clone()) + } +} + +impl From for Bytes { + fn from(value: AnyProtocolExtension) -> Self { + let mut bytes = BytesMut::with_capacity(5); + let payload = value.encode(); + bytes.put_u8(value.get_id()); + bytes.put_u32_le(payload.len() as u32); + bytes.extend(payload); + bytes.freeze() + } +} + +/// A Wisp protocol extension. +/// +/// See [the +/// docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#protocol-extensions). +#[async_trait] +pub trait ProtocolExtension: std::fmt::Debug { + /// Get the protocol extension ID. + fn get_id(&self) -> u8; + /// Get the protocol extension's supported packets. + /// + /// Used to decide whether to call the protocol extension's packet handler. + fn get_supported_packets(&self) -> &'static [u8]; + + /// Encode self into Bytes. + fn encode(&self) -> Bytes; + + /// Handle the handshake part of a Wisp connection. + /// + /// This should be used to send or receive data before any streams are created. + async fn handle_handshake( + &mut self, + read: &mut dyn WebSocketRead, + write: &LockedWebSocketWrite, + ) -> Result<(), WispError>; + + /// Handle receiving a packet. + async fn handle_packet( + &mut self, + packet: Bytes, + read: &mut dyn WebSocketRead, + write: &LockedWebSocketWrite, + ) -> Result<(), WispError>; + + /// Clone the protocol extension. + fn box_clone(&self) -> Box; +} + +/// Trait to build a Wisp protocol extension from a payload. +pub trait ProtocolExtensionBuilder { + /// Get the protocol extension ID. + /// + /// Used to decide whether this builder should be used. + fn get_id(&self) -> u8; + + /// Build a protocol extension from the extension's metadata. + fn build_from_bytes(&self, bytes: Bytes, role: Role) + -> Result; + + /// Build a protocol extension to send to the other side. + fn build_to_extension(&self, role: Role) -> AnyProtocolExtension; +} diff --git a/wisp/src/extensions/password.rs b/wisp/src/extensions/password.rs new file mode 100644 index 0000000..3fe15b3 --- /dev/null +++ b/wisp/src/extensions/password.rs @@ -0,0 +1,276 @@ +//! Password protocol extension. +//! +//! Passwords are sent in plain text!! +//! +//! # Example +//! Server: +//! ``` +//! let mut passwords = HashMap::new(); +//! passwords.insert("user1".to_string(), "pw".to_string()); +//! let (mux, fut) = ServerMux::new( +//! rx, +//! tx, +//! 128, +//! Some(&[Box::new(PasswordProtocolExtensionBuilder::new_server(passwords))]) +//! ); +//! ``` +//! +//! Client: +//! ``` +//! let (mux, fut) = ClientMux::new( +//! rx, +//! tx, +//! 128, +//! Some(&[ +//! Box::new(PasswordProtocolExtensionBuilder::new_client( +//! "user1".to_string(), +//! "pw".to_string() +//! )) +//! ]) +//! ); +//! ``` +//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x02---password-authentication) + +use std::{collections::HashMap, error::Error, fmt::Display, string::FromUtf8Error}; + +use async_trait::async_trait; +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +use crate::{ + ws::{LockedWebSocketWrite, WebSocketRead}, + Role, WispError, +}; + +use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder}; + +#[derive(Debug, Clone)] +/// Password protocol extension. +/// +/// **Passwords are sent in plain text!!** +/// **This extension will panic when encoding if the username's length does not fit within a u8 +/// or the password's length does not fit within a u16.** +pub struct PasswordProtocolExtension { + /// The username to log in with. + /// + /// This string's length must fit within a u8. + pub username: String, + /// The password to log in with. + /// + /// This string's length must fit within a u16. + pub password: String, + role: Role, +} + +impl PasswordProtocolExtension { + /// Password protocol extension ID. + pub const ID: u8 = 0x02; + + /// Create a new password protocol extension for the server. + /// + /// This signifies that the server requires a password. + pub fn new_server() -> Self { + Self { + username: String::new(), + password: String::new(), + role: Role::Server, + } + } + + /// Create a new password protocol extension for the client, with a username and password. + /// + /// The username's length must fit within a u8. The password's length must fit within a + /// u16. + pub fn new_client(username: String, password: String) -> Self { + Self { + username, + password, + role: Role::Client, + } + } +} + +#[async_trait] +impl ProtocolExtension for PasswordProtocolExtension { + fn get_id(&self) -> u8 { + Self::ID + } + + fn get_supported_packets(&self) -> &'static [u8] { + &[] + } + + fn encode(&self) -> Bytes { + match self.role { + Role::Server => Bytes::new(), + Role::Client => { + let username = Bytes::from(self.username.clone().into_bytes()); + let password = Bytes::from(self.password.clone().into_bytes()); + let username_len = u8::try_from(username.len()).expect("username was too long"); + let password_len = u16::try_from(password.len()).expect("password was too long"); + + let mut bytes = + BytesMut::with_capacity(3 + username_len as usize + password_len as usize); + bytes.put_u8(username_len); + bytes.put_u16_le(password_len); + bytes.extend(username); + bytes.extend(password); + bytes.freeze() + } + } + } + + async fn handle_handshake( + &mut self, + _: &mut dyn WebSocketRead, + _: &LockedWebSocketWrite, + ) -> Result<(), WispError> { + Ok(()) + } + + async fn handle_packet( + &mut self, + _: Bytes, + _: &mut dyn WebSocketRead, + _: &LockedWebSocketWrite, + ) -> Result<(), WispError> { + Ok(()) + } + + fn box_clone(&self) -> Box { + Box::new(self.clone()) + } +} + +#[derive(Debug)] +enum PasswordProtocolExtensionError { + Utf8Error(FromUtf8Error), + InvalidUsername, + InvalidPassword, +} + +impl Display for PasswordProtocolExtensionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use PasswordProtocolExtensionError as E; + match self { + E::Utf8Error(e) => write!(f, "{}", e), + E::InvalidUsername => write!(f, "Invalid username"), + E::InvalidPassword => write!(f, "Invalid password"), + } + } +} + +impl Error for PasswordProtocolExtensionError {} + +impl From for WispError { + fn from(value: PasswordProtocolExtensionError) -> Self { + WispError::ExtensionImplError(Box::new(value)) + } +} + +impl From for PasswordProtocolExtensionError { + fn from(value: FromUtf8Error) -> Self { + PasswordProtocolExtensionError::Utf8Error(value) + } +} + +impl From for AnyProtocolExtension { + fn from(value: PasswordProtocolExtension) -> Self { + AnyProtocolExtension(Box::new(value)) + } +} + +/// Password protocol extension builder. +/// +/// **Passwords are sent in plain text!!** +pub struct PasswordProtocolExtensionBuilder { + /// Map of users and their passwords to allow. Only used on server. + pub users: HashMap, + /// Username to authenticate with. Only used on client. + pub username: String, + /// Password to authenticate with. Only used on client. + pub password: String, +} + +impl PasswordProtocolExtensionBuilder { + /// Create a new password protocol extension builder for the server, with a map of users + /// and passwords to allow. + pub fn new_server(users: HashMap) -> Self { + Self { + users, + username: String::new(), + password: String::new(), + } + } + + /// Create a new password protocol extension builder for the client, with a username and + /// password to authenticate with. + pub fn new_client(username: String, password: String) -> Self { + Self { + users: HashMap::new(), + username, + password, + } + } +} + +impl ProtocolExtensionBuilder for PasswordProtocolExtensionBuilder { + fn get_id(&self) -> u8 { + PasswordProtocolExtension::ID + } + + fn build_from_bytes( + &self, + mut payload: Bytes, + role: crate::Role, + ) -> Result { + match role { + Role::Server => { + if payload.remaining() < 3 { + return Err(WispError::PacketTooSmall); + } + + let username_len = payload.get_u8(); + let password_len = payload.get_u16_le(); + if payload.remaining() < (password_len + username_len as u16) as usize { + return Err(WispError::PacketTooSmall); + } + + use PasswordProtocolExtensionError as EError; + let username = + String::from_utf8(payload.copy_to_bytes(username_len as usize).to_vec()) + .map_err(|x| WispError::from(EError::from(x)))?; + let password = + String::from_utf8(payload.copy_to_bytes(password_len as usize).to_vec()) + .map_err(|x| WispError::from(EError::from(x)))?; + + let Some(user) = self.users.iter().find(|x| *x.0 == username) else { + return Err(EError::InvalidUsername.into()); + }; + + if *user.1 != password { + return Err(EError::InvalidPassword.into()); + } + + Ok(PasswordProtocolExtension { + username, + password, + role, + } + .into()) + } + Role::Client => { + Ok(PasswordProtocolExtension::new_client(String::new(), String::new()).into()) + } + } + } + + fn build_to_extension(&self, role: Role) -> AnyProtocolExtension { + match role { + Role::Server => PasswordProtocolExtension::new_server(), + Role::Client => { + PasswordProtocolExtension::new_client(self.username.clone(), self.password.clone()) + } + } + .into() + } +} diff --git a/wisp/src/extensions/udp.rs b/wisp/src/extensions/udp.rs new file mode 100644 index 0000000..068b5eb --- /dev/null +++ b/wisp/src/extensions/udp.rs @@ -0,0 +1,93 @@ +//! UDP protocol extension. +//! +//! # Example +//! ``` +//! let (mux, fut) = ServerMux::new( +//! rx, +//! tx, +//! 128, +//! Some(&[Box::new(UdpProtocolExtensionBuilder())]) +//! ); +//! ``` +//! See [the docs](https://github.com/MercuryWorkshop/wisp-protocol/blob/main/protocol.md#0x01---udp) +use async_trait::async_trait; +use bytes::Bytes; + +use crate::{ + ws::{LockedWebSocketWrite, WebSocketRead}, + WispError, +}; + +use super::{AnyProtocolExtension, ProtocolExtension, ProtocolExtensionBuilder}; + +#[derive(Debug)] +/// UDP protocol extension. +pub struct UdpProtocolExtension(); + +impl UdpProtocolExtension { + /// UDP protocol extension ID. + pub const ID: u8 = 0x01; +} + +#[async_trait] +impl ProtocolExtension for UdpProtocolExtension { + fn get_id(&self) -> u8 { + Self::ID + } + + fn get_supported_packets(&self) -> &'static [u8] { + &[] + } + + fn encode(&self) -> Bytes { + Bytes::new() + } + + async fn handle_handshake( + &mut self, + _: &mut dyn WebSocketRead, + _: &LockedWebSocketWrite, + ) -> Result<(), WispError> { + Ok(()) + } + + async fn handle_packet( + &mut self, + _: Bytes, + _: &mut dyn WebSocketRead, + _: &LockedWebSocketWrite, + ) -> Result<(), WispError> { + Ok(()) + } + + fn box_clone(&self) -> Box { + Box::new(Self()) + } +} + +impl From for AnyProtocolExtension { + fn from(value: UdpProtocolExtension) -> Self { + AnyProtocolExtension(Box::new(value)) + } +} + +/// UDP protocol extension builder. +pub struct UdpProtocolExtensionBuilder(); + +impl ProtocolExtensionBuilder for UdpProtocolExtensionBuilder { + fn get_id(&self) -> u8 { + UdpProtocolExtension::ID + } + + fn build_from_bytes( + &self, + _: Bytes, + _: crate::Role, + ) -> Result { + Ok(UdpProtocolExtension().into()) + } + + fn build_to_extension(&self, _: crate::Role) -> AnyProtocolExtension { + UdpProtocolExtension().into() + } +} diff --git a/wisp/src/fastwebsockets.rs b/wisp/src/fastwebsockets.rs index a2c2f7d..e05de91 100644 --- a/wisp/src/fastwebsockets.rs +++ b/wisp/src/fastwebsockets.rs @@ -3,7 +3,7 @@ use std::ops::Deref; use async_trait::async_trait; use bytes::BytesMut; use fastwebsockets::{ - CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite + CloseCode, FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite, }; use tokio::io::{AsyncRead, AsyncWrite}; @@ -79,6 +79,8 @@ impl crate::ws::WebSocketWrite for WebSocketWrite< } async fn wisp_close(&mut self) -> Result<(), WispError> { - self.write_frame(Frame::close(CloseCode::Normal.into(), b"")).await.map_err(|e| e.into()) + self.write_frame(Frame::close(CloseCode::Normal.into(), b"")) + .await + .map_err(|e| e.into()) } } diff --git a/wisp/src/stream.rs b/wisp/src/stream.rs index 8a074a7..e980bec 100644 --- a/wisp/src/stream.rs +++ b/wisp/src/stream.rs @@ -189,7 +189,10 @@ impl Drop for MuxStreamWrite { if !self.is_closed.load(Ordering::Acquire) { self.is_closed.store(true, Ordering::Release); let (tx, _) = oneshot::channel(); - let _ = self.tx.send(WsEvent::Close(Packet::new_close(self.stream_id, CloseReason::Unknown), tx)); + let _ = self.tx.send(WsEvent::Close( + Packet::new_close(self.stream_id, CloseReason::Unknown), + tx, + )); } } }