diff --git a/Cargo.toml b/Cargo.toml index 260c4acb..8015de54 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,7 +37,7 @@ native-tls = { version = "0.2", optional = true } url = { version = "2.1.0", optional = true } ## WS async-std = { version = "1.5.0", optional = true } -soketto = { version = "0.3.2", optional = true } +soketto = { version = "0.4.1", optional = true } [dev-dependencies] # For examples diff --git a/src/transports/ws.rs b/src/transports/ws.rs index 09e2ce5a..a14136dd 100644 --- a/src/transports/ws.rs +++ b/src/transports/ws.rs @@ -12,12 +12,11 @@ use crate::{BatchTransport, DuplexTransport, Error, RequestId, Transport}; use futures::channel::{mpsc, oneshot}; use futures::{ task::{Context, Poll}, - Future, FutureExt, StreamExt, + Future, FutureExt, Stream, StreamExt, }; use async_std::net::TcpStream; use soketto::connection; -use soketto::data::Incoming; use soketto::handshake::{Client, ServerResponse}; impl From for Error { @@ -84,8 +83,7 @@ impl WsServerTask { mut subscriptions, } = self; - let receiver = connection::into_stream(receiver); - let receiver = receiver.fuse(); + let receiver = as_data_stream(receiver).fuse(); let requests = requests.fuse(); pin_mut!(receiver); pin_mut!(requests); @@ -116,9 +114,9 @@ impl WsServerTask { } None => {} }, - message = receiver.next() => match message { - Some(Ok(message)) => { - handle_message(message, &subscriptions, &mut pending); + res = receiver.next() => match res { + Some(Ok(data)) => { + handle_message(&data, &subscriptions, &mut pending); }, Some(Err(e)) => { log::error!("WS connection error: {:?}", e); @@ -132,60 +130,67 @@ impl WsServerTask { } } +fn as_data_stream( + receiver: soketto::connection::Receiver, +) -> impl Stream, soketto::connection::Error>> { + futures::stream::unfold(receiver, |mut receiver| async move { + let mut data = Vec::new(); + Some(match receiver.receive_data(&mut data).await { + Ok(_) => (Ok(data), receiver), + Err(e) => (Err(e), receiver), + }) + }) +} + fn handle_message( - message: Incoming, + data: &[u8], subscriptions: &BTreeMap, pending: &mut BTreeMap, ) { - log::trace!("Message received: {:?}", message); - match message { - Incoming::Pong(_) => {} - Incoming::Data(t) => { - if let Ok(notification) = helpers::to_notification_from_slice(t.as_ref()) { - if let rpc::Params::Map(params) = notification.params { - let id = params.get("subscription"); - let result = params.get("result"); - - if let (Some(&rpc::Value::String(ref id)), Some(result)) = (id, result) { - let id: SubscriptionId = id.clone().into(); - if let Some(stream) = subscriptions.get(&id) { - if let Err(e) = stream.unbounded_send(result.clone()) { - log::error!("Error sending notification: {:?} (id: {:?}", e, id); - } - } else { - log::warn!("Got notification for unknown subscription (id: {:?})", id); - } - } else { - log::error!("Got unsupported notification (id: {:?})", id); + log::trace!("Message received: {:?}", data); + if let Ok(notification) = helpers::to_notification_from_slice(data) { + if let rpc::Params::Map(params) = notification.params { + let id = params.get("subscription"); + let result = params.get("result"); + + if let (Some(&rpc::Value::String(ref id)), Some(result)) = (id, result) { + let id: SubscriptionId = id.clone().into(); + if let Some(stream) = subscriptions.get(&id) { + if let Err(e) = stream.unbounded_send(result.clone()) { + log::error!("Error sending notification: {:?} (id: {:?}", e, id); } + } else { + log::warn!("Got notification for unknown subscription (id: {:?})", id); } } else { - let response = helpers::to_response_from_slice(t.as_ref()); - let outputs = match response { - Ok(rpc::Response::Single(output)) => vec![output], - Ok(rpc::Response::Batch(outputs)) => outputs, - _ => vec![], - }; - - let id = match outputs.get(0) { - Some(&rpc::Output::Success(ref success)) => success.id.clone(), - Some(&rpc::Output::Failure(ref failure)) => failure.id.clone(), - None => rpc::Id::Num(0), - }; - - if let rpc::Id::Num(num) = id { - if let Some(request) = pending.remove(&(num as usize)) { - log::trace!("Responding to (id: {:?}) with {:?}", num, outputs); - if let Err(err) = request.send(helpers::to_results_from_outputs(outputs)) { - log::warn!("Sending a response to deallocated channel: {:?}", err); - } - } else { - log::warn!("Got response for unknown request (id: {:?})", num); - } - } else { - log::warn!("Got unsupported response (id: {:?})", id); + log::error!("Got unsupported notification (id: {:?})", id); + } + } + } else { + let response = helpers::to_response_from_slice(data); + let outputs = match response { + Ok(rpc::Response::Single(output)) => vec![output], + Ok(rpc::Response::Batch(outputs)) => outputs, + _ => vec![], + }; + + let id = match outputs.get(0) { + Some(&rpc::Output::Success(ref success)) => success.id.clone(), + Some(&rpc::Output::Failure(ref failure)) => failure.id.clone(), + None => rpc::Id::Num(0), + }; + + if let rpc::Id::Num(num) = id { + if let Some(request) = pending.remove(&(num as usize)) { + log::trace!("Responding to (id: {:?}) with {:?}", num, outputs); + if let Err(err) = request.send(helpers::to_results_from_outputs(outputs)) { + log::warn!("Sending a response to deallocated channel: {:?}", err); } + } else { + log::warn!("Got response for unknown request (id: {:?})", num); } + } else { + log::warn!("Got unsupported response (id: {:?})", id); } } } @@ -391,10 +396,11 @@ mod tests { server.send_response(&accept).await.unwrap(); let (mut sender, mut receiver) = server.into_builder().finish(); loop { - match receiver.receive_data().await { - Ok(data) if data.is_text() => { + let mut data = Vec::new(); + match receiver.receive_data(&mut data).await { + Ok(data_type) if data_type.is_text() => { assert_eq!( - std::str::from_utf8(data.as_ref()), + std::str::from_utf8(&data), Ok(r#"{"jsonrpc":"2.0","method":"eth_accounts","params":["1"],"id":1}"#) ); sender