From 7ac0f0808eab25b8e69e9068fce15de4ad82f0e1 Mon Sep 17 00:00:00 2001 From: Victor Adossi Date: Thu, 30 Jan 2025 15:19:52 +0900 Subject: [PATCH] feat: support error context in stream/error operations Signed-off-by: Victor Adossi --- crates/guest-rust/rt/src/async_support.rs | 51 +++++++++--- .../rt/src/async_support/future_support.rs | 74 ++++++++++++++---- .../rt/src/async_support/stream_support.rs | 77 ++++++++++++++----- crates/rust/src/interface.rs | 62 +++++++++------ 4 files changed, 190 insertions(+), 74 deletions(-) diff --git a/crates/guest-rust/rt/src/async_support.rs b/crates/guest-rust/rt/src/async_support.rs index b60cabaf8..51adc778c 100644 --- a/crates/guest-rust/rt/src/async_support.rs +++ b/crates/guest-rust/rt/src/async_support.rs @@ -53,6 +53,9 @@ pub enum Handle { LocalClosed, Read, Write, + // Local end is closed with an error + // NOTE: this is only valid for write ends + WriteClosedErr(Option), } /// The current task being polled (or null if none). @@ -176,7 +179,7 @@ pub async unsafe fn await_result( STATUS_RETURNED | STATUS_DONE => { alloc::dealloc(params, params_layout); } - _ => unreachable!(), + _ => unreachable!("unrecognized async call status"), } } @@ -187,13 +190,40 @@ mod results { pub const CANCELED: u32 = 0; } +/// Result of awaiting a asynchronous read or write +#[doc(hidden)] +pub enum AsyncWaitResult { + /// Used when a value was successfully sent or received + Values(usize), + /// Represents a successful but error-indicating read + Error(u32), + /// Represents a failed read (closed, canceled, etc) + End, +} + +impl AsyncWaitResult { + /// Interpret the results from an async operation that is known to *not* be blocked + fn from_nonblocked_async_result(v: u32) -> Self { + match v { + results::CLOSED | results::CANCELED => Self::End, + v => { + if v & results::CLOSED != 0 { + Self::Error(v & !results::CLOSED) + } else { + Self::Values(v as usize) + } + } + } + } +} + /// Await the completion of a future read or write. #[doc(hidden)] pub async unsafe fn await_future_result( import: unsafe extern "C" fn(u32, *mut u8) -> u32, future: u32, address: *mut u8, -) -> bool { +) -> AsyncWaitResult { let result = import(future, address); match result { results::BLOCKED => { @@ -201,12 +231,9 @@ pub async unsafe fn await_future_result( (*CURRENT).todo += 1; let (tx, rx) = oneshot::channel(); CALLS.insert(future as _, tx); - let v = rx.await.unwrap(); - v == 1 + AsyncWaitResult::from_nonblocked_async_result(rx.await.unwrap()) } - results::CLOSED | results::CANCELED => false, - 1 => true, - _ => unreachable!(), + v => AsyncWaitResult::from_nonblocked_async_result(v), } } @@ -217,7 +244,7 @@ pub async unsafe fn await_stream_result( stream: u32, address: *mut u8, count: u32, -) -> Option { +) -> AsyncWaitResult { let result = import(stream, address, count); match result { results::BLOCKED => { @@ -227,13 +254,12 @@ pub async unsafe fn await_stream_result( CALLS.insert(stream as _, tx); let v = rx.await.unwrap(); if let results::CLOSED | results::CANCELED = v { - None + AsyncWaitResult::End } else { - Some(usize::try_from(v).unwrap()) + AsyncWaitResult::Values(usize::try_from(v).unwrap()) } } - results::CLOSED | results::CANCELED => None, - v => Some(usize::try_from(v).unwrap()), + v => AsyncWaitResult::from_nonblocked_async_result(v), } } @@ -310,6 +336,7 @@ pub unsafe fn callback(ctx: *mut u8, event0: i32, event1: i32, event2: i32) -> i } /// Represents the Component Model `error-context` type. +#[derive(PartialEq, Eq)] pub struct ErrorContext { handle: u32, } diff --git a/crates/guest-rust/rt/src/async_support/future_support.rs b/crates/guest-rust/rt/src/async_support/future_support.rs index 8477ec481..4a082ae47 100644 --- a/crates/guest-rust/rt/src/async_support/future_support.rs +++ b/crates/guest-rust/rt/src/async_support/future_support.rs @@ -1,6 +1,7 @@ extern crate std; use { + super::ErrorContext, super::Handle, futures::{ channel::oneshot, @@ -20,10 +21,10 @@ use { #[doc(hidden)] pub struct FutureVtable { pub write: fn(future: u32, value: T) -> Pin>>, - pub read: fn(future: u32) -> Pin>>>, + pub read: fn(future: u32) -> Pin>>>>, pub cancel_write: fn(future: u32), pub cancel_read: fn(future: u32), - pub close_writable: fn(future: u32), + pub close_writable: fn(future: u32, err_ctx: u32), pub close_readable: fn(future: u32), } @@ -78,7 +79,8 @@ impl CancelableWrite { Handle::LocalOpen | Handle::LocalWaiting(_) | Handle::Read - | Handle::LocalClosed => unreachable!(), + | Handle::LocalClosed + | Handle::WriteClosedErr(_) => unreachable!(), Handle::LocalReady(..) => { entry.insert(Handle::LocalOpen); } @@ -126,7 +128,9 @@ impl FutureWriter { Poll::Pending } Handle::LocalReady(..) => Poll::Pending, - Handle::LocalClosed => Poll::Ready(()), + Handle::LocalClosed | Handle::WriteClosedErr(_) => { + Poll::Ready(()) + } Handle::LocalWaiting(_) | Handle::Read | Handle::Write => { unreachable!() } @@ -141,13 +145,29 @@ impl FutureWriter { _ = tx.send(Box::new(v)); Box::pin(future::ready(())) } - Handle::LocalClosed => Box::pin(future::ready(())), + Handle::LocalClosed | Handle::WriteClosedErr(_) => Box::pin(future::ready(())), Handle::Read | Handle::LocalReady(..) => unreachable!(), Handle::Write => Box::pin((vtable.write)(handle, v).map(drop)), }, }), } } + + /// Close the writer with an error that will be returned as the last value + /// + /// Note that this error is not sent immediately, but only when the + /// writer closes, which is normally a result of a `drop()` + pub fn close_with_error(&mut self, err: ErrorContext) { + super::with_entry(self.handle, move |entry| match entry { + Entry::Vacant(_) => unreachable!(), + Entry::Occupied(mut entry) => match entry.get_mut() { + // Regardless of current state, put the writer into a closed with error state + _ => { + entry.insert(Handle::WriteClosedErr(Some(err))); + } + }, + }); + } } impl Drop for FutureWriter { @@ -161,7 +181,18 @@ impl Drop for FutureWriter { Handle::Read => unreachable!(), Handle::Write | Handle::LocalClosed => { entry.remove(); - (self.vtable.close_writable)(self.handle); + (self.vtable.close_writable)(self.handle, 0); + } + Handle::WriteClosedErr(_) => { + match entry.remove() { + Handle::WriteClosedErr(None) => { + (self.vtable.close_writable)(self.handle, 0); + } + Handle::WriteClosedErr(Some(err_ctx)) => { + (self.vtable.close_writable)(self.handle, err_ctx.handle()); + } + _ => unreachable!(), + } } }, }); @@ -171,13 +202,13 @@ impl Drop for FutureWriter { /// Represents a read operation which may be canceled prior to completion. pub struct CancelableRead { reader: Option>, - future: Pin>>>, + future: Pin>>>>, } impl Future for CancelableRead { - type Output = Option; + type Output = Option>; - fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll>> { let me = self.get_mut(); match me.future.poll_unpin(cx) { Poll::Ready(v) => { @@ -206,7 +237,8 @@ impl CancelableRead { Handle::LocalOpen | Handle::LocalReady(..) | Handle::Write - | Handle::LocalClosed => unreachable!(), + | Handle::LocalClosed + | Handle::WriteClosedErr(_) => unreachable!(), Handle::LocalWaiting(_) => { entry.insert(Handle::LocalOpen); } @@ -262,7 +294,8 @@ impl FutureReader { | Handle::LocalOpen | Handle::LocalReady(..) | Handle::LocalWaiting(_) - | Handle::LocalClosed => { + | Handle::LocalClosed + | Handle::WriteClosedErr(_) => { unreachable!() } }, @@ -286,7 +319,10 @@ impl FutureReader { Handle::Read | Handle::LocalClosed => { entry.remove(); } - Handle::LocalReady(..) | Handle::LocalWaiting(_) | Handle::Write => unreachable!(), + Handle::LocalReady(..) + | Handle::LocalWaiting(_) + | Handle::Write + | Handle::WriteClosedErr(_) => unreachable!(), }, }); @@ -295,7 +331,7 @@ impl FutureReader { } impl IntoFuture for FutureReader { - type Output = Option; + type Output = Option>; type IntoFuture = CancelableRead; /// Convert this object into a `Future` which will resolve when a value is @@ -308,8 +344,10 @@ impl IntoFuture for FutureReader { reader: Some(self), future: super::with_entry(handle, |entry| match entry { Entry::Vacant(_) => unreachable!(), - Entry::Occupied(mut entry) => match entry.get() { - Handle::Write | Handle::LocalWaiting(_) => unreachable!(), + Entry::Occupied(mut entry) => match entry.get_mut() { + Handle::Write | Handle::LocalWaiting(_) => { + unreachable!() + } Handle::Read => Box::pin(async move { (vtable.read)(handle).await }) as Pin>>, Handle::LocalOpen => { @@ -318,6 +356,10 @@ impl IntoFuture for FutureReader { Box::pin(async move { rx.await.ok().map(|v| *v.downcast().unwrap()) }) } Handle::LocalClosed => Box::pin(future::ready(None)), + Handle::WriteClosedErr(err_ctx) => match err_ctx.take() { + None => Box::pin(future::ready(None)), + Some(err_ctx) => Box::pin(future::ready(Some(Err(err_ctx)))), + }, Handle::LocalReady(..) => { let Handle::LocalReady(v, waker) = entry.insert(Handle::LocalClosed) else { unreachable!() @@ -353,7 +395,7 @@ impl Drop for FutureReader { entry.remove(); (self.vtable.close_readable)(handle); } - Handle::Write => unreachable!(), + Handle::Write | Handle::WriteClosedErr(_) => unreachable!(), }, }); } diff --git a/crates/guest-rust/rt/src/async_support/stream_support.rs b/crates/guest-rust/rt/src/async_support/stream_support.rs index d80b96d60..52fde7d09 100644 --- a/crates/guest-rust/rt/src/async_support/stream_support.rs +++ b/crates/guest-rust/rt/src/async_support/stream_support.rs @@ -1,6 +1,7 @@ extern crate std; use { + super::ErrorContext, super::Handle, futures::{ channel::oneshot, @@ -33,10 +34,10 @@ pub struct StreamVtable { pub read: fn( future: u32, values: &mut [MaybeUninit], - ) -> Pin> + '_>>, + ) -> Pin>> + '_>>, pub cancel_write: fn(future: u32), pub cancel_read: fn(future: u32), - pub close_writable: fn(future: u32), + pub close_writable: fn(future: u32, err_ctx: u32), pub close_readable: fn(future: u32), } @@ -54,7 +55,8 @@ impl Drop for CancelWriteOnDrop { Handle::LocalOpen | Handle::LocalWaiting(_) | Handle::Read - | Handle::LocalClosed => unreachable!(), + | Handle::LocalClosed + | Handle::WriteClosedErr(_) => unreachable!(), Handle::LocalReady(..) => { entry.insert(Handle::LocalOpen); } @@ -89,6 +91,22 @@ impl StreamWriter { assert!(self.future.is_some()); self.future = None; } + + /// Close the writer with an error that will be returned as the last value + /// + /// Note that this error is not sent immediately, but only when the + /// writer closes, which is normally a result of a `drop()` + pub fn close_with_error(self, err: ErrorContext) { + super::with_entry(self.handle, move |entry| match entry { + Entry::Vacant(_) => unreachable!(), + Entry::Occupied(mut entry) => match entry.get_mut() { + _ => { + // Note: the impending drop after this function runs should trigger + entry.insert(Handle::WriteClosedErr(Some(err))); + } + }, + }); + } } impl fmt::Debug for StreamWriter { @@ -147,7 +165,7 @@ impl Sink> for StreamWriter { } } Handle::LocalReady(..) => Poll::Pending, - Handle::LocalClosed => { + Handle::LocalClosed | Handle::WriteClosedErr(_) => { cancel_on_drop.take().unwrap().handle = None; Poll::Ready(()) } @@ -164,7 +182,7 @@ impl Sink> for StreamWriter { }; _ = tx.send(Box::new(item)); } - Handle::LocalClosed => (), + Handle::LocalClosed | Handle::WriteClosedErr(_) => (), Handle::Read | Handle::LocalReady(..) => unreachable!(), Handle::Write => { let handle = self.handle; @@ -206,7 +224,12 @@ impl Drop for StreamWriter { Handle::Read => unreachable!(), Handle::Write | Handle::LocalClosed => { entry.remove(); - (self.vtable.close_writable)(self.handle); + (self.vtable.close_writable)(self.handle, 0); + } + Handle::WriteClosedErr(err) => { + let err_ctx = err.take().as_ref().map(ErrorContext::handle).unwrap_or(0); + entry.remove(); + (self.vtable.close_writable)(self.handle, err_ctx); } }, }); @@ -227,7 +250,8 @@ impl Drop for CancelReadOnDrop { Handle::LocalOpen | Handle::LocalReady(..) | Handle::Write - | Handle::LocalClosed => unreachable!(), + | Handle::LocalClosed + | Handle::WriteClosedErr(_) => unreachable!(), Handle::LocalWaiting(_) => { entry.insert(Handle::LocalOpen); } @@ -241,7 +265,7 @@ impl Drop for CancelReadOnDrop { /// Represents the readable end of a Component Model `stream`. pub struct StreamReader { handle: AtomicU32, - future: Option>> + 'static>>>, + future: Option, ErrorContext>>> + 'static>>>, vtable: &'static StreamVtable, } @@ -287,7 +311,8 @@ impl StreamReader { | Handle::LocalOpen | Handle::LocalReady(..) | Handle::LocalWaiting(_) - | Handle::LocalClosed => { + | Handle::LocalClosed + | Handle::WriteClosedErr(_) => { unreachable!() } }, @@ -312,7 +337,10 @@ impl StreamReader { Handle::Read | Handle::LocalClosed => { entry.remove(); } - Handle::LocalReady(..) | Handle::LocalWaiting(_) | Handle::Write => unreachable!(), + Handle::LocalReady(..) + | Handle::LocalWaiting(_) + | Handle::Write + | Handle::WriteClosedErr(_) => unreachable!(), }, }); @@ -321,7 +349,7 @@ impl StreamReader { } impl Stream for StreamReader { - type Item = Vec; + type Item = Result, ErrorContext>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { let me = self.get_mut(); @@ -331,8 +359,10 @@ impl Stream for StreamReader { me.handle.load(Relaxed), |entry| match entry { Entry::Vacant(_) => unreachable!(), - Entry::Occupied(mut entry) => match entry.get() { - Handle::Write | Handle::LocalWaiting(_) => unreachable!(), + Entry::Occupied(mut entry) => match entry.get_mut() { + Handle::Write | Handle::LocalWaiting(_) => { + unreachable!() + } Handle::Read => { let handle = me.handle.load(Relaxed); let vtable = me.vtable; @@ -345,15 +375,16 @@ impl Stream for StreamReader { .take(ceiling(64 * 1024, mem::size_of::().max(1))) .collect::>(); - let result = - if let Some(count) = (vtable.read)(handle, &mut buffer).await { + let result = match (vtable.read)(handle, &mut buffer).await { + Some(Ok(count)) => { buffer.truncate(count); - Some(unsafe { + Some(Ok(unsafe { mem::transmute::>, Vec>(buffer) - }) - } else { - None - }; + })) + } + Some(Err(err)) => Some(Err(err)), + None => None, + }; cancel_on_drop.handle = None; drop(cancel_on_drop); result @@ -375,6 +406,10 @@ impl Stream for StreamReader { }) } Handle::LocalClosed => Box::pin(future::ready(None)), + Handle::WriteClosedErr(err_ctx) => match err_ctx.take() { + None => Box::pin(future::ready(None)), + Some(err_ctx) => Box::pin(future::ready(Some(Err(err_ctx)))), + }, Handle::LocalReady(..) => { let Handle::LocalReady(v, waker) = entry.insert(Handle::LocalOpen) else { @@ -422,7 +457,7 @@ impl Drop for StreamReader { entry.remove(); (self.vtable.close_readable)(handle); } - Handle::Write => unreachable!(), + Handle::Write | Handle::WriteClosedErr(_) => unreachable!(), }, }); } diff --git a/crates/rust/src/interface.rs b/crates/rust/src/interface.rs index b7bf24f12..d39d242de 100644 --- a/crates/rust/src/interface.rs +++ b/crates/rust/src/interface.rs @@ -560,11 +560,15 @@ pub mod vtable{ordinal} {{ fn wit_import(_: u32, _: *mut u8) -> u32; }} - unsafe {{ {async_support}::await_future_result(wit_import, future, address).await }} + match unsafe {{ {async_support}::await_future_result(wit_import, future, address).await }} {{ + {async_support}::AsyncWaitResult::Values(_) => true, + {async_support}::AsyncWaitResult::End => false, + {async_support}::AsyncWaitResult::Error(_) => unreachable!("received error while performing write"), + }} }}) }} - fn read(future: u32) -> ::core::pin::Pin<{box_}>>> {{ + fn read(future: u32) -> ::core::pin::Pin<{box_}>>>> {{ {box_}::pin(async move {{ struct Buffer([::core::mem::MaybeUninit::; {size}]); let mut buffer = Buffer([::core::mem::MaybeUninit::uninit(); {size}]); @@ -582,11 +586,15 @@ pub mod vtable{ordinal} {{ fn wit_import(_: u32, _: *mut u8) -> u32; }} - if unsafe {{ {async_support}::await_future_result(wit_import, future, address).await }} {{ - {lift} - Some(value) - }} else {{ - None + match unsafe {{ {async_support}::await_future_result(wit_import, future, address).await }} {{ + {async_support}::AsyncWaitResult::Values(v) => {{ + {lift} + Some(Ok(value)) + }}, + {async_support}::AsyncWaitResult::Error(e) => {{ + Some(Err({async_support}::ErrorContext::from_handle(e))) + }}, + {async_support}::AsyncWaitResult::End => None, }} }}) }} @@ -625,7 +633,7 @@ pub mod vtable{ordinal} {{ }} }} - fn close_writable(writer: u32) {{ + fn close_writable(writer: u32, err_ctx: u32) {{ #[cfg(not(target_arch = "wasm32"))] {{ unreachable!(); @@ -638,7 +646,7 @@ pub mod vtable{ordinal} {{ #[link_name = "[future-close-writable-{index}]{func_name}"] fn drop(_: u32, _: u32); }} - unsafe {{ drop(writer, 0) }} + unsafe {{ drop(writer, err_ctx) }} }} }} @@ -790,26 +798,28 @@ pub mod vtable{ordinal} {{ let mut total = 0; while total < values.len() {{ - let count = unsafe {{ + + match unsafe {{ {async_support}::await_stream_result( wit_import, stream, address.add(total * {size}), u32::try_from(values.len()).unwrap() ).await - }}; - - if let Some(count) = count {{ - total += count; - }} else {{ - break + }} {{ + {async_support}::AsyncWaitResult::Values(count) => total += count, + {async_support}::AsyncWaitResult::Error(_) => unreachable!("encountered error during write"), + {async_support}::AsyncWaitResult::End => break, }} }} total }}) }} - fn read(stream: u32, values: &mut [::core::mem::MaybeUninit::<{name}>]) -> ::core::pin::Pin<{box_}> + '_>> {{ + fn read( + stream: u32, + values: &mut [::core::mem::MaybeUninit::<{name}>] + ) -> ::core::pin::Pin<{box_}>> + '_>> {{ {box_}::pin(async move {{ {lift_address} @@ -825,19 +835,21 @@ pub mod vtable{ordinal} {{ fn wit_import(_: u32, _: *mut u8, _: u32) -> u32; }} - let count = unsafe {{ + match unsafe {{ {async_support}::await_stream_result( wit_import, stream, address, u32::try_from(values.len()).unwrap() ).await - }}; - #[allow(unused)] - if let Some(count) = count {{ - {lift} + }} {{ + {async_support}::AsyncWaitResult::Values(count) => {{ + {lift} + Some(Ok(count)) + }}, + {async_support}::AsyncWaitResult::Error(e) => Some(Err({async_support}::ErrorContext::from_handle(e))), + {async_support}::AsyncWaitResult::End => None, }} - count }}) }} @@ -875,7 +887,7 @@ pub mod vtable{ordinal} {{ }} }} - fn close_writable(writer: u32) {{ + fn close_writable(writer: u32, err_ctx: u32) {{ #[cfg(not(target_arch = "wasm32"))] {{ unreachable!(); @@ -888,7 +900,7 @@ pub mod vtable{ordinal} {{ #[link_name = "[stream-close-writable-{index}]{func_name}"] fn drop(_: u32, _: u32); }} - unsafe {{ drop(writer, 0) }} + unsafe {{ drop(writer, err_ctx) }} }} }}