diff --git a/server/src/main.rs b/server/src/main.rs index eac6387..0b1f6f3 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -4,12 +4,12 @@ use std::io::Error; use bytes::Bytes; use clap::Parser; use fastwebsockets::{ - upgrade, CloseCode, FragmentCollector, FragmentCollectorRead, Frame, OpCode, Payload, - WebSocket, WebSocketError, + upgrade::{self, UpgradeFut}, CloseCode, FragmentCollector, FragmentCollectorRead, Frame, OpCode, Payload, + WebSocketError, }; use futures_util::{SinkExt, StreamExt, TryFutureExt}; use hyper::{ - body::Incoming, server::conn::http1, service::service_fn, upgrade::Upgraded, Request, Response, + body::Incoming, server::conn::http1, service::service_fn, Request, Response, StatusCode, }; use hyper_util::rt::TokioIo; @@ -54,9 +54,6 @@ struct Cli { /// Whether the server should block ports other than 80 or 443 #[arg(long)] block_non_http: bool, - /// Maximum WebSocket frame size allowed - #[arg(long, short, default_value_t = 64 << 20)] - frame_size: usize, } #[cfg(not(unix))] @@ -143,7 +140,6 @@ async fn main() -> Result<(), Error> { while let Ok((stream, addr)) = socket.accept().await { let prefix = prefix.clone(); tokio::spawn(async move { - let io = TokioIo::new(stream); let service = service_fn(move |res| { accept_http( res, @@ -152,11 +148,10 @@ async fn main() -> Result<(), Error> { opt.block_local, opt.block_udp, opt.block_non_http, - opt.frame_size, ) }); let conn = http1::Builder::new() - .serve_connection(io, service) + .serve_connection(TokioIo::new(stream), service) .with_upgrades(); if let Err(err) = conn.await { println!("failed to serve conn: {:?}", err); @@ -174,7 +169,6 @@ async fn accept_http( block_local: bool, block_udp: bool, block_non_http: bool, - max_size: usize, ) -> Result, WebSocketError> { let uri = req.uri().path().to_string(); if upgrade::is_upgrade_request(&req) @@ -182,17 +176,13 @@ async fn accept_http( { let (res, fut) = upgrade::upgrade(&mut req)?; - let mut ws = fut.await?; - - ws.set_max_message_size(max_size); - if uri.is_empty() { tokio::spawn(async move { - accept_ws(ws, addr.clone(), block_local, block_udp, block_non_http).await + accept_ws(fut, addr.clone(), block_local, block_udp, block_non_http).await }); } else if let Some(uri) = uri.strip_prefix('/').map(|x| x.to_string()) { tokio::spawn(async move { - accept_wsproxy(ws, uri, addr.clone(), block_local, block_non_http).await + accept_wsproxy(fut, uri, addr.clone(), block_local, block_non_http).await }); } @@ -260,13 +250,13 @@ async fn handle_mux(packet: ConnectPacket, mut stream: MuxStream) -> Result>, + ws: UpgradeFut, addr: String, block_local: bool, block_non_http: bool, block_udp: bool, ) -> Result<(), Box> { - let (rx, tx) = ws.split(tokio::io::split); + let (rx, tx) = ws.await?.split(tokio::io::split); let rx = FragmentCollectorRead::new(rx); println!("{:?}: connected", addr); @@ -334,13 +324,13 @@ async fn accept_ws( } async fn accept_wsproxy( - ws: WebSocket>, + ws: UpgradeFut, incoming_uri: String, addr: String, block_local: bool, block_non_http: bool, ) -> Result<(), Box> { - let mut ws_stream = FragmentCollector::new(ws); + let mut ws_stream = FragmentCollector::new(ws.await?); println!("{:?}: connected (wsproxy): {:?}", addr, incoming_uri);