Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion reflectapi-demo/clients/rust/generated/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ edition = "2021"
workspace = true

[dependencies]
reflectapi = { workspace = true, features = ["rt", "reqwest", "chrono"] }
reflectapi = { workspace = true, features = ["rt", "rt-sse", "reqwest", "chrono"] }
chrono = { version = "0.4.37", features = ["serde"] }
tracing = "0.1"
serde = { version = "1.0.218", features = ["derive"] }
Expand Down
103 changes: 70 additions & 33 deletions reflectapi/src/rt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,52 @@ impl<AE: std::error::Error + 'static, NE: std::error::Error + 'static> std::erro

pub type BoxStream<T> = Pin<Box<dyn Stream<Item = T> + Send + 'static>>;

pub type StreamResponse<T, AE, NE> = Result<BoxStream<Result<T, Error<AE, NE>>>, Error<AE, NE>>;
/// Error type for individual stream items.
///
/// Unlike [`Error`], this does not include an `Application` variant because
/// application-level errors can only occur during the initial request/response
/// cycle (stream creation), not per-item during streaming.
pub enum StreamItemError<NE> {
Network(NE),
Protocol {
info: String,
stage: ProtocolErrorStage,
},
}

impl<NE: core::fmt::Debug> core::fmt::Debug for StreamItemError<NE> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
StreamItemError::Network(err) => write!(f, "network error: {err:?}"),
StreamItemError::Protocol { info, stage } => {
write!(f, "protocol error: {info} at {stage:?}")
}
}
}
}

impl<NE: core::fmt::Display> core::fmt::Display for StreamItemError<NE> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
StreamItemError::Network(err) => write!(f, "network error: {err}"),
StreamItemError::Protocol { info, stage } => {
write!(f, "protocol error: {info} at {stage}")
}
}
}
}

impl<NE: std::error::Error + 'static> std::error::Error for StreamItemError<NE> {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
StreamItemError::Network(err) => Some(err),
StreamItemError::Protocol { .. } => None,
}
}
}

pub type StreamResponse<T, AE, NE> =
Result<BoxStream<Result<T, StreamItemError<NE>>>, Error<AE, NE>>;

pub enum ProtocolErrorStage {
SerializeRequestBody,
Expand Down Expand Up @@ -216,6 +261,7 @@ where
}
}

#[cfg(feature = "rt-sse")]
fn __serialize_headers_for_stream<H: serde::Serialize>(
headers: H,
) -> Result<http::HeaderMap, (String, ProtocolErrorStage)> {
Expand Down Expand Up @@ -256,12 +302,13 @@ fn __serialize_headers_for_stream<H: serde::Serialize>(
}

#[doc(hidden)]
#[cfg(feature = "rt-sse")]
pub async fn __stream_request_impl<C, I, H, O, E>(
client: &C,
url: Url,
body: I,
headers: H,
) -> Result<BoxStream<Result<O, Error<E, C::Error>>>, Error<E, C::Error>>
) -> Result<BoxStream<Result<O, StreamItemError<C::Error>>>, Error<E, C::Error>>
where
C: Client,
C::Error: Send + 'static,
Expand All @@ -270,6 +317,11 @@ where
O: serde::de::DeserializeOwned + Send + 'static,
E: serde::de::DeserializeOwned + Send + 'static,
{
use futures_util::StreamExt;
use sseer::event_stream::EventStream;
use sseer::json_stream::JsonStream;
use sseer::{errors::EventStreamError, json_stream::JsonStreamError};

let body = serde_json::to_vec(&body).map_err(|e| Error::Protocol {
info: e.to_string(),
stage: ProtocolErrorStage::SerializeRequestBody,
Expand All @@ -284,40 +336,24 @@ where
.map_err(Error::Network)?;

if status.is_success() {
#[cfg(feature = "rt-sse")]
{
use futures_util::StreamExt;
use sseer::event_stream::EventStream;
use sseer::json_stream::JsonStream;
use sseer::{errors::EventStreamError, json_stream::JsonStreamError};

let event_stream = EventStream::new(byte_stream);
let json_stream = JsonStream::<O, _>::new_default(event_stream);
let stream = json_stream.map(|item| {
item.map_err(|err| match err {
JsonStreamError::Stream(err) => match err {
EventStreamError::Transport(err) => Error::Network(err),
EventStreamError::Utf8Error(err) => Error::Protocol {
info: err.to_string(),
stage: ProtocolErrorStage::DeserializeResponseBody(bytes::Bytes::new()),
},
},
JsonStreamError::Deserialize(err) => Error::Protocol {
let event_stream = EventStream::new(byte_stream);
let json_stream = JsonStream::<O, _>::new_default(event_stream);
let stream = json_stream.map(|item| {
item.map_err(|err| match err {
JsonStreamError::Stream(err) => match err {
EventStreamError::Transport(err) => StreamItemError::Network(err),
EventStreamError::Utf8Error(err) => StreamItemError::Protocol {
info: err.to_string(),
stage: ProtocolErrorStage::DeserializeResponseBody(bytes::Bytes::new()),
},
})
});
return Ok(Box::pin(stream));
}

#[cfg(not(feature = "rt-sse"))]
{
return Err(Error::Protocol {
info: "SSE streaming requires the 'rt-sse' feature flag".to_string(),
stage: ProtocolErrorStage::DeserializeResponseBody(bytes::Bytes::new()),
});
}
},
JsonStreamError::Deserialize(err) => StreamItemError::Protocol {
info: err.to_string(),
stage: ProtocolErrorStage::DeserializeResponseBody(bytes::Bytes::new()),
},
})
});
return Ok(Box::pin(stream));
}

let body = __collect_byte_stream(byte_stream)
Expand All @@ -333,6 +369,7 @@ where
}
}

#[cfg(feature = "rt-sse")]
async fn __collect_byte_stream<E>(
stream: Pin<Box<dyn Stream<Item = Result<bytes::Bytes, E>> + Send>>,
) -> Result<bytes::Bytes, E> {
Expand Down
Loading