diff --git a/crates/bindings-macro/src/lib.rs b/crates/bindings-macro/src/lib.rs index cd85b80ef65..ee0e0cc020d 100644 --- a/crates/bindings-macro/src/lib.rs +++ b/crates/bindings-macro/src/lib.rs @@ -8,6 +8,7 @@ // // (private documentation for the macro authors is totally fine here and you SHOULD write that!) +mod procedure; mod reducer; mod sats; mod table; @@ -104,6 +105,14 @@ mod sym { } } +#[proc_macro_attribute] +pub fn procedure(args: StdTokenStream, item: StdTokenStream) -> StdTokenStream { + cvt_attr::(args, item, quote!(), |args, original_function| { + let args = procedure::ProcedureArgs::parse(args)?; + procedure::procedure_impl(args, original_function) + }) +} + #[proc_macro_attribute] pub fn reducer(args: StdTokenStream, item: StdTokenStream) -> StdTokenStream { cvt_attr::(args, item, quote!(), |args, original_function| { diff --git a/crates/bindings-macro/src/procedure.rs b/crates/bindings-macro/src/procedure.rs new file mode 100644 index 00000000000..8d549609370 --- /dev/null +++ b/crates/bindings-macro/src/procedure.rs @@ -0,0 +1,102 @@ +use crate::reducer::{assert_only_lifetime_generics, extract_typed_args}; +use crate::sym; +use crate::util::{check_duplicate, ident_to_litstr, match_meta}; +use proc_macro2::TokenStream; +use quote::quote; +use syn::parse::Parser as _; +use syn::{ItemFn, LitStr}; + +#[derive(Default)] +pub(crate) struct ProcedureArgs { + name: Option, +} + +impl ProcedureArgs { + pub(crate) fn parse(input: TokenStream) -> syn::Result { + let mut args = Self::default(); + syn::meta::parser(|meta| { + match_meta!(match meta { + sym::name => { + check_duplicate(&args.name, &meta)?; + args.name = Some(meta.value()?.parse()?); + } + }); + Ok(()) + }) + .parse2(input)?; + Ok(args) + } +} + +pub(crate) fn procedure_impl(args: ProcedureArgs, original_function: &ItemFn) -> syn::Result { + let func_name = &original_function.sig.ident; + let vis = &original_function.vis; + + let procedure_name = args.name.unwrap_or_else(|| ident_to_litstr(func_name)); + + assert_only_lifetime_generics(original_function, "procedures")?; + + let typed_args = extract_typed_args(original_function)?; + + // Extract all function parameter names. + let opt_arg_names = typed_args.iter().map(|arg| { + if let syn::Pat::Ident(i) = &*arg.pat { + let name = i.ident.to_string(); + quote!(Some(#name)) + } else { + quote!(None) + } + }); + + let arg_tys = typed_args.iter().map(|arg| arg.ty.as_ref()).collect::>(); + let first_arg_ty = arg_tys.first().into_iter(); + let rest_arg_tys = arg_tys.iter().skip(1); + + // Extract the return type. + let ret_ty = match &original_function.sig.output { + syn::ReturnType::Default => None, + syn::ReturnType::Type(_, t) => Some(&**t), + } + .into_iter(); + + let register_describer_symbol = format!("__preinit__20_register_describer_{}", procedure_name.value()); + + let lifetime_params = &original_function.sig.generics; + let lifetime_where_clause = &lifetime_params.where_clause; + + let generated_describe_function = quote! { + #[export_name = #register_describer_symbol] + pub extern "C" fn __register_describer() { + spacetimedb::rt::register_procedure::<_, _, #func_name>(#func_name) + } + }; + + Ok(quote! { + const _: () = { + #generated_describe_function + }; + #[allow(non_camel_case_types)] + #vis struct #func_name { _never: ::core::convert::Infallible } + const _: () = { + fn _assert_args #lifetime_params () #lifetime_where_clause { + #(let _ = <#first_arg_ty as spacetimedb::rt::ProcedureContextArg>::_ITEM;)* + #(let _ = <#rest_arg_tys as spacetimedb::rt::ProcedureArg>::_ITEM;)* + #(let _ = <#ret_ty as spacetimedb::rt::IntoProcedureResult>::into_result;)* + } + }; + impl #func_name { + fn invoke(__ctx: spacetimedb::ProcedureContext, __args: &[u8]) -> spacetimedb::ProcedureResult { + spacetimedb::rt::invoke_procedure(#func_name, __ctx, __args) + } + } + #[automatically_derived] + impl spacetimedb::rt::ExportFunctionInfo for #func_name { + const NAME: &'static str = #procedure_name; + const ARG_NAMES: &'static [Option<&'static str>] = &[#(#opt_arg_names),*]; + } + #[automatically_derived] + impl spacetimedb::rt::ProcedureInfo for #func_name { + const INVOKE: spacetimedb::rt::ProcedureFn = #func_name::invoke; + } + }) +} diff --git a/crates/bindings-macro/src/reducer.rs b/crates/bindings-macro/src/reducer.rs index ff98cc6250b..72627aa53c7 100644 --- a/crates/bindings-macro/src/reducer.rs +++ b/crates/bindings-macro/src/reducer.rs @@ -4,7 +4,7 @@ use proc_macro2::{Span, TokenStream}; use quote::{quote, quote_spanned}; use syn::parse::Parser as _; use syn::spanned::Spanned; -use syn::{FnArg, Ident, ItemFn, LitStr}; +use syn::{FnArg, Ident, ItemFn, LitStr, PatType}; #[derive(Default)] pub(crate) struct ReducerArgs { @@ -59,25 +59,28 @@ impl ReducerArgs { } } -pub(crate) fn reducer_impl(args: ReducerArgs, original_function: &ItemFn) -> syn::Result { - let func_name = &original_function.sig.ident; - let vis = &original_function.vis; - - let reducer_name = args.name.unwrap_or_else(|| ident_to_litstr(func_name)); - +pub(crate) fn assert_only_lifetime_generics(original_function: &ItemFn, function_kind_plural: &str) -> syn::Result<()> { for param in &original_function.sig.generics.params { let err = |msg| syn::Error::new_spanned(param, msg); match param { syn::GenericParam::Lifetime(_) => {} - syn::GenericParam::Type(_) => return Err(err("type parameters are not allowed on reducers")), - syn::GenericParam::Const(_) => return Err(err("const parameters are not allowed on reducers")), + syn::GenericParam::Type(_) => { + return Err(err(format!( + "type parameters are not allowed on {function_kind_plural}" + ))) + } + syn::GenericParam::Const(_) => { + return Err(err(format!( + "const parameters are not allowed on {function_kind_plural}" + ))) + } } } + Ok(()) +} - let lifecycle = args.lifecycle.iter().filter_map(|lc| lc.to_lifecycle_value()); - - // Extract all function parameters, except for `self` ones that aren't allowed. - let typed_args = original_function +pub(crate) fn extract_typed_args(original_function: &ItemFn) -> syn::Result> { + original_function .sig .inputs .iter() @@ -85,7 +88,21 @@ pub(crate) fn reducer_impl(args: ReducerArgs, original_function: &ItemFn) -> syn FnArg::Typed(arg) => Ok(arg), _ => Err(syn::Error::new_spanned(arg, "expected typed argument")), }) - .collect::>>()?; + .collect() +} + +pub(crate) fn reducer_impl(args: ReducerArgs, original_function: &ItemFn) -> syn::Result { + let func_name = &original_function.sig.ident; + let vis = &original_function.vis; + + let reducer_name = args.name.unwrap_or_else(|| ident_to_litstr(func_name)); + + assert_only_lifetime_generics(original_function, "reducers")?; + + let lifecycle = args.lifecycle.iter().filter_map(|lc| lc.to_lifecycle_value()); + + // Extract all function parameters, except for `self` ones that aren't allowed. + let typed_args = extract_typed_args(original_function)?; // Extract all function parameter names. let opt_arg_names = typed_args.iter().map(|arg| { @@ -139,10 +156,13 @@ pub(crate) fn reducer_impl(args: ReducerArgs, original_function: &ItemFn) -> syn } } #[automatically_derived] - impl spacetimedb::rt::ReducerInfo for #func_name { + impl spacetimedb::rt::ExportFunctionInfo for #func_name { const NAME: &'static str = #reducer_name; - #(const LIFECYCLE: Option = Some(#lifecycle);)* const ARG_NAMES: &'static [Option<&'static str>] = &[#(#opt_arg_names),*]; + } + #[automatically_derived] + impl spacetimedb::rt::ReducerInfo for #func_name { + #(const LIFECYCLE: Option = Some(#lifecycle);)* const INVOKE: spacetimedb::rt::ReducerFn = #func_name::invoke; } }) diff --git a/crates/bindings-macro/src/table.rs b/crates/bindings-macro/src/table.rs index 4e2b079992f..e5350236fbd 100644 --- a/crates/bindings-macro/src/table.rs +++ b/crates/bindings-macro/src/table.rs @@ -40,7 +40,7 @@ impl TableAccess { struct ScheduledArg { span: Span, - reducer: Path, + reducer_or_procedure: Path, at: Option, } @@ -113,7 +113,7 @@ impl TableArgs { impl ScheduledArg { fn parse_meta(meta: ParseNestedMeta) -> syn::Result { let span = meta.path.span(); - let mut reducer = None; + let mut reducer_or_procedure = None; let mut at = None; meta.parse_nested_meta(|meta| { @@ -126,16 +126,26 @@ impl ScheduledArg { } }) } else { - check_duplicate_msg(&reducer, &meta, "can only specify one scheduled reducer")?; - reducer = Some(meta.path); + check_duplicate_msg( + &reducer_or_procedure, + &meta, + "can only specify one scheduled reducer or procedure", + )?; + reducer_or_procedure = Some(meta.path); } Ok(()) })?; - let reducer = reducer.ok_or_else(|| { - meta.error("must specify scheduled reducer associated with the table: scheduled(reducer_name)") + let reducer_or_procedure = reducer_or_procedure.ok_or_else(|| { + meta.error( + "must specify scheduled reducer or procedure associated with the table: scheduled(function_name)", + ) })?; - Ok(Self { span, reducer, at }) + Ok(Self { + span, + reducer_or_procedure, + at, + }) } } @@ -818,17 +828,17 @@ pub(crate) fn table_impl(mut args: TableArgs, item: &syn::DeriveInput) -> syn::R ) })?; - let reducer = &sched.reducer; + let reducer_or_procedure = &sched.reducer_or_procedure; let scheduled_at_id = scheduled_at_column.index; let desc = quote!(spacetimedb::table::ScheduleDesc { - reducer_name: <#reducer as spacetimedb::rt::ReducerInfo>::NAME, + reducer_or_procedure_name: <#reducer_or_procedure as spacetimedb::rt::ExportFunctionInfo>::NAME, scheduled_at_column: #scheduled_at_id, }); let primary_key_ty = primary_key_column.ty; let scheduled_at_ty = scheduled_at_column.ty; let typecheck = quote! { - spacetimedb::rt::scheduled_reducer_typecheck::<#original_struct_ident>(#reducer); + spacetimedb::rt::scheduled_typecheck::<#original_struct_ident>(#reducer_or_procedure); spacetimedb::rt::assert_scheduled_table_primary_key::<#primary_key_ty>(); let _ = |x: #scheduled_at_ty| { let _: spacetimedb::ScheduleAt = x; }; }; diff --git a/crates/bindings-sys/src/lib.rs b/crates/bindings-sys/src/lib.rs index 861f6feadc3..46101d09969 100644 --- a/crates/bindings-sys/src/lib.rs +++ b/crates/bindings-sys/src/lib.rs @@ -593,6 +593,21 @@ pub mod raw { // See comment on previous `extern "C"` block re: ABI version. #[link(wasm_import_module = "spacetime_10.1")] extern "C" { + /// Suspends execution of this WASM instance until approximately `wake_at_micros_since_unix_epoch`. + /// + /// Returns immediately if `wake_at_micros_since_unix_epoch` is in the past. + /// + /// Upon resuming, returns the current timestamp as microseconds since the Unix epoch. + /// + /// Not particularly useful, except for testing SpacetimeDB internals related to suspending procedure execution. + /// # Traps + /// + /// Traps if: + /// + /// - The calling WASM instance is holding open a transaction. + /// - The calling WASM instance is not executing a procedure. + pub fn procedure_sleep_until(wake_at_micros_since_unix_epoch: i64) -> i64; + /// Read the remaining length of a [`BytesSource`] and write it to `out`. /// /// Note that the host automatically frees byte sources which are exhausted. @@ -1169,3 +1184,12 @@ impl Drop for RowIter { } } } + +pub mod procedure { + #[inline] + pub fn sleep_until(wake_at_timestamp: i64) -> i64 { + // Safety: Just calling an `extern "C"` function. + // Nothing weird happening here. + unsafe { super::raw::procedure_sleep_until(wake_at_timestamp) } + } +} diff --git a/crates/bindings/src/lib.rs b/crates/bindings/src/lib.rs index b7bbb18101d..0e718ea4eb6 100644 --- a/crates/bindings/src/lib.rs +++ b/crates/bindings/src/lib.rs @@ -48,6 +48,8 @@ pub use table::{ pub type ReducerResult = core::result::Result<(), Box>; +pub type ProcedureResult = Vec; + pub use spacetimedb_bindings_macro::duration; /// Generates code for registering a row-level security rule. @@ -666,6 +668,75 @@ pub use spacetimedb_bindings_macro::table; #[doc(inline)] pub use spacetimedb_bindings_macro::reducer; +// TODO: document +#[doc(inline)] +pub use spacetimedb_bindings_macro::procedure; + +/// The context that any procedure is provided with. +/// +/// Each procedure must accept `&mut ProcedureContext` as its first argument. +/// +/// Includes information about the client calling the procedure and the time of invocation, +/// and exposes methods for running transactions and performing side-effecting operations. +/// +/// If the crate was compiled with the `rand` feature, +/// also includes faculties for random number generation. +pub struct ProcedureContext { + /// The `Identity` of the client that invoked the procedure. + pub sender: Identity, + + /// The time at which the procedure was started. + pub timestamp: Timestamp, + + /// The `ConnectionId` of the client that invoked the procedure. + /// + /// Will be `None` for certain scheduled procedures. + pub connection_id: Option, + + #[cfg(feature = "rand08")] + rng: std::cell::OnceCell, +} + +impl ProcedureContext { + /// Read the current module's [`Identity`]. + pub fn identity(&self) -> Identity { + // Hypothetically, we *could* read the module identity out of the system tables. + // However, this would be: + // - Onerous, because we have no tooling to inspect the system tables from module code. + // - Slow (at least relatively), + // because it would involve multiple host calls which hit the datastore, + // as compared to a single host call which does not. + // As such, we've just defined a host call + // which reads the module identity out of the `InstanceEnv`. + Identity::from_byte_array(spacetimedb_bindings_sys::identity()) + } + + /// Suspend execution until approximately `Timestamp`. + /// + /// This will update `self.timestamp` to the new time after execution resumes. + /// + /// Actual time suspended may not be exactly equal to `duration`. + /// Callers should read `self.timestamp` after resuming to determine the new time. + /// + /// ```no-run + /// # use std::time::Duration; + /// # #[procedure] + /// # fn sleep_one_second(ctx: &mut ProcedureContext) { + /// let prev_time = ctx.timestamp; + /// let target = timestamp + Duration::SECOND; + /// ctx.sleep_until(target); + /// let new_time = ctx.timestamp; + /// let actual_delta = new_time - prev_time; + /// log::info!("Slept from {prev_time} to {new_time}, a total of {actual_delta}"); + /// # } + /// ``` + pub fn sleep_until(&mut self, timestamp: Timestamp) { + let new_time = sys::procedure::sleep_until(timestamp.to_micros_since_unix_epoch()); + let new_time = Timestamp::from_micros_since_unix_epoch(new_time); + self.timestamp = new_time; + } +} + /// One of two possible types that can be passed as the first argument to a `#[view]`. /// The other is [`ViewContext`]. /// Use this type if the view does not depend on the caller's identity. @@ -706,11 +777,8 @@ pub struct ReducerContext { /// The `ConnectionId` of the client that invoked the reducer. /// - /// `None` if no `ConnectionId` was supplied to the `/database/call` HTTP endpoint, - /// or via the CLI's `spacetime call` subcommand. - /// - /// For automatic reducers, i.e. `init`, `client_connected`, `client_disconnected`, and scheduled reducers, - /// this will be the module's `ConnectionId`. + /// Will be `None` for certain reducers invoked automatically by the host, + /// including `init` and scheduled reducers. pub connection_id: Option, /// Allows accessing the local database attached to a module. diff --git a/crates/bindings/src/rng.rs b/crates/bindings/src/rng.rs index 07415e14354..0bf44f7b496 100644 --- a/crates/bindings/src/rng.rs +++ b/crates/bindings/src/rng.rs @@ -1,13 +1,57 @@ use std::cell::UnsafeCell; use std::marker::PhantomData; -use crate::rand; +use crate::{rand, ProcedureContext, ReducerContext}; use rand::distributions::{Distribution, Standard}; use rand::rngs::StdRng; use rand::{RngCore, SeedableRng}; -use crate::ReducerContext; +impl ProcedureContext { + /// Generates a random value. + /// + /// Similar to [`rand::random()`], but using [`StdbRng`] instead. + /// + /// See also [`ReducerContext::rng()`]. + pub fn random(&self) -> T + where + Standard: Distribution, + { + Standard.sample(&mut self.rng()) + } + + /// Retrieve the random number generator for this procedure invocation, + /// seeded by the timestamp of the procedure call. + /// + /// If you only need a single random value, you can use [`ProcedureContext::random()`]. + /// + /// # Examples + /// + /// ```no_run + /// # #[cfg(target_arch = "wasm32")] mod demo { + /// use spacetimedb::{procedure, ProcedureContext}; + /// use rand::Rng; + /// + /// #[spacetimedb::procedure] + /// fn rng_demo(ctx: &mut spacetimedb::ProcedureContext) { + /// // Can be used in method chaining style: + /// let digit = ctx.rng().gen_range(0..=9); + /// + /// // Or, cache locally for reuse: + /// let mut rng = ctx.rng(); + /// let floats: Vec = rng.sample_iter(rand::distributions::Standard).collect(); + /// } + /// # } + /// ``` + /// + /// For more information, see [`StdbRng`] and [`rand::Rng`]. + pub fn rng(&self) -> &StdbRng { + self.rng.get_or_init(|| StdbRng { + rng: StdRng::seed_from_u64(self.timestamp.to_micros_since_unix_epoch() as u64).into(), + _marker: PhantomData, + }) + } +} impl ReducerContext { /// Generates a random value. diff --git a/crates/bindings/src/rt.rs b/crates/bindings/src/rt.rs index 6cdddde830d..add8fb7170f 100644 --- a/crates/bindings/src/rt.rs +++ b/crates/bindings/src/rt.rs @@ -1,7 +1,7 @@ #![deny(unsafe_op_in_unsafe_fn)] use crate::table::IndexAlgo; -use crate::{sys, IterBuf, ReducerContext, ReducerResult, SpacetimeType, Table}; +use crate::{sys, IterBuf, ProcedureContext, ProcedureResult, ReducerContext, ReducerResult, SpacetimeType, Table}; pub use spacetimedb_lib::db::raw_def::v9::Lifecycle as LifecycleReducer; use spacetimedb_lib::db::raw_def::v9::{RawIndexAlgorithm, RawModuleDefV9Builder, TableType}; use spacetimedb_lib::de::{self, Deserialize, Error as _, SeqProductAccess}; @@ -29,6 +29,25 @@ pub fn invoke_reducer<'a, A: Args<'a>>( reducer.invoke(&ctx, args) } + +/// Invoke `procedure` +pub fn invoke_procedure<'a, A: Args<'a>, Ret: IntoProcedureResult>( + procedure: impl Procedure<'a, A, Ret>, + mut ctx: ProcedureContext, + args: &'a [u8], +) -> ProcedureResult { + // Deserialize the arguments from a bsatn encoding. + let SerDeArgs(args) = bsatn::from_slice(args).expect("unable to decode args"); + + let ret = procedure.invoke(&mut ctx, args); + + ret.into_result() +} + +/// Marker supertrait for [`Reducer`] and [`Procedure`], +/// used for typechecking by [`scheduled_typecheck`]. +pub trait ExportFunction<'de, A: Args<'de>> {} + /// A trait for types representing the *execution logic* of a reducer. #[diagnostic::on_unimplemented( message = "invalid reducer signature", @@ -39,26 +58,56 @@ pub fn invoke_reducer<'a, A: Args<'a>>( note = "where each `Ti` type implements `SpacetimeType`.", note = "" )] -pub trait Reducer<'de, A: Args<'de>> { +pub trait Reducer<'de, A: Args<'de>>: ExportFunction<'de, A> { fn invoke(&self, ctx: &ReducerContext, args: A) -> ReducerResult; } /// A trait for types that can *describe* a reducer. -pub trait ReducerInfo { - /// The name of the reducer. - const NAME: &'static str; - +/// +/// The `#[reducer]` macro generates an empty struct which implements this trait, +/// along with [`ExportFunctionInfo`]. +pub trait ReducerInfo: ExportFunctionInfo { /// The lifecycle of the reducer, if there is one. const LIFECYCLE: Option = None; - /// A description of the parameter names of the reducer. - const ARG_NAMES: &'static [Option<&'static str>]; - /// The function to call to invoke the reducer. const INVOKE: ReducerFn; } -/// A trait of types representing the arguments of a reducer. +#[diagnostic::on_unimplemented( + message = "invalid procedure signature", + label = "this procedure signature is not valid", + note = "", + note = "procedure signatures must match the following pattern:", + note = " `Fn(&mut ProcedureContext, [T1, ...]) [-> Tn]`", + note = "where each `Ti` implements `SpacetimeType`.", + note = "" +)] +pub trait Procedure<'de, A: Args<'de>, Ret: IntoProcedureResult>: ExportFunction<'de, A> { + fn invoke(&self, ctx: &mut ProcedureContext, args: A) -> Ret; +} + +/// A trait for types that can *describe* a procedure. +/// +/// The `#[procedure]` macro generates an empty struct which implements this trait, +/// along with [`ExportFunctionInfo`]. +pub trait ProcedureInfo: ExportFunctionInfo { + /// The function to invoke the procedure. + const INVOKE: ProcedureFn; +} + +/// Shared super-trait of [`ProcedureInfo`] and [`ReducerInfo`]. +pub trait ExportFunctionInfo { + const NAME: &'static str; + + const ARG_NAMES: &'static [Option<&'static str>]; +} + +/// A trait of types representing the arguments of a reducer or procedure. +/// +/// This does not include the `ReducerContext` or `ProcedureContext` first argument, +/// only the client-provided args. +/// As such, the same trait can be used for both procedures and reducers. pub trait Args<'de>: Sized { /// How many arguments does the reducer accept? const LEN: usize; @@ -69,8 +118,8 @@ pub trait Args<'de>: Sized { /// Serialize the arguments in `self` into the sequence `prod` according to the type `S`. fn serialize_seq_product(&self, prod: &mut S) -> Result<(), S::Error>; - /// Returns the schema for this reducer provided a `typespace`. - fn schema(typespace: &mut impl TypespaceBuilder) -> ProductType; + /// Returns the arguments schema [`ProductType`] for this reducer or procedure, provided a `typespace`. + fn schema(typespace: &mut impl TypespaceBuilder) -> ProductType; } /// A trait of types representing the result of executing a reducer. @@ -96,6 +145,18 @@ impl IntoReducerResult for Result<(), E> { } } +#[diagnostic::on_unimplemented( + message = "The procdure return type `{Self}` does not implement `SpacetimeType`", + note = "if you own the type, try adding `#[derive(SpacetimeType)]` to its definition" +)] +pub trait IntoProcedureResult: SpacetimeType + Serialize { + #[inline] + fn into_result(&self) -> ProcedureResult { + bsatn::to_vec(&self).expect("Failed to serialize procedure result") + } +} +impl IntoProcedureResult for T {} + #[diagnostic::on_unimplemented( message = "the first argument of a reducer must be `&ReducerContext`", label = "first argument must be `&ReducerContext`" @@ -119,8 +180,31 @@ pub trait ReducerArg { } impl ReducerArg for T {} -/// Assert that a reducer type-checks with a given type. -pub const fn scheduled_reducer_typecheck<'de, Row>(_x: impl ReducerForScheduledTable<'de, Row>) +#[diagnostic::on_unimplemented( + message = "the first argument of a procedure must be `&mut ProcedureContext`", + label = "first argument must be `&mut ProcedureContext`" +)] +pub trait ProcedureContextArg { + // a little hack used in the macro to make error messages nicer. it generates ::_ITEM + #[doc(hidden)] + const _ITEM: () = (); +} +impl ProcedureContextArg for &mut ProcedureContext {} + +/// A trait of types that can be an argument of a procedure. +#[diagnostic::on_unimplemented( + message = "the procedure argument `{Self}` does not implement `SpacetimeType`", + note = "if you own the type, try adding `#[derive(SpacetimeType)]` to its definition" +)] +pub trait ProcedureArg { + // a little hack used in the macro to make error messages nicer. it generates ::_ITEM + #[doc(hidden)] + const _ITEM: () = (); +} +impl ProcedureArg for T {} + +/// Assert that a reducer or procedure type-checks with a given argument type. +pub const fn scheduled_typecheck<'de, Row>(_x: impl ExportFunctionForScheduledTable<'de, Row>) where Row: SpacetimeType + Serialize + Deserialize<'de>, { @@ -128,13 +212,14 @@ where } #[diagnostic::on_unimplemented( - message = "invalid signature for scheduled table reducer", - note = "the scheduled reducer must take `{TableRow}` as its sole argument", - note = "e.g: `fn scheduled_reducer(ctx: &ReducerContext, arg: {TableRow})`" + message = "invalid signature for scheduled table reducer or procedure", + note = "the scheduled reducer or procedure must take `{TableRow}` as its sole argument", + note = "e.g: `fn scheduled_reducer(ctx: &ReducerContext, arg: {TableRow})`", + note = "or `fn scheduled_procedure(ctx: &mut ProcedureContext, arg: {TableRow})`" )] -pub trait ReducerForScheduledTable<'de, TableRow> {} -impl<'de, TableRow: SpacetimeType + Serialize + Deserialize<'de>, R: Reducer<'de, (TableRow,)>> - ReducerForScheduledTable<'de, TableRow> for R +pub trait ExportFunctionForScheduledTable<'de, TableRow> {} +impl<'de, TableRow: SpacetimeType + Serialize + Deserialize<'de>, R: ExportFunction<'de, (TableRow,)>> + ExportFunctionForScheduledTable<'de, TableRow> for R { } @@ -206,15 +291,15 @@ impl<'de, A: Args<'de>> de::ProductVisitor<'de> for ArgsVisitor { } } -macro_rules! impl_reducer { +macro_rules! impl_reducer_and_procedure { ($($T1:ident $(, $T:ident)*)?) => { - impl_reducer!(@impl $($T1 $(, $T)*)?); - $(impl_reducer!($($T),*);)? + impl_reducer_and_procedure!(@impl $($T1 $(, $T)*)?); + $(impl_reducer_and_procedure!($($T),*);)? }; (@impl $($T:ident),*) => { // Implement `Args` for the tuple type `($($T,)*)`. impl<'de, $($T: SpacetimeType + Deserialize<'de> + Serialize),*> Args<'de> for ($($T,)*) { - const LEN: usize = impl_reducer!(@count $($T)*); + const LEN: usize = impl_reducer_and_procedure!(@count $($T)*); #[allow(non_snake_case)] #[allow(unused)] fn visit_seq_product>(mut prod: Acc) -> Result { @@ -239,7 +324,7 @@ macro_rules! impl_reducer { #[inline] #[allow(non_snake_case, irrefutable_let_patterns)] - fn schema(_typespace: &mut impl TypespaceBuilder) -> ProductType { + fn schema(_typespace: &mut impl TypespaceBuilder) -> ProductType { // Extract the names of the arguments. let [.., $($T),*] = Info::ARG_NAMES else { panic!() }; ProductType::new(vec![ @@ -251,6 +336,8 @@ macro_rules! impl_reducer { } } + impl<'de, Func, $($T: SpacetimeType + Deserialize<'de> + Serialize),*> ExportFunction<'de, ($($T,)*)> for Func {} + // Implement `Reducer<..., ContextArg>` for the tuple type `($($T,)*)`. impl<'de, Func, Ret, $($T: SpacetimeType + Deserialize<'de> + Serialize),*> Reducer<'de, ($($T,)*)> for Func where @@ -264,15 +351,28 @@ macro_rules! impl_reducer { } } + impl<'de, Func, Ret, $($T: SpacetimeType + Deserialize<'de> + Serialize),*> Procedure<'de, ($($T,)*), Ret> for Func + where + Func: Fn(&mut ProcedureContext, $($T),*) -> Ret, + Ret: IntoProcedureResult, + { + #[allow(non_snake_case)] + fn invoke(&self, ctx: &mut ProcedureContext, args: ($($T,)*)) -> Ret { + let ($($T,)*) = args; + self(ctx, $($T),*) + } + } }; // Counts the number of elements in the tuple. (@count $($T:ident)*) => { - 0 $(+ impl_reducer!(@drop $T 1))* + 0 $(+ impl_reducer_and_procedure!(@drop $T 1))* }; (@drop $a:tt $b:tt) => { $b }; } -impl_reducer!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z, AA, AB, AC, AD, AE, AF); +impl_reducer_and_procedure!( + A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z, AA, AB, AC, AD, AE, AF +); /// Provides deserialization and serialization for any type `A: Args`. struct SerDeArgs(A); @@ -339,7 +439,7 @@ pub fn register_table() { table = table.with_column_sequence(col); } if let Some(schedule) = T::SCHEDULE { - table = table.with_schedule(schedule.reducer_name, schedule.scheduled_at_column); + table = table.with_schedule(schedule.reducer_or_procedure_name, schedule.scheduled_at_column); } for col in T::get_default_col_values().iter_mut() { @@ -370,6 +470,20 @@ pub fn register_reducer<'a, A: Args<'a>, I: ReducerInfo>(_: impl Reducer<'a, A>) }) } +pub fn register_procedure<'a, A, Ret, I>(_: impl Procedure<'a, A, Ret>) +where + A: Args<'a>, + Ret: SpacetimeType + Serialize, + I: ProcedureInfo, +{ + register_describer(|module| { + let params = A::schema::(&mut module.inner); + let ret_ty = ::make_type(&mut module.inner); + module.inner.add_procedure(I::NAME, params, ret_ty); + module.procedures.push(I::INVOKE); + }) +} + /// Registers a row-level security policy. pub fn register_row_level_security(sql: &'static str) { register_describer(|module| { @@ -384,6 +498,8 @@ struct ModuleBuilder { inner: RawModuleDefV9Builder, /// The reducers of the module. reducers: Vec, + /// The procedures of the module. + procedures: Vec, } // Not actually a mutex; because WASM is single-threaded this basically just turns into a refcell. @@ -394,6 +510,9 @@ static DESCRIBERS: Mutex>> = Mutex::new(Vec::new()); pub type ReducerFn = fn(ReducerContext, &[u8]) -> ReducerResult; static REDUCERS: OnceLock> = OnceLock::new(); +pub type ProcedureFn = fn(ProcedureContext, &[u8]) -> ProcedureResult; +static PROCEDURES: OnceLock> = OnceLock::new(); + /// Called by the host when the module is initialized /// to describe the module into a serialized form that is returned. /// @@ -425,6 +544,8 @@ extern "C" fn __describe_module__(description: BytesSink) { // Write the set of reducers. REDUCERS.set(module.reducers).ok().unwrap(); + PROCEDURES.set(module.procedures).ok().unwrap(); + // Write the bsatn data into the sink. write_to_sink(description, &bytes); } @@ -475,16 +596,10 @@ extern "C" fn __call_reducer__( error: BytesSink, ) -> i16 { // Piece together `sender_i` into an `Identity`. - let sender = [sender_0, sender_1, sender_2, sender_3]; - let sender: [u8; 32] = bytemuck::must_cast(sender); - let sender = Identity::from_byte_array(sender); // The LITTLE-ENDIAN constructor. + let sender = reconstruct_sender_identity(sender_0, sender_1, sender_2, sender_3); // Piece together `conn_id_i` into a `ConnectionId`. - // The all-zeros `ConnectionId` (`ConnectionId::ZERO`) is interpreted as `None`. - let conn_id = [conn_id_0, conn_id_1]; - let conn_id: [u8; 16] = bytemuck::must_cast(conn_id); - let conn_id = ConnectionId::from_le_byte_array(conn_id); // The LITTLE-ENDIAN constructor. - let conn_id = (conn_id != ConnectionId::ZERO).then_some(conn_id); + let conn_id = reconstruct_connection_id(conn_id_0, conn_id_1); // Assemble the `ReducerContext`. let timestamp = Timestamp::from_micros_since_unix_epoch(timestamp as i64); @@ -501,15 +616,117 @@ extern "C" fn __call_reducer__( // Dispatch to it with the arguments read. let res = with_read_args(args, |args| reducers[id](ctx, args)); // Convert any error message to an error code and writes to the `error` sink. + convert_err_to_errno(res, error) +} + +/// Reconstruct the `sender_i` args to [`__call_reducer__`] and [`__call_procedure__`] into an [`Identity`]. +fn reconstruct_sender_identity(sender_0: u64, sender_1: u64, sender_2: u64, sender_3: u64) -> Identity { + let sender = [sender_0, sender_1, sender_2, sender_3]; + let sender: [u8; 32] = bytemuck::must_cast(sender); + let sender = Identity::from_byte_array(sender); // The LITTLE-ENDIAN constructor. + sender +} + +/// Reconstruct the `conn_id_i` args to [`__call_reducer__`] and [`__call_procedure__`] into a [`ConnectionId`]. +/// +/// The all-zeros `ConnectionId` (`ConnectionId::ZERO`) is interpreted as `None`. +fn reconstruct_connection_id(conn_id_0: u64, conn_id_1: u64) -> Option { + // Piece together `conn_id_i` into a `ConnectionId`. + // The all-zeros `ConnectionId` (`ConnectionId::ZERO`) is interpreted as `None`. + let conn_id = [conn_id_0, conn_id_1]; + let conn_id: [u8; 16] = bytemuck::must_cast(conn_id); + let conn_id = ConnectionId::from_le_byte_array(conn_id); // The LITTLE-ENDIAN constructor. + let conn_id = (conn_id != ConnectionId::ZERO).then_some(conn_id); + conn_id +} + +/// If `res` is `Err`, write the message to `out` and return non-zero. +/// If `res` is `Ok`, return zero. +/// +/// Called by [`__call_reducer__`] and [`__call_procedure__`] +/// to convert the user-returned `Result` into a low-level errno return. +fn convert_err_to_errno(res: Result<(), Box>, out: BytesSink) -> i16 { match res { Ok(()) => 0, Err(msg) => { - write_to_sink(error, msg.as_bytes()); + write_to_sink(out, msg.as_bytes()); errno::HOST_CALL_FAILURE.get() as i16 } } } +/// Called by the host to execute a procedure +/// when the `sender` calls the reducer identified by `id` at `timestamp` with `args`. +/// +/// The `sender_{0-3}` are the pieces of a `[u8; 32]` (`u256`) representing the sender's `Identity`. +/// They are encoded as follows (assuming `identity.to_byte_array(): [u8; 32]`): +/// - `sender_0` contains bytes `[0 ..8 ]`. +/// - `sender_1` contains bytes `[8 ..16]`. +/// - `sender_2` contains bytes `[16..24]`. +/// - `sender_3` contains bytes `[24..32]`. +/// +/// Note that `to_byte_array` uses LITTLE-ENDIAN order! This matches most host systems. +/// +/// The `conn_id_{0-1}` are the pieces of a `[u8; 16]` (`u128`) representing the callers's [`ConnectionId`]. +/// They are encoded as follows (assuming `conn_id.as_le_byte_array(): [u8; 16]`): +/// - `conn_id_0` contains bytes `[0 ..8 ]`. +/// - `conn_id_1` contains bytes `[8 ..16]`. +/// +/// Again, note that `to_byte_array` uses LITTLE-ENDIAN order! This matches most host systems. +/// +/// The `args` is a `BytesSource`, registered on the host side, +/// which can be read with `bytes_source_read`. +/// The contents of the buffer are the BSATN-encoding of the arguments to the reducer. +/// In the case of empty arguments, `args` will be 0, that is, invalid. +/// +/// The `result_sink` is a `BytesSink`, registered on the host side, +/// which can be written to with `bytes_sink_write`. +/// Procedures are expected to always write to this sink +/// the BSATN-serialized bytes of a value of the procedure's return type. +/// +/// Procedures always return the error 0. All other return values are reserved. +#[no_mangle] +extern "C" fn __call_procedure__( + id: usize, + sender_0: u64, + sender_1: u64, + sender_2: u64, + sender_3: u64, + conn_id_0: u64, + conn_id_1: u64, + timestamp: u64, + args: BytesSource, + result_sink: BytesSink, +) -> i16 { + // Piece together `sender_i` into an `Identity`. + let sender = reconstruct_sender_identity(sender_0, sender_1, sender_2, sender_3); + + // Piece together `conn_id_i` into a `ConnectionId`. + let conn_id = reconstruct_connection_id(conn_id_0, conn_id_1); + + let timestamp = Timestamp::from_micros_since_unix_epoch(timestamp as i64); + + // Assemble the `ProcedureContext`. + let ctx = ProcedureContext { + connection_id: conn_id, + sender, + timestamp, + rng: std::cell::OnceCell::new(), + }; + + // Grab the list of procedures, which is populated by the preinit functions. + let procedures = PROCEDURES.get().unwrap(); + + // Deserialize the args and pass them to the actual procedure. + let res = with_read_args(args, |args| procedures[id](ctx, args)); + + // Write the result bytes to the `result_sink`. + write_to_sink(result_sink, &res); + + // Return 0 for no error. Procedures always either trap or return 0. + 0 +} + /// Run `logic` with `args` read from the host into a `&[u8]`. fn with_read_args(args: BytesSource, logic: impl FnOnce(&[u8]) -> R) -> R { if args == BytesSource::INVALID { diff --git a/crates/bindings/src/table.rs b/crates/bindings/src/table.rs index fa090eb8930..fcb521736ee 100644 --- a/crates/bindings/src/table.rs +++ b/crates/bindings/src/table.rs @@ -149,7 +149,7 @@ pub enum IndexAlgo<'a> { } pub struct ScheduleDesc<'a> { - pub reducer_name: &'a str, + pub reducer_or_procedure_name: &'a str, pub scheduled_at_column: u16, } diff --git a/crates/client-api-messages/src/websocket.rs b/crates/client-api-messages/src/websocket.rs index cf24e6e5e87..699edabc939 100644 --- a/crates/client-api-messages/src/websocket.rs +++ b/crates/client-api-messages/src/websocket.rs @@ -105,6 +105,8 @@ pub enum ClientMessage { /// Remove a subscription to a SQL query that was added with SubscribeSingle. Unsubscribe(Unsubscribe), UnsubscribeMulti(UnsubscribeMulti), + /// Request a procedure run. + CallProcedure(CallProcedure), } impl ClientMessage { @@ -127,6 +129,17 @@ impl ClientMessage { ClientMessage::Subscribe(x) => ClientMessage::Subscribe(x), ClientMessage::SubscribeMulti(x) => ClientMessage::SubscribeMulti(x), ClientMessage::UnsubscribeMulti(x) => ClientMessage::UnsubscribeMulti(x), + ClientMessage::CallProcedure(CallProcedure { + procedure, + args, + request_id, + flags, + }) => ClientMessage::CallProcedure(CallProcedure { + procedure, + args: f(args), + request_id, + flags, + }), } } } @@ -292,6 +305,37 @@ pub struct OneOffQuery { pub query_string: Box, } +#[derive(SpacetimeType)] +#[sats(crate = spacetimedb_lib)] +pub struct CallProcedure { + /// The name of the procedure to call. + pub procedure: Box, + /// The arguments to the procedure. + /// + /// In the wire format, this will be a [`Bytes`], BSATN or JSON encoded according to the reducer's argument schema + /// and the enclosing message format. + pub args: Args, + /// An identifier for a client request. + /// + /// The server will include the same ID in the response [`ProcedureResult`]. + pub request_id: u32, + /// Reserved space for future extensions. + pub flags: CallProcedureFlags, +} + +#[derive(Clone, Copy, Default, PartialEq, Eq)] +pub enum CallProcedureFlags { + #[default] + Default, +} + +impl_st!([] CallProcedureFlags, AlgebraicType::U8); +impl_serialize!([] CallProcedureFlags, (self, ser) => ser.serialize_u8(*self as u8)); +impl_deserialize!([] CallProcedureFlags, de => match de.deserialize_u8()? { + 0 => Ok(Self::Default), + x => Err(D::Error::custom(format_args!("invalid call procedure flag {x}"))), +}); + /// The tag recognized by the host and SDKs to mean no compression of a [`ServerMessage`]. pub const SERVER_MSG_COMPRESSION_TAG_NONE: u8 = 0; @@ -326,6 +370,8 @@ pub enum ServerMessage { SubscribeMultiApplied(SubscribeMultiApplied), /// Sent in response to an `UnsubscribeMulti` message. This contains the matching rows. UnsubscribeMultiApplied(UnsubscribeMultiApplied), + /// Sent in response to a `ProcedureCall` message. This contains the return value. + ProcedureResult(ProcedureResult), } /// The matching rows of a subscription query. @@ -705,6 +751,44 @@ pub struct OneOffTable { pub rows: F::List, } +/// Received by client from database in response to a [`ProcedureCall`] +/// after the procedure finished running. +#[derive(SpacetimeType, Debug)] +#[sats(crate = spacetimedb_lib)] +pub struct ProcedureResult { + /// The status of the procedure run. + /// + /// Contains the return value if successful, or the error message if not. + pub status: ProcedureStatus, + /// The time when the reducer started. + /// + /// Note that [`Timestamp`] serializes as `i64` nanoseconds since the Unix epoch. + pub timestamp: Timestamp, + /// The time the procedure took to run. + pub total_host_execution_duration: TimeDuration, + /// The same same client-provided identifier as in the original [`ProcedureCall`] request. + /// + /// Clients use this to correlate the response with the original request. + pub request_id: u32, +} + +#[derive(SpacetimeType, Debug)] +#[sats(crate = spacetimedb_lib)] +pub enum ProcedureStatus { + /// The procedure ran and returned the enclosed value. + /// + /// All user error handling happens within here; + /// the returned value may be a `Result` or `Option`, + /// or any other type to which the user may ascribe arbitrary meaning. + Returned(F::Single), + /// The reducer was interrupted due to insufficient energy/funds. + /// + /// The procedure may have performed some observable side effects before being interrupted. + OutOfEnergy, + /// The call failed in the host, e.g. due to a type error or unknown procedure name. + InternalError(String), +} + /// Used whenever different formats need to coexist. #[derive(Debug, Clone)] pub enum FormatSwitch { diff --git a/crates/client-api/src/routes/database.rs b/crates/client-api/src/routes/database.rs index ea056f57b2d..a49ba5dbb5f 100644 --- a/crates/client-api/src/routes/database.rs +++ b/crates/client-api/src/routes/database.rs @@ -19,11 +19,10 @@ use futures::StreamExt; use http::StatusCode; use serde::Deserialize; use spacetimedb::database_logger::DatabaseLogger; -use spacetimedb::host::module_host::ClientConnectedError; -use spacetimedb::host::ReducerCallError; -use spacetimedb::host::ReducerOutcome; -use spacetimedb::host::UpdateDatabaseResult; -use spacetimedb::host::{MigratePlanResult, ReducerArgs}; +use spacetimedb::host::{ + ClientConnectedError, MigratePlanResult, ProcedureCallError, ReducerArgs, ReducerCallError, ReducerOutcome, + UpdateDatabaseResult, +}; use spacetimedb::identity::Identity; use spacetimedb::messages::control_db::{Database, HostType}; use spacetimedb_client_api_messages::name::{ @@ -31,7 +30,7 @@ use spacetimedb_client_api_messages::name::{ }; use spacetimedb_lib::db::raw_def::v9::RawModuleDefV9; use spacetimedb_lib::identity::AuthCtx; -use spacetimedb_lib::{sats, ProductValue, Timestamp}; +use spacetimedb_lib::sats::{self, timestamp::Timestamp, AlgebraicValue, ProductValue}; use spacetimedb_schema::auto_migrate::{ MigrationPolicy as SchemaMigrationPolicy, MigrationToken, PrettyPrintStyle as AutoMigratePrettyPrintStyle, }; @@ -70,7 +69,7 @@ pub async fn call( log::error!("Could not find database: {}", db_identity.to_hex()); NO_SUCH_DATABASE })?; - let identity = database.owner_identity; + let owner_identity = database.owner_identity; let leader = worker_ctx .leader(database.id) @@ -83,35 +82,13 @@ pub async fn call( // so generate one. let connection_id = generate_random_connection_id(); - match module.call_identity_connected(auth.into(), connection_id).await { - // If `call_identity_connected` returns `Err(Rejected)`, then the `client_connected` reducer errored, - // meaning the connection was refused. Return 403 forbidden. - Err(ClientConnectedError::Rejected(msg)) => return Err((StatusCode::FORBIDDEN, msg).into()), - // If `call_identity_connected` returns `Err(OutOfEnergy)`, - // then, well, the database is out of energy. - // Return 503 service unavailable. - Err(err @ ClientConnectedError::OutOfEnergy) => { - return Err((StatusCode::SERVICE_UNAVAILABLE, err.to_string()).into()) - } - // If `call_identity_connected` returns `Err(ReducerCall)`, - // something went wrong while invoking the `client_connected` reducer. - // I (pgoldman 2025-03-27) am not really sure how this would happen, - // but we returned 404 not found in this case prior to my editing this code, - // so I guess let's keep doing that. - Err(ClientConnectedError::ReducerCall(e)) => { - return Err((StatusCode::NOT_FOUND, format!("{:#}", anyhow::anyhow!(e))).into()) - } - // If `call_identity_connected` returns `Err(DBError)`, - // then the module didn't define `client_connected`, - // but something went wrong when we tried to insert into `st_client`. - // That's weird and scary, so return 500 internal error. - Err(e @ ClientConnectedError::DBError(_)) => { - return Err((StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into()) - } + // Call the database's `client_connected` reducer, if any. + // If it fails or rejects the connection, bail. + module + .call_identity_connected(auth.into(), connection_id) + .await + .map_err(client_connected_error_to_response)?; - // If `call_identity_connected` returns `Ok`, then we can actually call the reducer we want. - Ok(()) => (), - } let result = match module .call_reducer(caller_identity, Some(connection_id), None, None, None, &reducer, args) .await @@ -139,17 +116,14 @@ pub async fn call( } }; - if let Err(e) = module.call_identity_disconnected(caller_identity, connection_id).await { - // If `call_identity_disconnected` errors, something is very wrong: - // it means we tried to delete the `st_client` row but failed. - // Note that `call_identity_disconnected` swallows errors from the `client_disconnected` reducer. - // Slap a 500 on it and pray. - return Err((StatusCode::INTERNAL_SERVER_ERROR, format!("{:#}", anyhow::anyhow!(e))).into()); - } + module + .call_identity_disconnected(caller_identity, connection_id) + .await + .map_err(client_disconnected_error_to_response)?; match result { Ok(result) => { - let (status, body) = reducer_outcome_response(&identity, &reducer, result.outcome); + let (status, body) = reducer_outcome_response(&owner_identity, &reducer, result.outcome); Ok(( status, TypedHeader(SpacetimeEnergyUsed(result.energy_used)), @@ -161,7 +135,7 @@ pub async fn call( } } -fn reducer_outcome_response(identity: &Identity, reducer: &str, outcome: ReducerOutcome) -> (StatusCode, String) { +fn reducer_outcome_response(owner_identity: &Identity, reducer: &str, outcome: ReducerOutcome) -> (StatusCode, String) { match outcome { ReducerOutcome::Committed => (StatusCode::OK, "".to_owned()), ReducerOutcome::Failed(errmsg) => { @@ -169,7 +143,7 @@ fn reducer_outcome_response(identity: &Identity, reducer: &str, outcome: Reducer (StatusCode::from_u16(530).unwrap(), errmsg) } ReducerOutcome::BudgetExceeded => { - log::warn!("Node's energy budget exceeded for identity: {identity} while executing {reducer}"); + log::warn!("Node's energy budget exceeded for identity: {owner_identity} while executing {reducer}"); ( StatusCode::PAYMENT_REQUIRED, "Module energy budget exhausted.".to_owned(), @@ -178,6 +152,38 @@ fn reducer_outcome_response(identity: &Identity, reducer: &str, outcome: Reducer } } +fn client_connected_error_to_response(err: ClientConnectedError) -> ErrorResponse { + match err { + // If `call_identity_connected` returns `Err(Rejected)`, then the `client_connected` reducer errored, + // meaning the connection was refused. Return 403 forbidden. + ClientConnectedError::Rejected(msg) => (StatusCode::FORBIDDEN, msg).into(), + // If `call_identity_connected` returns `Err(OutOfEnergy)`, + // then, well, the database is out of energy. + // Return 503 service unavailable. + ClientConnectedError::OutOfEnergy => (StatusCode::SERVICE_UNAVAILABLE, err.to_string()).into(), + // If `call_identity_connected` returns `Err(ReducerCall)`, + // something went wrong while invoking the `client_connected` reducer. + // I (pgoldman 2025-03-27) am not really sure how this would happen, + // but we returned 404 not found in this case prior to my editing this code, + // so I guess let's keep doing that. + ClientConnectedError::ReducerCall(e) => (StatusCode::NOT_FOUND, format!("{:#}", anyhow::anyhow!(e))).into(), + // If `call_identity_connected` returns `Err(DBError)`, + // then the module didn't define `client_connected`, + // but something went wrong when we tried to insert into `st_client`. + // That's weird and scary, so return 500 internal error. + ClientConnectedError::DBError(_) => (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into(), + } +} + +/// If `call_identity_disconnected` errors, something is very wrong: +/// it means we tried to delete the `st_client` row but failed. +/// +/// Note that `call_identity_disconnected` swallows errors from the `client_disconnected` reducer. +/// Slap a 500 on it and pray. +fn client_disconnected_error_to_response(err: ReducerCallError) -> ErrorResponse { + (StatusCode::INTERNAL_SERVER_ERROR, format!("{:#}", anyhow::anyhow!(err))).into() +} + #[derive(Debug, derive_more::From)] pub enum DBCallErr { HandlerError(ErrorResponse), @@ -185,6 +191,107 @@ pub enum DBCallErr { InstanceNotScheduled, } +#[derive(Deserialize)] +pub struct ProcedureParams { + name_or_identity: NameOrIdentity, + procedure: String, +} + +async fn procedure( + State(worker_ctx): State, + Extension(auth): Extension, + Path(ProcedureParams { + name_or_identity, + procedure, + }): Path, + TypedHeader(content_type): TypedHeader, + ByteStringBody(body): ByteStringBody, +) -> axum::response::Result { + if content_type != headers::ContentType::json() { + return Err(axum::extract::rejection::MissingJsonContentType::default().into()); + } + let caller_identity = auth.claims.identity; + + let args = ReducerArgs::Json(body); + + let db_identity = name_or_identity.resolve(&worker_ctx).await?; + let database = worker_ctx_find_database(&worker_ctx, &db_identity) + .await? + .ok_or_else(|| { + log::error!("Could not find database: {}", db_identity.to_hex()); + NO_SUCH_DATABASE + })?; + + let leader = worker_ctx + .leader(database.id) + .await + .map_err(log_and_500)? + .ok_or(StatusCode::NOT_FOUND)?; + let module = leader.module().await.map_err(log_and_500)?; + + // HTTP callers always need a connection ID to provide to connect/disconnect, + // so generate one. + let connection_id = generate_random_connection_id(); + + // Call the database's `client_connected` reducer, if any. + // If it fails or rejects the connection, bail. + module + .call_identity_connected(auth.into(), connection_id) + .await + .map_err(client_connected_error_to_response)?; + + let result = match module + .call_procedure(caller_identity, Some(connection_id), None, &procedure, args) + .await + { + Ok(res) => Ok(res), + Err(e) => { + let status_code = match e { + ProcedureCallError::Args(_) => { + log::debug!("Attempt to call reducer with invalid arguments"); + StatusCode::BAD_REQUEST + } + ProcedureCallError::NoSuchModule(_) => StatusCode::NOT_FOUND, + ProcedureCallError::NoSuchProcedure => { + log::debug!("Attempt to call non-existent procedure {procedure}"); + StatusCode::NOT_FOUND + } + ProcedureCallError::OutOfEnergy => StatusCode::PAYMENT_REQUIRED, + ProcedureCallError::InternalError(_) => StatusCode::INTERNAL_SERVER_ERROR, + }; + log::error!("Error while invoking procedure {e:#}"); + Err((status_code, format!("{:#}", anyhow::anyhow!(e)))) + } + }; + + module + .call_identity_disconnected(caller_identity, connection_id) + .await + .map_err(client_disconnected_error_to_response)?; + + match result { + Ok(result) => { + // Procedures don't assign a special meaning to error returns, unlike reducers, + // as there's no transaction for them to automatically abort. + // Instead, we just pass on their return value with the OK status so long as we successfully invoked the procedure. + let (status, body) = procedure_outcome_response(result.return_val); + Ok(( + status, + TypedHeader(SpacetimeExecutionDurationMicros(result.execution_duration)), + body, + )) + } + Err(e) => Err((e.0, e.1).into()), + } +} + +fn procedure_outcome_response(return_val: AlgebraicValue) -> (StatusCode, axum::response::Response) { + ( + StatusCode::OK, + axum::Json(sats::serde::SerdeWrapper(return_val)).into_response(), + ) +} + #[derive(Deserialize)] pub struct SchemaParams { name_or_identity: NameOrIdentity, @@ -950,6 +1057,8 @@ pub struct DatabaseRoutes { pub subscribe_get: MethodRouter, /// POST: /database/:name_or_identity/call/:reducer pub call_reducer_post: MethodRouter, + /// POST: /database/:name_or_identity/procedure/:reducer + pub call_procedure_post: MethodRouter, /// GET: /database/:name_or_identity/schema pub schema_get: MethodRouter, /// GET: /database/:name_or_identity/logs @@ -979,6 +1088,7 @@ where identity_get: get(get_identity::), subscribe_get: get(handle_websocket::), call_reducer_post: post(call::), + call_procedure_post: post(procedure::), schema_get: get(schema::), logs_get: get(logs::), sql_post: post(sql::), @@ -1003,6 +1113,7 @@ where .route("/identity", self.identity_get) .route("/subscribe", self.subscribe_get) .route("/call/:reducer", self.call_reducer_post) + .route("/procedure/:procedure", self.call_procedure_post) .route("/schema", self.schema_get) .route("/logs", self.logs_get) .route("/sql", self.sql_post) diff --git a/crates/core/src/client/client_connection.rs b/crates/core/src/client/client_connection.rs index 85d43482d16..b9558e37460 100644 --- a/crates/core/src/client/client_connection.rs +++ b/crates/core/src/client/client_connection.rs @@ -9,11 +9,13 @@ use std::time::{Instant, SystemTime}; use super::messages::{OneOffQueryResponseMessage, SerializableMessage}; use super::{message_handlers, ClientActorId, MessageHandleError}; +use crate::client::messages::ProcedureResultMessage; use crate::db::relational_db::RelationalDB; use crate::error::DBError; use crate::host::module_host::ClientConnectedError; use crate::host::{ModuleHost, NoSuchModule, ReducerArgs, ReducerCallError, ReducerCallResult}; use crate::messages::websocket::Subscribe; +use crate::subscription::module_subscription_manager::BroadcastError; use crate::util::asyncify; use crate::util::prometheus_handle::IntGaugeExt; use crate::worker_metrics::WORKER_METRICS; @@ -834,6 +836,29 @@ impl ClientConnection { .await } + pub async fn call_procedure( + &self, + procedure: &str, + args: ReducerArgs, + request_id: RequestId, + timer: Instant, + ) -> Result<(), BroadcastError> { + let res = self + .module() + .call_procedure( + self.id.identity, + Some(self.id.connection_id), + Some(timer), + procedure, + args, + ) + .await; + + self.module() + .subscriptions() + .send_procedure_message(self.sender(), ProcedureResultMessage::from_result(&res, request_id)) + } + pub async fn subscribe_single( &self, subscription: SubscribeSingle, diff --git a/crates/core/src/client/message_handlers.rs b/crates/core/src/client/message_handlers.rs index e2077d948e4..a272ee9f586 100644 --- a/crates/core/src/client/message_handlers.rs +++ b/crates/core/src/client/message_handlers.rs @@ -2,14 +2,16 @@ use super::messages::{SubscriptionUpdateMessage, SwitchedServerMessage, ToProtoc use super::{ClientConnection, DataMessage, Protocol}; use crate::energy::EnergyQuanta; use crate::host::module_host::{EventStatus, ModuleEvent, ModuleFunctionCall}; -use crate::host::{ReducerArgs, ReducerId}; +use crate::host::ReducerArgs; use crate::identity::Identity; use crate::messages::websocket::{CallReducer, ClientMessage, OneOffQuery}; use crate::worker_metrics::WORKER_METRICS; +use spacetimedb_client_api_messages::websocket::CallProcedure; use spacetimedb_datastore::execution_context::WorkloadType; use spacetimedb_lib::de::serde::DeserializeWrapper; use spacetimedb_lib::identity::RequestId; use spacetimedb_lib::{bsatn, ConnectionId, Timestamp}; +use spacetimedb_primitives::ReducerId; use std::borrow::Cow; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -129,9 +131,27 @@ pub async fn handle(client: &ClientConnection, message: DataMessage, timer: Inst .observe(timer.elapsed().as_secs_f64()); res.map_err(|err| (None, None, err)) } + ClientMessage::CallProcedure(CallProcedure { + ref procedure, + args, + request_id, + flags: _, + }) => { + let res = client.call_procedure(procedure, args, request_id, timer).await; + WORKER_METRICS + .request_round_trip + .with_label_values(&WorkloadType::Procedure, &database_identity, procedure) + .observe(timer.elapsed().as_secs_f64()); + if let Err(e) = res { + log::warn!("Procedure call failed: {e:#}"); + } + // `ClientConnection::call_procedure` handles sending the error message to the client if the call fails, + // so we don't need to return an `Err` here. + Ok(()) + } }; - res.map_err(|(reducer, reducer_id, err)| MessageExecutionError { - reducer: reducer.cloned(), + res.map_err(|(reducer_name, reducer_id, err)| MessageExecutionError { + reducer: reducer_name.cloned(), reducer_id, caller_identity: client.id.identity, caller_connection_id: Some(client.id.connection_id), diff --git a/crates/core/src/client/messages.rs b/crates/core/src/client/messages.rs index 67f3b90397b..bd5cbdfcdad 100644 --- a/crates/core/src/client/messages.rs +++ b/crates/core/src/client/messages.rs @@ -1,6 +1,6 @@ use super::{ClientConfig, DataMessage, Protocol}; use crate::host::module_host::{EventStatus, ModuleEvent}; -use crate::host::ArgsTuple; +use crate::host::{ArgsTuple, ProcedureCallError, ProcedureCallResult}; use crate::messages::websocket as ws; use crate::subscription::websocket_building::{brotli_compress, decide_compression, gzip_compress}; use bytes::{BufMut, Bytes, BytesMut}; @@ -13,7 +13,7 @@ use spacetimedb_client_api_messages::websocket::{ use spacetimedb_datastore::execution_context::WorkloadType; use spacetimedb_lib::identity::RequestId; use spacetimedb_lib::ser::serde::SerializeWrapper; -use spacetimedb_lib::{ConnectionId, TimeDuration}; +use spacetimedb_lib::{AlgebraicValue, ConnectionId, TimeDuration, Timestamp}; use spacetimedb_primitives::TableId; use spacetimedb_sats::bsatn; use std::sync::Arc; @@ -167,6 +167,7 @@ pub enum SerializableMessage { Subscribe(SubscriptionUpdateMessage), Subscription(SubscriptionMessage), TxUpdate(TransactionUpdateMessage), + ProcedureResult(ProcedureResultMessage), } impl SerializableMessage { @@ -177,7 +178,7 @@ impl SerializableMessage { Self::Subscribe(msg) => Some(msg.num_rows()), Self::Subscription(msg) => Some(msg.num_rows()), Self::TxUpdate(msg) => Some(msg.num_rows()), - Self::Identity(_) => None, + Self::Identity(_) | Self::ProcedureResult(_) => None, } } @@ -194,6 +195,7 @@ impl SerializableMessage { }, Self::TxUpdate(_) => Some(WorkloadType::Update), Self::Identity(_) => None, + Self::ProcedureResult(_) => Some(WorkloadType::Procedure), } } } @@ -208,6 +210,7 @@ impl ToProtocol for SerializableMessage { SerializableMessage::Subscribe(msg) => msg.to_protocol(protocol), SerializableMessage::TxUpdate(msg) => msg.to_protocol(protocol), SerializableMessage::Subscription(msg) => msg.to_protocol(protocol), + SerializableMessage::ProcedureResult(msg) => msg.to_protocol(protocol), } } } @@ -584,3 +587,87 @@ fn convert(msg: OneOffQueryResponseMessage) -> ws::Server total_host_execution_duration: msg.total_host_execution_duration, }) } + +#[derive(Debug)] +pub enum ProcedureStatus { + Returned(AlgebraicValue), + OutOfEnergy, + InternalError(String), +} + +#[derive(Debug)] +pub struct ProcedureResultMessage { + status: ProcedureStatus, + timestamp: Timestamp, + total_host_execution_duration: TimeDuration, + request_id: u32, +} + +impl ProcedureResultMessage { + pub fn from_result(res: &Result, request_id: RequestId) -> Self { + let (status, timestamp, execution_duration) = match res { + Ok(ProcedureCallResult { + return_val, + execution_duration, + start_timestamp, + }) => ( + ProcedureStatus::Returned(return_val.clone()), + *start_timestamp, + TimeDuration::from(*execution_duration), + ), + Err(err) => ( + match err { + ProcedureCallError::OutOfEnergy => ProcedureStatus::OutOfEnergy, + _ => ProcedureStatus::InternalError(format!("{err}")), + }, + Timestamp::UNIX_EPOCH, + TimeDuration::ZERO, + ), + }; + + ProcedureResultMessage { + status, + timestamp, + total_host_execution_duration: execution_duration, + request_id, + } + } +} + +impl ToProtocol for ProcedureResultMessage { + type Encoded = SwitchedServerMessage; + + fn to_protocol(self, protocol: Protocol) -> Self::Encoded { + fn convert( + msg: ProcedureResultMessage, + serialize_value: impl Fn(AlgebraicValue) -> F::Single, + ) -> ws::ServerMessage { + let ProcedureResultMessage { + status, + timestamp, + total_host_execution_duration, + request_id, + } = msg; + let status = match status { + ProcedureStatus::InternalError(msg) => ws::ProcedureStatus::InternalError(msg), + ProcedureStatus::OutOfEnergy => ws::ProcedureStatus::OutOfEnergy, + ProcedureStatus::Returned(val) => ws::ProcedureStatus::Returned(serialize_value(val)), + }; + ws::ServerMessage::ProcedureResult(ws::ProcedureResult { + status, + timestamp, + total_host_execution_duration, + request_id, + }) + } + + // Note that procedure returns are sent only to the caller, not broadcast to all subscribers, + // so we don't have to bother with memoizing the serialization the way we do for reducer args. + match protocol { + Protocol::Binary => FormatSwitch::Bsatn(convert(self, |val| bsatn::to_vec(&val).unwrap().into())), + Protocol::Text => FormatSwitch::Json(convert(self, |val| { + serde_json::to_string(&SerializeWrapper(val)).unwrap().into() + })), + } + } +} diff --git a/crates/core/src/host/host_controller.rs b/crates/core/src/host/host_controller.rs index f508bbbc5be..7decab10802 100644 --- a/crates/core/src/host/host_controller.rs +++ b/crates/core/src/host/host_controller.rs @@ -29,7 +29,7 @@ use spacetimedb_datastore::db_metrics::data_size::DATA_SIZE_METRICS; use spacetimedb_datastore::db_metrics::DB_METRICS; use spacetimedb_datastore::traits::Program; use spacetimedb_durability::{self as durability}; -use spacetimedb_lib::{hash_bytes, Identity}; +use spacetimedb_lib::{hash_bytes, AlgebraicValue, Identity, Timestamp}; use spacetimedb_paths::server::{ReplicaDir, ServerDataDir}; use spacetimedb_paths::FromPathUnchecked; use spacetimedb_sats::hash::Hash; @@ -170,6 +170,13 @@ impl From<&EventStatus> for ReducerOutcome { } } +#[derive(Clone, Debug)] +pub struct ProcedureCallResult { + pub return_val: AlgebraicValue, + pub execution_duration: Duration, + pub start_timestamp: Timestamp, +} + impl HostController { pub fn new( data_dir: Arc, diff --git a/crates/core/src/host/instance_env.rs b/crates/core/src/host/instance_env.rs index bcc552003ba..38406037c8b 100644 --- a/crates/core/src/host/instance_env.rs +++ b/crates/core/src/host/instance_env.rs @@ -178,7 +178,7 @@ impl InstanceEnv { } /// Signal to this `InstanceEnv` that a reducer call is beginning. - pub fn start_reducer(&mut self, ts: Timestamp) { + pub fn start_funcall(&mut self, ts: Timestamp) { self.start_time = ts; } diff --git a/crates/core/src/host/mod.rs b/crates/core/src/host/mod.rs index 05538542bbe..66d3570acc8 100644 --- a/crates/core/src/host/mod.rs +++ b/crates/core/src/host/mod.rs @@ -6,9 +6,8 @@ use enum_map::Enum; use once_cell::sync::OnceCell; use spacetimedb_lib::bsatn; use spacetimedb_lib::de::serde::SeedWrapper; -use spacetimedb_lib::de::DeserializeSeed; use spacetimedb_lib::ProductValue; -use spacetimedb_schema::def::deserialize::ReducerArgsDeserializeSeed; +use spacetimedb_schema::def::deserialize::{ArgsSeed, ProcedureArgsDeserializeSeed, ReducerArgsDeserializeSeed}; mod disk_storage; mod host_controller; @@ -25,12 +24,17 @@ mod wasm_common; pub use disk_storage::DiskStorage; pub use host_controller::{ - extract_schema, ExternalDurability, ExternalStorage, HostController, MigratePlanResult, ProgramStorage, - ReducerCallResult, ReducerOutcome, + extract_schema, ExternalDurability, ExternalStorage, HostController, MigratePlanResult, ProcedureCallResult, + ProgramStorage, ReducerCallResult, ReducerOutcome, +}; +pub use module_host::{ + ClientConnectedError, ModuleHost, NoSuchModule, ProcedureCallError, ReducerCallError, UpdateDatabaseResult, }; -pub use module_host::{ModuleHost, NoSuchModule, ReducerCallError, UpdateDatabaseResult}; pub use scheduler::Scheduler; +/// Encoded arguments to a database function. +/// +/// Despite the name, this may be arguments to either a reducer or a procedure. #[derive(Debug)] pub enum ReducerArgs { Json(ByteString), @@ -39,13 +43,22 @@ pub enum ReducerArgs { } impl ReducerArgs { + fn into_tuple_for_procedure( + self, + seed: ProcedureArgsDeserializeSeed, + ) -> Result { + self._into_tuple(seed).map_err(|err| InvalidProcedureArguments { + err, + procedure: (*seed.inner_def().name).into(), + }) + } fn into_tuple(self, seed: ReducerArgsDeserializeSeed) -> Result { self._into_tuple(seed).map_err(|err| InvalidReducerArguments { err, - reducer: (*seed.reducer_def().name).into(), + reducer: (*seed.inner_def().name).into(), }) } - fn _into_tuple(self, seed: ReducerArgsDeserializeSeed) -> anyhow::Result { + fn _into_tuple(self, seed: impl ArgsSeed) -> anyhow::Result { Ok(match self { ReducerArgs::Json(json) => ArgsTuple { tuple: from_json_seed(&json, SeedWrapper(seed))?, @@ -58,10 +71,7 @@ impl ReducerArgs { json: OnceCell::new(), }, ReducerArgs::Nullary => { - anyhow::ensure!( - seed.reducer_def().params.elements.is_empty(), - "failed to typecheck args" - ); + anyhow::ensure!(seed.params().elements.is_empty(), "failed to typecheck args"); ArgsTuple::nullary() } }) @@ -114,6 +124,14 @@ pub struct InvalidReducerArguments { reducer: Box, } +#[derive(thiserror::Error, Debug)] +#[error("invalid arguments for procedure {procedure}: {err}")] +pub struct InvalidProcedureArguments { + #[source] + err: anyhow::Error, + procedure: Box, +} + fn from_json_seed<'de, T: serde::de::DeserializeSeed<'de>>(s: &'de str, seed: T) -> anyhow::Result { let mut de = serde_json::Deserializer::from_str(s); let mut track = serde_path_to_error::Track::new(); @@ -147,4 +165,6 @@ pub enum AbiCall { Identity, VolatileNonatomicScheduleImmediate, + + ProcedureSleepUntil, } diff --git a/crates/core/src/host/module_host.rs b/crates/core/src/host/module_host.rs index 8d03f012f05..cbce5641769 100644 --- a/crates/core/src/host/module_host.rs +++ b/crates/core/src/host/module_host.rs @@ -1,4 +1,8 @@ -use super::{ArgsTuple, InvalidReducerArguments, ReducerArgs, ReducerCallResult, ReducerId, ReducerOutcome, Scheduler}; +use super::host_controller::ProcedureCallResult; +use super::{ + ArgsTuple, InvalidProcedureArguments, InvalidReducerArguments, ReducerArgs, ReducerCallResult, ReducerId, + ReducerOutcome, Scheduler, +}; use crate::client::messages::{OneOffQueryResponseMessage, SerializableMessage}; use crate::client::{ClientActorId, ClientConnectionSender}; use crate::database_logger::{LogLevel, Record}; @@ -43,16 +47,17 @@ use spacetimedb_lib::identity::{AuthCtx, RequestId}; use spacetimedb_lib::metrics::ExecutionMetrics; use spacetimedb_lib::ConnectionId; use spacetimedb_lib::Timestamp; -use spacetimedb_primitives::TableId; +use spacetimedb_primitives::{ProcedureId, TableId}; use spacetimedb_query::compile_subscription; use spacetimedb_sats::ProductValue; use spacetimedb_schema::auto_migrate::{AutoMigrateError, MigrationPolicy}; -use spacetimedb_schema::def::deserialize::ReducerArgsDeserializeSeed; -use spacetimedb_schema::def::{ModuleDef, ReducerDef, TableDef}; +use spacetimedb_schema::def::deserialize::{ProcedureArgsDeserializeSeed, ReducerArgsDeserializeSeed}; +use spacetimedb_schema::def::{ModuleDef, ProcedureDef, ReducerDef, TableDef}; use spacetimedb_schema::schema::{Schema, TableSchema}; use spacetimedb_vm::relation::RelValue; use std::collections::VecDeque; use std::fmt; +use std::future::Future; use std::sync::atomic::AtomicBool; use std::sync::{Arc, Weak}; use std::time::{Duration, Instant}; @@ -395,6 +400,13 @@ impl Instance { Instance::Js(inst) => inst.call_reducer(tx, params), } } + + async fn call_procedure(&mut self, params: CallProcedureParams) -> Result { + match self { + Instance::Wasm(inst) => inst.call_procedure(params).await, + Instance::Js(inst) => inst.call_procedure(params).await, + } + } } /// Creates the table for `table_def` in `stdb`. @@ -493,6 +505,15 @@ pub struct CallReducerParams { pub args: ArgsTuple, } +pub struct CallProcedureParams { + pub timestamp: Timestamp, + pub caller_identity: Identity, + pub caller_connection_id: ConnectionId, + pub timer: Option, + pub procedure_id: ProcedureId, + pub args: ArgsTuple, +} + /// Holds a [`Module`] and a set of [`Instance`]s from it, /// and allocates the [`Instance`]s to be used for function calls. /// @@ -641,6 +662,20 @@ pub enum ReducerCallError { LifecycleReducer(Lifecycle), } +#[derive(thiserror::Error, Debug)] +pub enum ProcedureCallError { + #[error(transparent)] + Args(#[from] InvalidProcedureArguments), + #[error(transparent)] + NoSuchModule(#[from] NoSuchModule), + #[error("no such procedure")] + NoSuchProcedure, + #[error("Procedure terminated due to insufficient budget")] + OutOfEnergy, + #[error("The WASM instance encountered a fatal error: {0}")] + InternalError(String), +} + #[derive(thiserror::Error, Debug)] pub enum InitDatabaseError { #[error(transparent)] @@ -761,6 +796,35 @@ impl ModuleHost { }) } + async fn call_async_with_instance(&self, label: &str, f: Fun) -> Result + where + Fun: (FnOnce(Instance) -> Fut) + Send + 'static, + Fut: Future + Send + 'static, + R: Send + 'static, + { + self.guard_closed()?; + let timer_guard = self.start_call_timer(label); + + scopeguard::defer_on_unwind!({ + log::warn!("procedure {label} panicked"); + (self.on_panic)(); + }); + + let instance = self.instance_manager.lock().await.get_instance(); + + let (res, instance) = self + .executor + .run_job(async move { + drop(timer_guard); + f(instance).await + }) + .await; + + self.instance_manager.lock().await.return_instance(instance); + + Ok(res) + } + /// Run a function on the JobThread for this module which has access to the module instance. async fn call(&self, label: &str, f: F) -> Result where @@ -778,6 +842,7 @@ impl ModuleHost { // If a reducer call panics, we **must** ensure to call `self.on_panic` // so that the module is discarded by the host controller. + // TODO(pgoldman,procedures): Determine if this is still true. scopeguard::defer_on_unwind!({ log::warn!("reducer {label} panicked"); (self.on_panic)(); @@ -1194,6 +1259,79 @@ impl ModuleHost { res } + pub async fn call_procedure( + &self, + caller_identity: Identity, + caller_connection_id: Option, + timer: Option, + procedure_name: &str, + args: ReducerArgs, + ) -> Result { + let res = async { + let (procedure_id, procedure_def) = self + .info + .module_def + .procedure_full(procedure_name) + .ok_or(ProcedureCallError::NoSuchProcedure)?; + self.call_procedure_inner( + caller_identity, + caller_connection_id, + timer, + procedure_id, + procedure_def, + args, + ) + .await + } + .await; + + let log_message = match &res { + Err(ProcedureCallError::NoSuchProcedure) => Some(format!( + "External attempt to call nonexistent procedure \"{procedure_name}\" failed. Have you run `spacetime generate` recently?" + )), + Err(ProcedureCallError::Args(_)) => Some(format!( + "External attempt to call procedure \"{procedure_name}\" failed, invalid arguments.\n\ + This is likely due to a mismatched client schema, have you run `spacetime generate` recently?" + )), + _ => None, + }; + + if let Some(log_message) = log_message { + self.inject_logs(LogLevel::Error, procedure_name, &log_message) + } + + res + } + + async fn call_procedure_inner( + &self, + caller_identity: Identity, + caller_connection_id: Option, + timer: Option, + procedure_id: ProcedureId, + procedure_def: &ProcedureDef, + args: ReducerArgs, + ) -> Result { + let procedure_seed = ProcedureArgsDeserializeSeed(self.info.module_def.typespace().with_type(procedure_def)); + let args = args.into_tuple_for_procedure(procedure_seed)?; + let caller_connection_id = caller_connection_id.unwrap_or(ConnectionId::ZERO); + + self.call_async_with_instance(&procedure_def.name, async move |mut inst| { + let res = inst + .call_procedure(CallProcedureParams { + timestamp: Timestamp::now(), + caller_identity, + caller_connection_id, + timer, + procedure_id, + args, + }) + .await; + (res, inst) + }) + .await? + } + // Scheduled reducers require a different function here to call their reducer // because their reducer arguments are stored in the database and need to be fetched // within the same transaction as the reducer call. @@ -1277,11 +1415,11 @@ impl ModuleHost { self.module.scheduler().closed().await; } - pub fn inject_logs(&self, log_level: LogLevel, reducer_name: &str, message: &str) { + pub fn inject_logs(&self, log_level: LogLevel, function_name: &str, message: &str) { self.replica_ctx().logger.write( log_level, &Record { - function: Some(reducer_name), + function: Some(function_name), ..Record::injected(message) }, &(), diff --git a/crates/core/src/host/v8/mod.rs b/crates/core/src/host/v8/mod.rs index 809e3da097f..6585dfa4887 100644 --- a/crates/core/src/host/v8/mod.rs +++ b/crates/core/src/host/v8/mod.rs @@ -1,7 +1,7 @@ #![allow(dead_code)] use super::module_common::{build_common_module_from_raw, run_describer, ModuleCommon}; -use super::module_host::{CallReducerParams, Module, ModuleInfo, ModuleRuntime}; +use super::module_host::{CallProcedureParams, CallReducerParams, Module, ModuleInfo, ModuleRuntime}; use super::UpdateDatabaseResult; use crate::host::instance_env::{ChunkPool, InstanceEnv}; use crate::host::wasm_common::instrumentation::CallTimes; @@ -52,7 +52,7 @@ pub struct V8Runtime { impl ModuleRuntime for V8Runtime { fn make_actor(&self, mcc: ModuleCreationContext<'_>) -> anyhow::Result { - V8_RUNTIME_GLOBAL.make_actor(mcc) + V8_RUNTIME_GLOBAL.make_actor(mcc).map(Module::Js) } } @@ -375,6 +375,13 @@ impl JsInstance { (tx, exec_result) }) } + + pub async fn call_procedure( + &mut self, + _params: CallProcedureParams, + ) -> Result { + todo!("JS/TS module procedure support") + } } fn with_script( diff --git a/crates/core/src/host/wasm_common.rs b/crates/core/src/host/wasm_common.rs index 08a492f9a4b..8b91314905e 100644 --- a/crates/core/src/host/wasm_common.rs +++ b/crates/core/src/host/wasm_common.rs @@ -14,6 +14,8 @@ use spacetimedb_table::table::UniqueConstraintViolation; pub const CALL_REDUCER_DUNDER: &str = "__call_reducer__"; +pub const CALL_PROCEDURE_DUNDER: &str = "__call_procedure__"; + pub const DESCRIBE_MODULE_DUNDER: &str = "__describe_module__"; /// functions with this prefix run prior to __setup__, initializing global variables and the like @@ -384,8 +386,8 @@ pub struct AbiRuntimeError { } macro_rules! abi_funcs { - ($mac:ident) => { - $mac! { + ($link_sync:ident , $link_async:ident) => { + $link_sync! { "spacetime_10.0"::table_id_from_name, "spacetime_10.0"::datastore_table_row_count, "spacetime_10.0"::datastore_table_scan_bsatn, @@ -411,6 +413,9 @@ macro_rules! abi_funcs { "spacetime_10.1"::bytes_source_remaining_length, } + $link_async! { + "spacetime_10.1"::procedure_sleep_until, + } }; } pub(crate) use abi_funcs; diff --git a/crates/core/src/host/wasm_common/module_host_actor.rs b/crates/core/src/host/wasm_common/module_host_actor.rs index 16d2d585819..300fa2b6b01 100644 --- a/crates/core/src/host/wasm_common/module_host_actor.rs +++ b/crates/core/src/host/wasm_common/module_host_actor.rs @@ -1,6 +1,10 @@ +use bytes::Bytes; use prometheus::{Histogram, IntCounter, IntGauge}; use spacetimedb_lib::db::raw_def::v9::Lifecycle; +use spacetimedb_lib::de::DeserializeSeed; +use spacetimedb_primitives::ProcedureId; use spacetimedb_schema::auto_migrate::{MigratePlan, MigrationPolicy, MigrationPolicyError}; +use std::future::Future; use std::sync::Arc; use std::time::Duration; use tracing::span::EnteredSpan; @@ -12,9 +16,12 @@ use crate::energy::{EnergyMonitor, ReducerBudget, ReducerFingerprint}; use crate::host::instance_env::InstanceEnv; use crate::host::module_common::{build_common_module_from_raw, ModuleCommon}; use crate::host::module_host::{ - CallReducerParams, DatabaseUpdate, EventStatus, ModuleEvent, ModuleFunctionCall, ModuleInfo, + CallProcedureParams, CallReducerParams, DatabaseUpdate, EventStatus, ModuleEvent, ModuleFunctionCall, ModuleInfo, +}; +use crate::host::{ + ArgsTuple, ProcedureCallError, ProcedureCallResult, ReducerCallResult, ReducerId, ReducerOutcome, Scheduler, + UpdateDatabaseResult, }; -use crate::host::{ArgsTuple, ReducerCallResult, ReducerId, ReducerOutcome, Scheduler, UpdateDatabaseResult}; use crate::identity::Identity; use crate::messages::control_db::HostType; use crate::module_host_context::ModuleCreationContext; @@ -48,6 +55,7 @@ pub trait WasmInstancePre: Send + Sync + 'static { fn instantiate(&self, env: InstanceEnv, func_names: &FuncNames) -> Result; } +#[async_trait::async_trait] pub trait WasmInstance: Send + Sync + 'static { fn extract_descriptions(&mut self) -> Result, DescribeError>; @@ -56,6 +64,8 @@ pub trait WasmInstance: Send + Sync + 'static { fn call_reducer(&mut self, op: ReducerOp<'_>, budget: ReducerBudget) -> ExecuteResult; fn log_traceback(func_type: &str, func: &str, trap: &anyhow::Error); + + async fn call_procedure(&mut self, op: ProcedureOp, budget: ReducerBudget) -> ProcedureExecuteResult; } pub struct EnergyStats { @@ -64,6 +74,11 @@ pub struct EnergyStats { } impl EnergyStats { + pub const ZERO: Self = Self { + budget: ReducerBudget::ZERO, + remaining: ReducerBudget::ZERO, + }; + /// Returns the used energy amount. fn used(&self) -> ReducerBudget { (self.budget.get() - self.remaining.get()).into() @@ -75,6 +90,17 @@ pub struct ExecutionTimings { pub wasm_instance_env_call_times: CallTimes, } +impl ExecutionTimings { + /// Not a `const` because there doesn't seem to be any way to `const` construct an `enum_map::EnumMap`, + /// which `CallTimes` uses. + pub fn zero() -> Self { + Self { + total_duration: Duration::ZERO, + wasm_instance_env_call_times: CallTimes::new(), + } + } +} + pub struct ExecuteResult { pub energy: EnergyStats, pub timings: ExecutionTimings, @@ -82,6 +108,15 @@ pub struct ExecuteResult { pub call_result: Result>, anyhow::Error>, } +pub struct ProcedureExecuteResult { + #[allow(unused)] + pub energy: EnergyStats, + #[allow(unused)] + pub timings: ExecutionTimings, + pub memory_allocation: usize, + pub call_result: Result, +} + pub struct WasmModuleHostActor { module: T::InstancePre, initial_instance: Option>>, @@ -229,6 +264,19 @@ impl WasmModuleInstance { pub fn call_reducer(&mut self, tx: Option, params: CallReducerParams) -> ReducerCallResult { crate::callgrind_flag::invoke_allowing_callgrind(|| self.call_reducer_with_tx(tx, params)) } + + pub async fn call_procedure( + &mut self, + params: CallProcedureParams, + ) -> Result { + self.common + .call_procedure( + params, + |ty, fun, err| T::log_traceback(ty, fun, err), + |op, budget| self.instance.call_procedure(op, budget), + ) + .await + } } impl WasmModuleInstance { @@ -331,6 +379,103 @@ impl InstanceCommon { } } + async fn call_procedure>( + &mut self, + params: CallProcedureParams, + log_traceback: impl FnOnce(&str, &str, &anyhow::Error), + vm_call_procedure: impl FnOnce(ProcedureOp, ReducerBudget) -> F, + ) -> Result { + let CallProcedureParams { + timestamp, + caller_identity, + caller_connection_id, + timer, + procedure_id, + args, + } = params; + + // We've already validated by this point that the procedure exists, + // so it's fine to use the panicking `procedure_by_id`. + let procedure_def = self.info.module_def.procedure_by_id(procedure_id); + let procedure_name: &str = &procedure_def.name; + + // TODO(observability): Add tracing spans, energy, metrics? + // These will require further thinking once we implement procedure suspend/resume, + // and so are not worth doing yet. + + let op = ProcedureOp { + id: procedure_id, + name: procedure_name.into(), + caller_identity, + caller_connection_id, + timestamp, + arg_bytes: args.get_bsatn().clone(), + }; + + let energy_fingerprint = ReducerFingerprint { + module_hash: self.info.module_hash, + module_identity: self.info.owner_identity, + caller_identity, + reducer_name: &procedure_def.name, + }; + + // TODO: replace with call to separate function `procedure_budget`. + let budget = self.energy_monitor.reducer_budget(&energy_fingerprint); + + let result = vm_call_procedure(op, budget).await; + + let ProcedureExecuteResult { + memory_allocation, + call_result, + // TODO: Do something with timing and energy. + .. + } = result; + + if self.allocated_memory != memory_allocation { + self.metric_wasm_memory_bytes.set(memory_allocation as i64); + self.allocated_memory = memory_allocation; + } + + match call_result { + Err(err) => { + log_traceback("procedure", &procedure_def.name, &err); + + WORKER_METRICS + .wasm_instance_errors + .with_label_values( + &caller_identity, + &self.info.module_hash, + &caller_connection_id, + procedure_name, + ) + .inc(); + + self.trapped = true; + + // if energy.remaining.get() == 0 { + // return Err(ProcedureCallError::OutOfEnergy); + // } else + { + return Err(ProcedureCallError::InternalError(format!("{err}"))); + } + } + Ok(return_val) => { + // TODO: deserialize return_val at its appropriate type, which you get out of the procedure def, + // then return it in `Ok`. + let return_type = &procedure_def.return_type; + let seed = spacetimedb_sats::WithTypespace::new(self.info.module_def.typespace(), return_type); + let return_val = seed + .deserialize(bsatn::Deserializer::new(&mut &return_val[..])) + .map_err(|err| ProcedureCallError::InternalError(format!("{err}")))?; + Ok(ProcedureCallResult { + return_val, + execution_duration: timer.map(|timer| timer.elapsed()).unwrap_or_default(), + start_timestamp: timestamp, + }) + } + } + } + /// Execute a reducer. /// /// If `Some` [`MutTxId`] is supplied, the reducer is called within the @@ -669,3 +814,13 @@ impl From> for execution_context::ReducerContext { } } } + +#[derive(Clone, Debug)] +pub struct ProcedureOp { + pub id: ProcedureId, + pub name: Box, + pub caller_identity: Identity, + pub caller_connection_id: ConnectionId, + pub timestamp: Timestamp, + pub arg_bytes: Bytes, +} diff --git a/crates/core/src/host/wasmtime/wasm_instance_env.rs b/crates/core/src/host/wasmtime/wasm_instance_env.rs index c41c4737034..ce6c570b659 100644 --- a/crates/core/src/host/wasmtime/wasm_instance_env.rs +++ b/crates/core/src/host/wasmtime/wasm_instance_env.rs @@ -1,5 +1,9 @@ #![allow(clippy::too_many_arguments)] +use std::future::Future; +use std::num::NonZeroU32; +use std::time::Instant; + use super::{Mem, MemView, NullableMemOp, WasmError, WasmPointee, WasmPtr}; use crate::database_logger::{BacktraceFrame, BacktraceProvider, ModuleBacktrace, Record}; use crate::host::instance_env::{ChunkPool, InstanceEnv}; @@ -102,8 +106,8 @@ pub(super) struct WasmInstanceEnv { /// Track time spent in module-defined spans. timing_spans: TimingSpanSet, - /// The point in time the last reducer call started at. - reducer_start: Instant, + /// The point in time the last, or current, reducer or procedure call started at. + funcall_start: Instant, /// Track time spent in all wasm instance env calls (aka syscall time). /// @@ -111,8 +115,7 @@ pub(super) struct WasmInstanceEnv { /// to this tracker. call_times: CallTimes, - /// The last, including current, reducer to be executed by this environment. - reducer_name: String, + funcall_name: String, /// A pool of unused allocated chunks that can be reused. // TODO(Centril): consider using this pool for `console_timer_start` and `bytes_sink_write`. @@ -129,7 +132,7 @@ type RtResult = anyhow::Result; impl WasmInstanceEnv { /// Create a new `WasmEnstanceEnv` from the given `InstanceEnv`. pub fn new(instance_env: InstanceEnv) -> Self { - let reducer_start = Instant::now(); + let funcall_start = Instant::now(); Self { instance_env, mem: None, @@ -138,9 +141,9 @@ impl WasmInstanceEnv { standard_bytes_sink: None, iters: Default::default(), timing_spans: Default::default(), - reducer_start, + funcall_start, call_times: CallTimes::new(), - reducer_name: String::from(""), + funcall_name: String::from(""), chunk_pool: <_>::default(), } } @@ -223,44 +226,50 @@ impl WasmInstanceEnv { /// /// Returns the handle used by reducers to read from `args` /// as well as the handle used to write the error message, if any. - pub fn start_reducer(&mut self, name: &str, args: bytes::Bytes, ts: Timestamp) -> (BytesSourceId, u32) { + pub fn start_funcall(&mut self, name: &str, args: bytes::Bytes, ts: Timestamp) -> (BytesSourceId, u32) { + // Create the output sink. + // Reducers which fail will write their error message here. + // Procedures will write their result here. let errors = self.setup_standard_bytes_sink(); let args = self.create_bytes_source(args).unwrap(); - self.reducer_start = Instant::now(); - name.clone_into(&mut self.reducer_name); - self.instance_env.start_reducer(ts); + self.funcall_start = Instant::now(); + name.clone_into(&mut self.funcall_name); + self.instance_env.start_funcall(ts); (args, errors) } - /// Returns the name of the most recent reducer to be run in this environment. - pub fn reducer_name(&self) -> &str { - &self.reducer_name + /// Returns the name of the most recent reducer or procedure to be run in this environment. + pub fn funcall_name(&self) -> &str { + &self.funcall_name + } + + /// Returns the start time of the current or most recent reducer or procedure to be run in this environment. + pub fn funcall_start(&self) -> Instant { + self.funcall_start } /// Returns the name of the most recent reducer to be run in this environment, /// or `None` if no reducer is actively being invoked. fn log_record_function(&self) -> Option<&str> { - let function = self.reducer_name(); + let function = self.funcall_name(); (!function.is_empty()).then_some(function) } - /// Returns the name of the most recent reducer to be run in this environment. - pub fn reducer_start(&self) -> Instant { - self.reducer_start - } - - /// Signal to this `WasmInstanceEnv` that a reducer call is over. - /// This resets all of the state associated to a single reducer call, - /// and returns instrumentation records. - pub fn finish_reducer(&mut self) -> (ExecutionTimings, Vec) { + /// Signal to this `WasmInstanceEnv` that a reducer or procedure call is over. + /// + /// Returns time measurements which can be recorded as metrics, + /// and the errors written by the WASM code to hte standard error sink. + /// + /// This resets the call times and clears the arguments source and error sink. + pub fn finish_funcall(&mut self) -> (ExecutionTimings, Vec) { // For the moment, // we only explicitly clear the source/sink buffers and the "syscall" times. // TODO: should we be clearing `iters` and/or `timing_spans`? - let total_duration = self.reducer_start.elapsed(); + let total_duration = self.funcall_start.elapsed(); // Taking the call times record also resets timings to 0s for the next call. let wasm_instance_env_call_times = self.call_times.take(); @@ -1326,6 +1335,36 @@ impl WasmInstanceEnv { Ok(()) }) } + + pub fn procedure_sleep_until<'caller>( + mut caller: Caller<'caller, Self>, + (wake_at_micros_since_unix_epoch,): (i64,), + ) -> Box + Send + 'caller> { + Box::new(async move { + use std::time::SystemTime; + let span_start = span::CallSpanStart::new(AbiCall::ProcedureSleepUntil); + + let get_current_time = || Timestamp::now().to_micros_since_unix_epoch(); + + if wake_at_micros_since_unix_epoch < 0 { + return get_current_time(); + } + + let wake_at = Timestamp::from_micros_since_unix_epoch(wake_at_micros_since_unix_epoch); + let Ok(duration) = SystemTime::from(wake_at).duration_since(SystemTime::now()) else { + return get_current_time(); + }; + + tokio::time::sleep(duration).await; + + let res = get_current_time(); + + let span = span_start.end(); + span::record_span(&mut caller.data_mut().call_times, span); + + res + }) + } } impl BacktraceProvider for wasmtime::StoreContext<'_, T> { diff --git a/crates/core/src/host/wasmtime/wasmtime_module.rs b/crates/core/src/host/wasmtime/wasmtime_module.rs index c2906f65ee3..37fad0fb4aa 100644 --- a/crates/core/src/host/wasmtime/wasmtime_module.rs +++ b/crates/core/src/host/wasmtime/wasmtime_module.rs @@ -9,6 +9,7 @@ use crate::host::wasm_common::module_host_actor::{DescribeError, InitializationE use crate::host::wasm_common::*; use crate::util::string_from_utf8_lossy_owned; use futures_util::FutureExt; +use spacetimedb_lib::{ConnectionId, Identity}; use spacetimedb_primitives::errno::HOST_CALL_FAILURE; use wasmtime::{ AsContext, AsContextMut, ExternType, Instance, InstancePre, Linker, Store, TypedFunc, WasmBacktrace, WasmParams, @@ -49,7 +50,13 @@ impl WasmtimeModule { linker$(.func_wrap($module, stringify!($func), WasmInstanceEnv::$func)?)*; } } - abi_funcs!(link_functions); + macro_rules! link_async_functions { + ($($module:literal :: $func:ident,)*) => { + #[allow(deprecated)] + linker$(.func_wrap_async($module, stringify!($func), WasmInstanceEnv::$func)?)*; + } + } + abi_funcs!(link_functions, link_async_functions); Ok(()) } } @@ -126,9 +133,9 @@ impl module_host_actor::WasmInstancePre for WasmtimeModule { store.epoch_deadline_callback(|store| { let env = store.data(); let database = env.instance_env().replica_ctx.database_identity; - let reducer = env.reducer_name(); - let dur = env.reducer_start().elapsed(); - tracing::warn!(reducer, ?database, "Wasm has been running for {dur:?}"); + let funcall = env.funcall_name(); + let dur = env.funcall_start().elapsed(); + tracing::warn!(funcall, ?database, "Wasm has been running for {dur:?}"); Ok(wasmtime::UpdateDeadline::Continue(EPOCH_TICKS_PER_SECOND)) }); @@ -162,22 +169,80 @@ impl module_host_actor::WasmInstancePre for WasmtimeModule { .get_typed_func(&mut store, CALL_REDUCER_DUNDER) .expect("no call_reducer"); + let call_procedure = get_call_procedure(&mut store, &instance); + Ok(WasmtimeInstance { store, instance, call_reducer, + call_procedure, }) } } -type CallReducerType = TypedFunc<(u32, u64, u64, u64, u64, u64, u64, u64, u32, u32), i32>; +/// Look up the `instance`'s export named by [`CALL_PROCEDURE_DUNDER`]. +/// +/// Return `None` if the `instance` has no such export. +/// Modules from before the introduction of procedures will not have a `__call_procedure__` export, +/// which is fine because they also won't define any procedures. +/// +/// Panicks if the `instance` has an export at the expected name, +/// but it is not a function or is a function of an inappropriate type. +/// For new modules, this will be caught during publish. +/// Old modules from before the introduction of procedures might have an export at that name, +/// but it follows the double-underscore pattern of reserved names, +/// so we're fine to break those modules. +fn get_call_procedure(store: &mut Store, instance: &Instance) -> Option { + // Wasmtime uses `anyhow` for error reporting, vexing library users the world over. + // This means we can't distinguish between the failure modes of `Instance::get_typed_func`. + // Instead, we type out the body of that method ourselves, + // but with error handling appropriate to our needs. + let export = instance.get_export(store.as_context_mut(), CALL_PROCEDURE_DUNDER)?; + + Some( + export + .into_func() + .expect(&format!("{CALL_PROCEDURE_DUNDER} export is not a function")) + .typed(store) + .expect(&format!( + "{CALL_PROCEDURE_DUNDER} export is a function with incorrect type" + )), + ) +} + +type CallReducerType = TypedFunc< + ( + // Reducer ID, + u32, + // Sender `Identity` + u64, + u64, + u64, + u64, + // Sender `ConnectionId`, or 0 for none. + u64, + u64, + // Start timestamp. + u64, + // Args byte source. + u32, + // Errors byte sink. + u32, + ), + // Errno. + i32, +>; +// `__call_procedure__` takes the same arguments as `__call_reducer__`. +type CallProcedureType = CallReducerType; pub struct WasmtimeInstance { store: Store, instance: Instance, call_reducer: CallReducerType, + call_procedure: Option, } +#[async_trait::async_trait] impl module_host_actor::WasmInstance for WasmtimeInstance { fn extract_descriptions(&mut self) -> Result, DescribeError> { let describer_func_name = DESCRIBE_MODULE_DUNDER; @@ -206,18 +271,16 @@ impl module_host_actor::WasmInstance for WasmtimeInstance { #[tracing::instrument(level = "trace", skip_all)] fn call_reducer(&mut self, op: ReducerOp<'_>, budget: ReducerBudget) -> module_host_actor::ExecuteResult { let store = &mut self.store; - // Set the fuel budget in WASM. - set_store_fuel(store, budget.into()); - store.set_epoch_deadline(EPOCH_TICKS_PER_SECOND); + let original_fuel = prepare_store_for_call_and_get_starting_fuel(store, budget); // Prepare sender identity and connection ID, as LITTLE-ENDIAN byte arrays. - let [sender_0, sender_1, sender_2, sender_3] = bytemuck::must_cast(op.caller_identity.to_byte_array()); - let [conn_id_0, conn_id_1] = bytemuck::must_cast(op.caller_connection_id.as_le_byte_array()); + let [sender_0, sender_1, sender_2, sender_3] = prepare_identity_for_call(*op.caller_identity); + let [conn_id_0, conn_id_1] = prepare_connection_id_for_call(*op.caller_connection_id); // Prepare arguments to the reducer + the error sink & start timings. let args_bytes = op.args.get_bsatn().clone(); - let (args_source, errors_sink) = store.data_mut().start_reducer(op.name, args_bytes, op.timestamp); + let (args_source, errors_sink) = store.data_mut().start_funcall(op.name, args_bytes, op.timestamp); let call_result = call_sync_typed_func( &self.call_reducer, @@ -239,7 +302,7 @@ impl module_host_actor::WasmInstance for WasmtimeInstance { // Signal that this reducer call is finished. This gets us the timings // associated to our reducer call, and clears all of the instance state // associated to the call. - let (timings, error) = store.data_mut().finish_reducer(); + let (timings, error) = store.data_mut().finish_funcall(); let call_result = call_result.map(|code| handle_error_sink_code(code, error)); @@ -257,6 +320,77 @@ impl module_host_actor::WasmInstance for WasmtimeInstance { } } + #[tracing::instrument(level = "trace", skip_all)] + async fn call_procedure( + &mut self, + op: module_host_actor::ProcedureOp, + budget: ReducerBudget, + ) -> module_host_actor::ProcedureExecuteResult { + let store = &mut self.store; + let original_fuel = prepare_store_for_call_and_get_starting_fuel(store, budget); + + // Prepare sender identity and connection ID, as LITTLE-ENDIAN byte arrays. + let [sender_0, sender_1, sender_2, sender_3] = prepare_identity_for_call(op.caller_identity); + let [conn_id_0, conn_id_1] = prepare_connection_id_for_call(op.caller_connection_id); + + // Prepare arguments to the reducer + the error sink & start timings. + let (args_source, result_sink) = store.data_mut().start_funcall(&op.name, op.arg_bytes, op.timestamp); + + let Some(call_procedure) = self.call_procedure.as_ref() else { + return module_host_actor::ProcedureExecuteResult { + energy: module_host_actor::EnergyStats::ZERO, + timings: module_host_actor::ExecutionTimings::zero(), + memory_allocation: get_memory_size(store), + call_result: Err(anyhow::anyhow!( + "Module defines procedure {} but does not export `{}`", + op.name, + CALL_PROCEDURE_DUNDER, + )), + }; + }; + let call_result = call_procedure + .call_async( + &mut *store, + ( + op.id.0, + sender_0, + sender_1, + sender_2, + sender_3, + conn_id_0, + conn_id_1, + op.timestamp.to_micros_since_unix_epoch() as u64, + args_source.0, + result_sink, + ), + ) + .await; + + // Close the timing span for this procedure and get the BSATN bytes of its result. + let (timings, result_bytes) = store.data_mut().finish_funcall(); + + let call_result = call_result.and_then(|code| { + (code == 0).then_some(result_bytes.into()).ok_or_else(|| { + anyhow::anyhow!( + "{CALL_PROCEDURE_DUNDER} returned unexpected code {code}. Procedures should return code 0 or trap." + ) + }) + }); + + let remaining_fuel = get_store_fuel(store); + let remaining = ReducerBudget::from(remaining_fuel); + + let energy = module_host_actor::EnergyStats { budget, remaining }; + let memory_allocation = get_memory_size(store); + + module_host_actor::ProcedureExecuteResult { + energy, + timings, + memory_allocation, + call_result, + } + } + fn log_traceback(func_type: &str, func: &str, trap: &anyhow::Error) { log_traceback(func_type, func, trap) } @@ -270,6 +404,55 @@ fn get_store_fuel(store: &impl AsContext) -> WasmtimeFuel { WasmtimeFuel(store.as_context().get_fuel().unwrap()) } +fn prepare_store_for_call_and_get_starting_fuel( + store: &mut Store, + budget: ReducerBudget, +) -> WasmtimeFuel { + // note that ReducerBudget being a u64 is load-bearing here - although we convert budget right back into + // EnergyQuanta at the end of this function, from_energy_quanta clamps it to a u64 range. + // otherwise, we'd return something like `used: i128::MAX - u64::MAX`, which is inaccurate. + set_store_fuel(store, budget.into()); + let original_fuel = get_store_fuel(store); + + // This seems odd, as we don't use epoch interruption, at least as far as I (pgoldman 2025-08-22) know. + // But this call was here prior to my last edit. + // The previous line git-blames to https://github.com/clockworklabs/spacetimeDB/pull/2738 . + store.set_epoch_deadline(EPOCH_TICKS_PER_SECOND); + + original_fuel +} + +/// Convert `caller_identity` to the format used by `__call_reducer__` and `__call_procedure__`, +/// i.e. an array of 4 `u64`s. +/// +/// Callers should destructure this like: +/// ```rust +/// # let identity = Identity::ZERO; +/// let [sender_0, sender_1, sender_2, sender_3] = prepare_identity_for_call(identity); +/// ``` +fn prepare_identity_for_call(caller_identity: Identity) -> [u64; 4] { + // Encode this as a LITTLE-ENDIAN byte array + bytemuck::must_cast(caller_identity.to_byte_array()) +} + +/// Convert `caller_connection_id` to the format used by `__call_reducer` and `__call_procedure__`, +/// i.e. an array of 2 `u64`s. +/// +/// Callers should destructure this like: +/// ```rust +/// # let connection_id = ConnectionId::ZERO; +/// let [conn_id_0, conn_id_1] = prepare_connection_id_for_call(connection_id); +/// ``` +/// +fn prepare_connection_id_for_call(caller_connection_id: ConnectionId) -> [u64; 2] { + // Encode this as a LITTLE-ENDIAN byte array + bytemuck::must_cast(caller_connection_id.as_le_byte_array()) +} + +fn get_memory_size(store: &Store) -> usize { + store.data().get_mem().memory.data_size(store) +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/core/src/subscription/module_subscription_actor.rs b/crates/core/src/subscription/module_subscription_actor.rs index 41eb62e60ba..edcda7d16a6 100644 --- a/crates/core/src/subscription/module_subscription_actor.rs +++ b/crates/core/src/subscription/module_subscription_actor.rs @@ -7,8 +7,8 @@ use super::query::compile_query_with_hashes; use super::tx::DeltaTx; use super::{collect_table_update, TableUpdateType}; use crate::client::messages::{ - SerializableMessage, SubscriptionData, SubscriptionError, SubscriptionMessage, SubscriptionResult, - SubscriptionRows, SubscriptionUpdateMessage, TransactionUpdateMessage, + ProcedureResultMessage, SerializableMessage, SubscriptionData, SubscriptionError, SubscriptionMessage, + SubscriptionResult, SubscriptionRows, SubscriptionUpdateMessage, TransactionUpdateMessage, }; use crate::client::{ClientActorId, ClientConnectionSender, Protocol}; use crate::db::relational_db::{MutTx, RelationalDB, Tx}; @@ -634,6 +634,22 @@ impl ModuleSubscriptions { Ok((plans, auth, scopeguard::ScopeGuard::into_inner(tx), compile_timer)) } + /// Like [`Self::send_client_message`], + /// but doesn't require a `TxId` because procedures don't hold a transaction open. + pub fn send_procedure_message( + &self, + recipient: Arc, + message: ProcedureResultMessage, + ) -> Result<(), BroadcastError> { + self.broadcast_queue.send_client_message( + recipient, + // TODO(procedure-tx): We'll need some mechanism for procedures to report their last-referenced TxOffset, + // and to pass it here. + // This is currently moot, as procedures have no way to open a transaction yet. + None, message, + ) + } + /// Send a message to a client connection. /// This will eventually be sent by the send-worker. /// This takes a `TxId`, because this should be called while still holding a lock on the database. diff --git a/crates/core/src/util/jobs.rs b/crates/core/src/util/jobs.rs index 6634d1e57d2..64f187b0408 100644 --- a/crates/core/src/util/jobs.rs +++ b/crates/core/src/util/jobs.rs @@ -71,6 +71,7 @@ impl CoreInfo { // However, `max_blocking_threads` will panic if passed 0, so we set a limit of 1 // and use `on_thread_start` to log an error when spawning a blocking task. .max_blocking_threads(1) + .enable_time() .on_thread_start({ use std::sync::atomic::{AtomicBool, Ordering}; let already_spawned_worker = AtomicBool::new(false); diff --git a/crates/datastore/src/execution_context.rs b/crates/datastore/src/execution_context.rs index f4fbcaf84ae..ef72cbbdf25 100644 --- a/crates/datastore/src/execution_context.rs +++ b/crates/datastore/src/execution_context.rs @@ -131,6 +131,7 @@ pub enum WorkloadType { Unsubscribe, Update, Internal, + Procedure, } impl Default for WorkloadType { diff --git a/crates/datastore/src/locking_tx_datastore/datastore.rs b/crates/datastore/src/locking_tx_datastore/datastore.rs index 6ebaa2fc275..99de3312237 100644 --- a/crates/datastore/src/locking_tx_datastore/datastore.rs +++ b/crates/datastore/src/locking_tx_datastore/datastore.rs @@ -3313,7 +3313,7 @@ mod tests { table_id: TableId::SENTINEL, schedule_id: ScheduleId::SENTINEL, schedule_name: "schedule".into(), - reducer_name: "reducer".into(), + function_name: "reducer".into(), at_column: 1.into(), }; let sum_ty = AlgebraicType::sum([("foo", AlgebraicType::Bool), ("bar", AlgebraicType::U16)]); diff --git a/crates/datastore/src/locking_tx_datastore/mut_tx.rs b/crates/datastore/src/locking_tx_datastore/mut_tx.rs index 06889c0ed5c..aee33e34d41 100644 --- a/crates/datastore/src/locking_tx_datastore/mut_tx.rs +++ b/crates/datastore/src/locking_tx_datastore/mut_tx.rs @@ -244,7 +244,7 @@ impl MutTxId { table_id: schedule.table_id, schedule_id: schedule.schedule_id, schedule_name: schedule.schedule_name, - reducer_name: schedule.reducer_name, + reducer_name: schedule.function_name, at_column: schedule.at_column, }; let id = self diff --git a/crates/datastore/src/system_tables.rs b/crates/datastore/src/system_tables.rs index 1ecd69df51f..f5515e379b8 100644 --- a/crates/datastore/src/system_tables.rs +++ b/crates/datastore/src/system_tables.rs @@ -1232,6 +1232,11 @@ impl TryFrom> for StVarRow { pub struct StScheduledRow { pub(crate) schedule_id: ScheduleId, pub(crate) table_id: TableId, + /// The name of the reducer or procedure which will run when this table's rows reach their execution time. + /// + /// Note that, despite the column name, this may refer to either a reducer or a procedure. + /// We cannot change the schema of existing system tables, + /// so we are unable to rename this column. pub(crate) reducer_name: Box, pub(crate) schedule_name: Box, pub(crate) at_column: ColId, @@ -1254,7 +1259,7 @@ impl From for ScheduleSchema { fn from(row: StScheduledRow) -> Self { Self { table_id: row.table_id, - reducer_name: row.reducer_name, + function_name: row.reducer_name, schedule_id: row.schedule_id, schedule_name: row.schedule_name, at_column: row.at_column, diff --git a/crates/lib/src/db/raw_def/v9.rs b/crates/lib/src/db/raw_def/v9.rs index 56518336e5d..ced99c8b73c 100644 --- a/crates/lib/src/db/raw_def/v9.rs +++ b/crates/lib/src/db/raw_def/v9.rs @@ -82,9 +82,17 @@ pub struct RawModuleDefV9 { pub types: Vec, /// Miscellaneous additional module exports. + /// + /// The enum `RawMiscModuleExportV9` can have new variants added + /// without breaking existing compiled modules. + /// As such, this acts as a sort of dumping ground for any exports added after we defined `RawModuleDefV9`. + /// Currently, this contains only procedure definitions. + /// + /// If/when we define `RawModuleDefV10`, these should be moved out of `misc_exports` and into their own fields, + /// and the new `misc_exports` should once again be initially empty. pub misc_exports: Vec, - /// Low level security definitions. + /// Row level security definitions. /// /// Each definition must have a unique name. pub row_level_security: Vec, @@ -294,7 +302,7 @@ pub fn direct(col: impl Into) -> RawIndexAlgorithm { RawIndexAlgorithm::Direct { column: col.into() } } -/// Marks a table as a timer table for a scheduled reducer. +/// Marks a table as a timer table for a scheduled reducer or procedure. /// /// The table must have columns: /// - `scheduled_id` of type `u64`. @@ -307,7 +315,9 @@ pub struct RawScheduleDefV9 { /// Even though there is ABSOLUTELY NO REASON TO. pub name: Option>, - /// The name of the reducer to call. + /// The name of the reducer or procedure to call. + /// + /// Despite the field name here, this may be either a reducer or a procedure. pub reducer_name: RawIdentifier, /// The column of the `scheduled_at` field of this scheduled table. @@ -364,6 +374,10 @@ pub struct RawRowLevelSecurityDefV9 { #[non_exhaustive] pub enum RawMiscModuleExportV9 { ColumnDefaultValue(RawColumnDefaultValueV9), + /// A procedure definition. + // Included here because procedures were added after the format of [`RawModuleDefV9`] was already stabilized. + // If/when we define `RawModuleDefV10`, this should be moved out of `misc_exports` and into its own field. + Procedure(RawProcedureDefV9), } /// Marks a particular table's column as having a particular default. @@ -459,6 +473,27 @@ pub enum Lifecycle { OnDisconnect, } +/// A procedure definition. +/// +/// Will be wrapped in [`RawMiscModuleExportV9`] and included in the [`RawModuleDefV9`]'s `misc_exports` vec. +#[derive(Debug, Clone, SpacetimeType)] +#[sats(crate = crate)] +#[cfg_attr(feature = "test", derive(PartialEq, Eq, PartialOrd, Ord))] +pub struct RawProcedureDefV9 { + /// The name of the procedure. + pub name: RawIdentifier, + + /// The types and optional names of the parameters, in order. + /// This `ProductType` need not be registered in the typespace. + pub params: ProductType, + + /// The type of the return value. + /// + /// If this is a user-defined product or sum type, + /// it should be registered in the typespace and indirected through an [`AlgebraicType::Ref`]. + pub return_type: AlgebraicType, +} + /// A builder for a [`RawModuleDefV9`]. #[derive(Default)] pub struct RawModuleDefV9Builder { @@ -631,6 +666,31 @@ impl RawModuleDefV9Builder { }); } + /// Add a procedure to the in-progress module. + /// + /// Accepts a `ProductType` of arguments. + /// The arguments `ProductType` need not be registered in the typespace. + /// + /// Also accepts an `AlgebraicType` return type. + /// If this is a user-defined product or sum type, + /// it should be registered in the typespace and indirected through an `AlgebraicType::Ref`. + /// + /// The `&mut ProcedureContext` first argument to the procedure should not be included in the `params`. + pub fn add_procedure( + &mut self, + name: impl Into, + params: spacetimedb_sats::ProductType, + return_type: spacetimedb_sats::AlgebraicType, + ) { + self.module + .misc_exports + .push(RawMiscModuleExportV9::Procedure(RawProcedureDefV9 { + name: name.into(), + params, + return_type, + })) + } + /// Add a row-level security policy to the module. /// /// The `sql` expression should be a valid SQL expression that will be used to filter rows. @@ -799,10 +859,10 @@ impl RawTableDefBuilder<'_> { /// The table must have the appropriate columns for a scheduled table. pub fn with_schedule( mut self, - reducer_name: impl Into, + function_name: impl Into, scheduled_at_column: impl Into, ) -> Self { - let reducer_name = reducer_name.into(); + let reducer_name = function_name.into(); let scheduled_at_column = scheduled_at_column.into(); self.table.schedule = Some(RawScheduleDefV9 { name: None, diff --git a/crates/primitives/src/ids.rs b/crates/primitives/src/ids.rs index fc9a9b69e61..afe5325aba1 100644 --- a/crates/primitives/src/ids.rs +++ b/crates/primitives/src/ids.rs @@ -116,3 +116,35 @@ system_id! { // This is never stored in a system table, but is useful to have defined here. pub struct ReducerId(pub u32); } + +system_id! { + /// The index of a procedure as defined in a module's procedure list. + // This is never stored in a system table, but is useful to have defined here. + pub struct ProcedureId(pub u32); +} + +/// An id for a function exported from a module, which may be a reducer or a procedure. +// This is never stored in a system table, +// but is useful to have defined here to provide a shared language for downstream crates. +#[derive(Clone, Copy, Debug)] +pub enum FunctionId { + Reducer(ReducerId), + Procedure(ProcedureId), +} + +impl FunctionId { + pub fn as_reducer(self) -> Option { + if let Self::Reducer(id) = self { + Some(id) + } else { + None + } + } + pub fn as_procedure(self) -> Option { + if let Self::Procedure(id) = self { + Some(id) + } else { + None + } + } +} diff --git a/crates/primitives/src/lib.rs b/crates/primitives/src/lib.rs index 7ae37765514..d88e541f195 100644 --- a/crates/primitives/src/lib.rs +++ b/crates/primitives/src/lib.rs @@ -7,7 +7,7 @@ mod ids; pub use attr::{AttributeKind, ColumnAttribute, ConstraintKind, Constraints}; pub use col_list::{ColList, ColOrCols, ColSet}; -pub use ids::{ColId, ConstraintId, IndexId, ReducerId, ScheduleId, SequenceId, TableId}; +pub use ids::{ColId, ConstraintId, FunctionId, IndexId, ProcedureId, ReducerId, ScheduleId, SequenceId, TableId}; /// The minimum size of a chunk yielded by a wasm abi RowIter. pub const ROW_ITER_CHUNK_SIZE: usize = 32 * 1024; diff --git a/crates/schema/src/auto_migrate/formatter.rs b/crates/schema/src/auto_migrate/formatter.rs index 54b4ed87ce3..abdf443a9df 100644 --- a/crates/schema/src/auto_migrate/formatter.rs +++ b/crates/schema/src/auto_migrate/formatter.rs @@ -5,7 +5,7 @@ use std::io; use super::{AutoMigratePlan, IndexAlgorithm, ModuleDefLookup, TableDef}; use crate::{ auto_migrate::AutoMigrateStep, - def::{ConstraintData, ModuleDef, ScheduleDef}, + def::{ConstraintData, FunctionKind, ModuleDef, ScheduleDef}, identifier::Identifier, }; use itertools::Itertools; @@ -188,7 +188,8 @@ pub struct AccessChangeInfo { #[derive(Debug, Clone, PartialEq)] pub struct ScheduleInfo { pub table_name: String, - pub reducer_name: Identifier, + pub function_name: Identifier, + pub function_kind: FunctionKind, } #[derive(Debug, Clone, PartialEq)] @@ -314,7 +315,8 @@ fn extract_table_info( let schedule = table_def.schedule.as_ref().map(|schedule| ScheduleInfo { table_name: table_def.name.to_string().clone(), - reducer_name: schedule.reducer_name.clone(), + function_name: schedule.function_name.clone(), + function_kind: schedule.function_kind, }); Ok(TableInfo { @@ -438,7 +440,8 @@ fn extract_schedule_info( Ok(ScheduleInfo { table_name: schedule_def.name.to_string().clone(), - reducer_name: schedule_def.reducer_name.clone(), + function_name: schedule_def.function_name.clone(), + function_kind: schedule_def.function_kind, }) } diff --git a/crates/schema/src/auto_migrate/termcolor_formatter.rs b/crates/schema/src/auto_migrate/termcolor_formatter.rs index 85648e3c244..c5c05f7993a 100644 --- a/crates/schema/src/auto_migrate/termcolor_formatter.rs +++ b/crates/schema/src/auto_migrate/termcolor_formatter.rs @@ -219,7 +219,7 @@ impl MigrationFormatter for TermColorFormatter { if let Some(s) = &table.schedule { self.write_colored_line("Schedule:", Some(self.colors.section_header), true)?; self.indent(); - self.write_bullet(&format!("Calls reducer: {}", s.reducer_name))?; + self.write_bullet(&format!("Calls {}: {}", s.function_kind, s.function_name))?; self.dedent(); } @@ -276,7 +276,7 @@ impl MigrationFormatter for TermColorFormatter { self.buffer.write_all(b" schedule for table ")?; self.write_colored(&s.table_name, Some(self.colors.table_name), true)?; self.buffer - .write_all(format!(" calling reducer {}\n", s.reducer_name).as_bytes()) + .write_all(format!(" calling {} {}\n", s.function_kind, s.function_name).as_bytes()) } fn format_rls(&mut self, r: &RlsInfo, action: Action) -> io::Result<()> { diff --git a/crates/schema/src/def.rs b/crates/schema/src/def.rs index 1461dcca076..72b5fb9ba4f 100644 --- a/crates/schema/src/def.rs +++ b/crates/schema/src/def.rs @@ -33,12 +33,12 @@ use spacetimedb_data_structures::map::HashMap; use spacetimedb_lib::db::raw_def; use spacetimedb_lib::db::raw_def::v9::{ Lifecycle, RawColumnDefaultValueV9, RawConstraintDataV9, RawConstraintDefV9, RawIdentifier, RawIndexAlgorithm, - RawIndexDefV9, RawMiscModuleExportV9, RawModuleDefV9, RawReducerDefV9, RawRowLevelSecurityDefV9, RawScheduleDefV9, - RawScopedTypeNameV9, RawSequenceDefV9, RawSql, RawTableDefV9, RawTypeDefV9, RawUniqueConstraintDataV9, TableAccess, - TableType, + RawIndexDefV9, RawMiscModuleExportV9, RawModuleDefV9, RawProcedureDefV9, RawReducerDefV9, RawRowLevelSecurityDefV9, + RawScheduleDefV9, RawScopedTypeNameV9, RawSequenceDefV9, RawSql, RawTableDefV9, RawTypeDefV9, + RawUniqueConstraintDataV9, TableAccess, TableType, }; use spacetimedb_lib::{ProductType, RawModuleDef}; -use spacetimedb_primitives::{ColId, ColList, ColOrCols, ColSet, ReducerId, TableId}; +use spacetimedb_primitives::{ColId, ColList, ColOrCols, ColSet, ProcedureId, ReducerId, TableId}; use spacetimedb_sats::{AlgebraicType, AlgebraicValue}; use spacetimedb_sats::{AlgebraicTypeRef, Typespace}; @@ -103,6 +103,12 @@ pub struct ModuleDef { /// and must be preserved for future calls to `__call_reducer__`. reducers: IndexMap, + /// The procedures of the module definition. + /// + /// Like `reducers`, this uses [`IndexMap`] to preserve order + /// so that `__call_procedure__` receives stable integer IDs. + procedures: IndexMap, + /// A map from lifecycle reducer kind to reducer id. lifecycle_reducers: EnumMap>, @@ -161,6 +167,11 @@ impl ModuleDef { self.reducers.values() } + /// The procedures of the module definition. + pub fn procedures(&self) -> impl Iterator { + self.procedures.values() + } + /// The type definitions of the module definition. pub fn types(&self) -> impl Iterator { self.types.values() @@ -243,6 +254,23 @@ impl ModuleDef { self.reducers.get_index(id.idx()).map(|(_, def)| def) } + /// Convenience method to look up a procedure, possibly by a string, returning its id as well. + pub fn procedure_full>( + &self, + name: &K, + ) -> Option<(ProcedureId, &ProcedureDef)> { + // If the string IS a valid identifier, we can just look it up. + self.procedures.get_full(name).map(|(idx, _, def)| (idx.into(), def)) + } + + pub fn procedure_by_id(&self, id: ProcedureId) -> &ProcedureDef { + &self.procedures[id.idx()] + } + + pub fn get_procedure_by_id(&self, id: ProcedureId) -> Option<&ProcedureDef> { + self.procedures.get_index(id.idx()).map(|(_, def)| def) + } + /// Looks up a lifecycle reducer defined in the module. pub fn lifecycle_reducer(&self, lifecycle: Lifecycle) -> Option<(ReducerId, &ReducerDef)> { self.lifecycle_reducers[lifecycle].map(|i| (i, &self.reducers[i.idx()])) @@ -342,13 +370,17 @@ impl From for RawModuleDefV9 { typespace_for_generate: _, refmap: _, row_level_security_raw, + procedures, } = val; RawModuleDefV9 { tables: to_raw(tables), reducers: reducers.into_iter().map(|(_, def)| def.into()).collect(), types: to_raw(types), - misc_exports: vec![], + misc_exports: procedures + .into_iter() + .map(|(_, def)| RawMiscModuleExportV9::Procedure(def.into())) + .collect(), typespace, row_level_security: row_level_security_raw.into_iter().map(|(_, def)| def).collect(), } @@ -745,7 +777,30 @@ impl From for RawRowLevelSecurityDefV9 { } } -/// Marks a table as a timer table for a scheduled reducer. +#[derive(Copy, Clone, Eq, PartialEq, Debug, Ord, PartialOrd)] +pub enum FunctionKind { + /// Functions which have not yet been determined to be reducers or procedures. + /// + /// Used as a placeholder during module validation, + /// when pre-processing [`ScheduleDef`]s prior to validating their scheduled functions. + /// Will never appear in a fully-validated [`ModuleDef`], + /// and should not be placed in errors either. + Unknown, + Reducer, + Procedure, +} + +impl fmt::Display for FunctionKind { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(match self { + FunctionKind::Unknown => "exported function", + FunctionKind::Reducer => "reducer", + FunctionKind::Procedure => "procedure", + }) + } +} + +/// Marks a table as a timer table for a scheduled reducer or procedure. #[derive(Debug, Clone, Eq, PartialEq)] #[non_exhaustive] pub struct ScheduleDef { @@ -762,16 +817,19 @@ pub struct ScheduleDef { /// Must be named `scheduled_id` and be of type `u64`. pub id_column: ColId, - /// The name of the reducer to call. Not yet an `Identifier` because + /// The name of the reducer or procedure to call. Not yet an `Identifier` because /// reducer names are not currently validated. - pub reducer_name: Identifier, + pub function_name: Identifier, + + /// Whether the `function_name` refers to a reducer or a procedure. + pub function_kind: FunctionKind, } impl From for RawScheduleDefV9 { fn from(val: ScheduleDef) -> Self { RawScheduleDefV9 { name: Some(val.name), - reducer_name: val.reducer_name.into(), + reducer_name: val.function_name.into(), scheduled_at_column: val.at_column, } } @@ -902,7 +960,9 @@ impl From for RawScopedTypeNameV9 { #[derive(Debug, Clone, Eq, PartialEq)] #[non_exhaustive] pub struct ReducerDef { - /// The name of the reducer. This must be unique within the module. + /// The name of the reducer. + /// + /// This must be unique within the module's set of reducers and procedures. pub name: Identifier, /// The parameters of the reducer. @@ -929,6 +989,47 @@ impl From for RawReducerDefV9 { } } +#[derive(Debug, Clone, Eq, PartialEq)] +#[non_exhaustive] +pub struct ProcedureDef { + /// The name of the procedure. + /// + /// This must be unique within the module's set of reducers and procedures. + pub name: Identifier, + + /// The parameters of the procedure. + /// + /// This `ProductType` need not be registered in the module's `Typespace`. + pub params: ProductType, + + /// The parameters of the procedure, formatted for client codegen. + /// + /// This `ProductType` need not be registered in the module's `TypespaceForGenerate`. + pub params_for_generate: ProductTypeDef, + + /// The return type of the procedure. + /// + /// If this is a non-special compound type, it should be registered in the module's `Typespace` + /// and indirected through an [`AlgebraicType::Ref`]. + pub return_type: AlgebraicType, + + /// The return type of the procedure. + /// + /// If this is a non-special compound type, it should be registered in the module's `TypespaceForGenerate` + /// and indirected through an [`AlgebraicTypeUse::Ref`]. + pub return_type_for_generate: AlgebraicTypeUse, +} + +impl From for RawProcedureDefV9 { + fn from(val: ProcedureDef) -> Self { + RawProcedureDefV9 { + name: val.name.into(), + params: val.params, + return_type: val.return_type, + } + } +} + impl ModuleDefLookup for TableDef { type Key<'a> = &'a Identifier; diff --git a/crates/schema/src/def/deserialize.rs b/crates/schema/src/def/deserialize.rs index a5a14240584..02d3768ab38 100644 --- a/crates/schema/src/def/deserialize.rs +++ b/crates/schema/src/def/deserialize.rs @@ -1,46 +1,93 @@ //! Helpers to allow deserializing data using a ReducerDef. -use crate::def::ReducerDef; -use spacetimedb_lib::sats::{self, de, ProductValue}; - -/// Wrapper around a `ReducerDef` that allows deserializing to a `ProductValue` at the type -/// of the reducer's parameter `ProductType`. -#[derive(Clone, Copy)] -pub struct ReducerArgsDeserializeSeed<'a>(pub sats::WithTypespace<'a, ReducerDef>); - -impl<'a> ReducerArgsDeserializeSeed<'a> { - /// Get the reducer def of this seed. - pub fn reducer_def(&self) -> &'a ReducerDef { - self.0.ty() - } +use crate::def::{ProcedureDef, ReducerDef}; +use spacetimedb_lib::{ + sats::{self, de, impl_serialize, ser, ProductValue}, + ProductType, +}; + +pub trait ArgsSeed: for<'de> de::DeserializeSeed<'de, Output = ProductValue> { + fn params(&self) -> &ProductType; } -impl<'de> de::DeserializeSeed<'de> for ReducerArgsDeserializeSeed<'_> { - type Output = ProductValue; +/// Define `struct_name` as a newtype wrapper around [`WithTypespace`] of `inner_ty`, +/// and implement [`de::DeserializeSeed`] and [`de::ProductVisitor`] for that newtype. +/// +/// `ReducerArgs` (defined in the spacetimedb_core crate) will use this type +/// to deserialize the arguments to a reducer or procedure +/// at the appropriate type for that specific function, which is known only at runtime. +macro_rules! define_args_deserialize_seed { + ($struct_vis:vis struct $struct_name:ident($field_vis:vis $inner_ty:ty)) => { + #[doc = concat!( + "Wrapper around a [`", + stringify!($inner_ty), + "`] that allows deserializing to a [`ProductValue`] at the type of the def's parameter `ProductType`." + )] + #[derive(Clone, Copy)] + $struct_vis struct $struct_name<'a>($field_vis sats::WithTypespace<'a, $inner_ty> ); - fn deserialize>(self, deserializer: D) -> Result { - deserializer.deserialize_product(self) - } -} + impl<'a> $struct_name<'a> { + #[doc = concat!( + "Get the inner [`", + stringify!($inner_ty), + "`] of this seed." + )] + $struct_vis fn inner_def(&self) -> &'a $inner_ty { + self.0.ty() + } + } -impl<'de> de::ProductVisitor<'de> for ReducerArgsDeserializeSeed<'_> { - type Output = ProductValue; + impl<'de> de::DeserializeSeed<'de> for $struct_name<'_> { + type Output = ProductValue; - fn product_name(&self) -> Option<&str> { - Some(&self.0.ty().name) - } - fn product_len(&self) -> usize { - self.0.ty().params.elements.len() - } - fn product_kind(&self) -> de::ProductKind { - de::ProductKind::ReducerArgs - } + fn deserialize>(self, deserializer: D) -> Result { + deserializer.deserialize_product(self) + } + } - fn visit_seq_product>(self, tup: A) -> Result { - de::visit_seq_product(self.0.map(|r| &*r.params.elements), &self, tup) - } + impl<'de> de::ProductVisitor<'de> for $struct_name<'_> { + type Output = ProductValue; - fn visit_named_product>(self, tup: A) -> Result { - de::visit_named_product(self.0.map(|r| &*r.params.elements), &self, tup) + fn product_name(&self) -> Option<&str> { + Some(&self.0.ty().name) + } + fn product_len(&self) -> usize { + self.0.ty().params.elements.len() + } + fn product_kind(&self) -> de::ProductKind { + de::ProductKind::ReducerArgs + } + + fn visit_seq_product>(self, tup: A) -> Result { + de::visit_seq_product(self.0.map(|r| &*r.params.elements), &self, tup) + } + + fn visit_named_product>(self, tup: A) -> Result { + de::visit_named_product(self.0.map(|r| &*r.params.elements), &self, tup) + } + } + + impl<'a> ArgsSeed for $struct_name<'a> { + fn params(&self) -> &ProductType { + &self.0.ty().params + } + } } } + +define_args_deserialize_seed!(pub struct ReducerArgsDeserializeSeed(pub ReducerDef)); +define_args_deserialize_seed!(pub struct ProcedureArgsDeserializeSeed(pub ProcedureDef)); + +pub struct ReducerArgsWithSchema<'a> { + value: &'a ProductValue, + ty: sats::WithTypespace<'a, ReducerDef>, +} +impl_serialize!([] ReducerArgsWithSchema<'_>, (self, ser) => { + use itertools::Itertools; + use ser::SerializeSeqProduct; + let mut seq = ser.serialize_seq_product(self.value.elements.len())?; + for (value, elem) in self.value.elements.iter().zip_eq(&*self.ty.ty().params.elements) { + seq.serialize_element(&self.ty.with(&elem.algebraic_type).with_value(value))?; + } + seq.end() +}); diff --git a/crates/schema/src/def/validate/v8.rs b/crates/schema/src/def/validate/v8.rs index 89a853a995e..3a24597fe50 100644 --- a/crates/schema/src/def/validate/v8.rs +++ b/crates/schema/src/def/validate/v8.rs @@ -57,7 +57,9 @@ fn upgrade_module(def: RawModuleDefV8, extra_errors: &mut Vec) tables, reducers, types, + // v8 doesn't have procedures, which are all we use the `misc_exports` for at this time (pgoldman 2025-07-23). misc_exports: Default::default(), + row_level_security: vec![], // v8 doesn't have row-level security } } @@ -526,7 +528,7 @@ mod tests { assert_eq!(delivery_def.columns[2].ty, AlgebraicType::U64); assert_eq!(delivery_def.schedule.as_ref().unwrap().at_column, 1.into()); assert_eq!( - &delivery_def.schedule.as_ref().unwrap().reducer_name[..], + &delivery_def.schedule.as_ref().unwrap().function_name[..], "check_deliveries" ); assert_eq!(delivery_def.primary_key, Some(ColId(2))); diff --git a/crates/schema/src/def/validate/v9.rs b/crates/schema/src/def/validate/v9.rs index 6adbbf01b9a..a8ae7447fb0 100644 --- a/crates/schema/src/def/validate/v9.rs +++ b/crates/schema/src/def/validate/v9.rs @@ -1,3 +1,5 @@ +use std::borrow::Cow; + use crate::def::*; use crate::error::{RawColumnName, ValidationError}; use crate::type_for_generate::{ClientCodegenError, ProductTypeDef, TypespaceForGenerateBuilder}; @@ -5,6 +7,7 @@ use crate::{def::validate::Result, error::TypeLocation}; use spacetimedb_data_structures::error_stream::{CollectAllErrors, CombineErrors}; use spacetimedb_data_structures::map::HashSet; use spacetimedb_lib::db::default_element_ordering::{product_type_has_default_ordering, sum_type_has_default_ordering}; +use spacetimedb_lib::db::raw_def::v9::RawProcedureDefV9; use spacetimedb_lib::ProductType; use spacetimedb_primitives::col_list; use spacetimedb_sats::{bsatn::de::Deserializer, de::DeserializeSeed, WithTypespace}; @@ -51,6 +54,28 @@ pub fn validate(def: RawModuleDefV9) -> Result { }) .collect_all_errors(); + let (procedures, non_procedure_misc_exports) = + misc_exports + .into_iter() + .partition::, _>(|misc_export| { + matches!(misc_export, RawMiscModuleExportV9::Procedure(_)) + }); + + let procedures = procedures + .into_iter() + .map(|procedure| { + let RawMiscModuleExportV9::Procedure(procedure) = procedure else { + unreachable!("Already partitioned procedures separate from other `RawMiscModuleExportV9` variants"); + }; + procedure + }) + .map(|procedure| { + validator + .validate_procedure_def(procedure) + .map(|procedure_def| (procedure_def.name.clone(), procedure_def)) + }) + .collect_all_errors(); + let tables = tables .into_iter() .map(|table| { @@ -76,15 +101,18 @@ pub fn validate(def: RawModuleDefV9) -> Result { }) .collect_all_errors::>(); - let tables_types_reducers = (tables, types, reducers) - .combine_errors() - .and_then(|(mut tables, types, reducers)| { - let sched_exists = check_scheduled_reducers_exist(&tables, &reducers); - let default_values_work = proccess_misc_exports(misc_exports, &validator, &mut tables); - (sched_exists, default_values_work).combine_errors()?; - - Ok((tables, types, reducers)) - }); + let tables_types_reducers_procedures = + (tables, types, reducers, procedures) + .combine_errors() + .and_then(|(mut tables, types, reducers, procedures)| { + ( + check_scheduled_functions_exist(&mut tables, &reducers, &procedures), + check_non_procedure_misc_exports(non_procedure_misc_exports, &validator, &mut tables), + check_function_names_are_unique(&reducers, &procedures), + ) + .combine_errors()?; + Ok((tables, types, reducers, procedures)) + }); let ModuleValidator { stored_in_table_def, @@ -93,7 +121,8 @@ pub fn validate(def: RawModuleDefV9) -> Result { .. } = validator; - let (tables, types, reducers) = (tables_types_reducers).map_err(|errors| errors.sort_deduplicate())?; + let (tables, types, reducers, procedures) = + (tables_types_reducers_procedures).map_err(|errors| errors.sort_deduplicate())?; let typespace_for_generate = typespace_for_generate.finish(); @@ -107,6 +136,7 @@ pub fn validate(def: RawModuleDefV9) -> Result { refmap, row_level_security_raw, lifecycle_reducers, + procedures, }) } @@ -286,26 +316,19 @@ impl ModuleValidator<'_> { }) } - /// Validate a reducer definition. - fn validate_reducer_def(&mut self, reducer_def: RawReducerDefV9, reducer_id: ReducerId) -> Result { - let RawReducerDefV9 { - name, - params, - lifecycle, - } = reducer_def; - - let params_for_generate: Result<_> = params + fn params_for_generate<'a>( + &mut self, + params: &'a ProductType, + make_type_location: impl Fn(usize, Option>) -> TypeLocation<'a>, + ) -> Result> { + params .elements .iter() .enumerate() .map(|(position, param)| { // Note: this does not allocate, since `TypeLocation` is defined using `Cow`. // We only allocate if an error is returned. - let location = TypeLocation::ReducerArg { - reducer_name: (&*name).into(), - position, - arg_name: param.name().map(Into::into), - }; + let location = make_type_location(position, param.name().map(Into::into)); let param_name = param .name() .ok_or_else(|| { @@ -319,10 +342,27 @@ impl ModuleValidator<'_> { let ty_use = self.validate_for_type_use(&location, ¶m.algebraic_type); (param_name, ty_use).combine_errors() }) - .collect_all_errors(); + .collect_all_errors() + } - // reducers don't live in the global namespace. - let name = identifier(name); + /// Validate a reducer definition. + fn validate_reducer_def(&mut self, reducer_def: RawReducerDefV9, reducer_id: ReducerId) -> Result { + let RawReducerDefV9 { + name, + params, + lifecycle, + } = reducer_def; + + let params_for_generate: Result<_> = + self.params_for_generate(¶ms, |position, arg_name| TypeLocation::ReducerArg { + reducer_name: (&*name).into(), + position, + arg_name, + }); + + // Reducers share the "function namespace" with procedures. + // Uniqueness is validated in a later pass, in `check_function_names_are_unique`. + let name = identifier(name.clone()); let lifecycle = lifecycle .map(|lifecycle| match &mut self.lifecycle_reducers[lifecycle] { @@ -347,6 +387,45 @@ impl ModuleValidator<'_> { }) } + fn validate_procedure_def(&mut self, procedure_def: RawProcedureDefV9) -> Result { + let RawProcedureDefV9 { + name, + params, + return_type, + } = procedure_def; + + let params_for_generate = self.params_for_generate(¶ms, |position, arg_name| TypeLocation::ProcedureArg { + procedure_name: Cow::Borrowed(&name), + position, + arg_name, + }); + + let return_type_for_generate = self.validate_for_type_use( + &TypeLocation::ProcedureReturn { + procedure_name: Cow::Borrowed(&name), + }, + &return_type, + ); + + // Procedures share the "function namespace" with reducers. + // Uniqueness is validated in a later pass, in `check_function_names_are_unique`. + let name = identifier(name); + + let (name, params_for_generate, return_type_for_generate) = + (name, params_for_generate, return_type_for_generate).combine_errors()?; + + Ok(ProcedureDef { + name, + params, + params_for_generate: ProductTypeDef { + elements: params_for_generate, + recursive: false, // A ProductTypeDef not stored in a Typespace cannot be recursive. + }, + return_type, + return_type_for_generate, + }) + } + fn validate_column_default_value( &self, tables: &HashMap, @@ -717,7 +796,8 @@ impl TableValidator<'_, '_> { /// Validate a schedule definition. fn validate_schedule_def(&mut self, schedule: RawScheduleDefV9, primary_key: Option) -> Result { let RawScheduleDefV9 { - reducer_name, + // Despite the field name, a `RawScheduleDefV9` may refer to either a reducer or a function. + reducer_name: function_name, scheduled_at_column, name, } = schedule; @@ -749,15 +829,20 @@ impl TableValidator<'_, '_> { }); let name = self.add_to_global_namespace(name); - let reducer_name = identifier(reducer_name); + let function_name = identifier(function_name); - let (name, (at_column, id_column), reducer_name) = (name, at_id, reducer_name).combine_errors()?; + let (name, (at_column, id_column), function_name) = (name, at_id, function_name).combine_errors()?; Ok(ScheduleDef { name, at_column, id_column, - reducer_name, + function_name, + + // Fill this in as a placeholder now. + // It will be populated with the correct `FunctionKind` later, + // in `check_scheduled_functions_exist`. + function_kind: FunctionKind::Unknown, }) } @@ -897,32 +982,43 @@ fn identifier(name: Box) -> Result { Identifier::new(name).map_err(|error| ValidationError::IdentifierError { error }.into()) } -fn check_scheduled_reducers_exist( - tables: &IdentifierMap, +/// Check that every [`ScheduleDef`]'s `function_name` refers to a real reducer or procedure +/// and that the function's arguments are appropriate for the table, +/// then record the scheduled function's [`FunctionKind`] in the [`ScheduleDef`]. +fn check_scheduled_functions_exist( + tables: &mut IdentifierMap, reducers: &IndexMap, + procedures: &IndexMap, ) -> Result<()> { + let validate_params = + |params_from_function: &ProductType, table_row_type_ref: AlgebraicTypeRef, function_name: &str| { + if params_from_function.elements.len() == 1 + && params_from_function.elements[0].algebraic_type == table_row_type_ref.into() + { + Ok(()) + } else { + Err(ValidationError::IncorrectScheduledFunctionParams { + function_name: function_name.into(), + function_kind: FunctionKind::Reducer, + expected: AlgebraicType::product([AlgebraicType::Ref(table_row_type_ref)]).into(), + actual: params_from_function.clone().into(), + }) + } + }; tables - .values() + .values_mut() .map(|table| -> Result<()> { - if let Some(schedule) = &table.schedule { - let reducer = reducers.get(&schedule.reducer_name); - if let Some(reducer) = reducer { - if reducer.params.elements.len() == 1 - && reducer.params.elements[0].algebraic_type == table.product_type_ref.into() - { - Ok(()) - } else { - Err(ValidationError::IncorrectScheduledReducerParams { - reducer: (&*schedule.reducer_name).into(), - expected: AlgebraicType::product([AlgebraicType::Ref(table.product_type_ref)]).into(), - actual: reducer.params.clone().into(), - } - .into()) - } + if let Some(schedule) = &mut table.schedule { + if let Some(reducer) = reducers.get(&schedule.function_name) { + schedule.function_kind = FunctionKind::Reducer; + validate_params(&reducer.params, table.product_type_ref, &reducer.name).map_err(Into::into) + } else if let Some(procedure) = procedures.get(&schedule.function_name) { + schedule.function_kind = FunctionKind::Procedure; + validate_params(&procedure.params, table.product_type_ref, &procedure.name).map_err(Into::into) } else { - Err(ValidationError::MissingScheduledReducer { + Err(ValidationError::MissingScheduledFunction { schedule: schedule.name.clone(), - reducer: schedule.reducer_name.clone(), + function: schedule.function_name.clone(), } .into()) } @@ -933,7 +1029,24 @@ fn check_scheduled_reducers_exist( .collect_all_errors() } -fn proccess_misc_exports( +fn check_function_names_are_unique( + reducers: &IndexMap, + procedures: &IndexMap, +) -> Result<()> { + let names = reducers.keys().collect::>(); + procedures + .keys() + .map(|name| -> Result<()> { + if names.contains(name) { + Err(ValidationError::DuplicateFunctionName { name: name.clone() }.into()) + } else { + Ok(()) + } + }) + .collect_all_errors() +} + +fn check_non_procedure_misc_exports( misc_exports: Vec, validator: &ModuleValidator, tables: &mut IdentifierMap, @@ -993,7 +1106,8 @@ mod tests { }; use crate::def::{validate::Result, ModuleDef}; use crate::def::{ - BTreeAlgorithm, ConstraintData, ConstraintDef, DirectAlgorithm, IndexDef, SequenceDef, UniqueConstraintData, + BTreeAlgorithm, ConstraintData, ConstraintDef, DirectAlgorithm, FunctionKind, IndexDef, SequenceDef, + UniqueConstraintData, }; use crate::error::*; use crate::type_for_generate::ClientCodegenError; @@ -1210,9 +1324,13 @@ mod tests { assert_eq!(delivery_def.columns[2].ty, AlgebraicType::U64); assert_eq!(delivery_def.schedule.as_ref().unwrap().at_column, 1.into()); assert_eq!( - &delivery_def.schedule.as_ref().unwrap().reducer_name[..], + &delivery_def.schedule.as_ref().unwrap().function_name[..], "check_deliveries" ); + assert_eq!( + delivery_def.schedule.as_ref().unwrap().function_kind, + FunctionKind::Reducer + ); assert_eq!(delivery_def.primary_key, Some(ColId(2))); assert_eq!(def.typespace.get(product_type_ref), Some(&product_type)); @@ -1660,9 +1778,9 @@ mod tests { .finish(); let result: Result = builder.finish().try_into(); - expect_error_matching!(result, ValidationError::MissingScheduledReducer { schedule, reducer } => { + expect_error_matching!(result, ValidationError::MissingScheduledFunction { schedule, function} => { &schedule[..] == "Deliveries_sched" && - reducer == &expect_identifier("check_deliveries") + function == &expect_identifier("check_deliveries") }); } @@ -1688,8 +1806,9 @@ mod tests { builder.add_reducer("check_deliveries", ProductType::from([("a", AlgebraicType::U64)]), None); let result: Result = builder.finish().try_into(); - expect_error_matching!(result, ValidationError::IncorrectScheduledReducerParams { reducer, expected, actual } => { - &reducer[..] == "check_deliveries" && + expect_error_matching!(result, ValidationError::IncorrectScheduledFunctionParams {function_name, function_kind, expected, actual } => { + &function_name[..] == "check_deliveries" && + *function_kind == FunctionKind::Reducer && expected.0 == AlgebraicType::product([AlgebraicType::Ref(deliveries_product_type)]) && actual.0 == ProductType::from([("a", AlgebraicType::U64)]).into() }); diff --git a/crates/schema/src/error.rs b/crates/schema/src/error.rs index e7526ad095d..5fe11dffe20 100644 --- a/crates/schema/src/error.rs +++ b/crates/schema/src/error.rs @@ -7,7 +7,7 @@ use spacetimedb_sats::{bsatn::DecodeError, AlgebraicType, AlgebraicTypeRef}; use std::borrow::Cow; use std::fmt; -use crate::def::ScopedTypeName; +use crate::def::{FunctionKind, ScopedTypeName}; use crate::identifier::Identifier; use crate::type_for_generate::ClientCodegenError; @@ -108,11 +108,12 @@ pub enum ValidationError { MissingPrimaryKeyUniqueConstraint { column: RawColumnName }, #[error("Table {table} should have a type definition for its product_type_element, but does not")] TableTypeNameMismatch { table: Identifier }, - #[error("Schedule {schedule} refers to a scheduled reducer {reducer} that does not exist")] - MissingScheduledReducer { schedule: Box, reducer: Identifier }, - #[error("Scheduled reducer {reducer} expected to have type {expected}, but has type {actual}")] - IncorrectScheduledReducerParams { - reducer: RawIdentifier, + #[error("Schedule {schedule} refers to a scheduled reducer or procedure {function} that does not exist")] + MissingScheduledFunction { schedule: Box, function: Identifier }, + #[error("Scheduled {function_kind} {function_name} expected to have type {expected}, but has type {actual}")] + IncorrectScheduledFunctionParams { + function_name: RawIdentifier, + function_kind: FunctionKind, expected: PrettyAlgebraicType, actual: PrettyAlgebraicType, }, @@ -120,6 +121,22 @@ pub enum ValidationError { TableNameReserved { table: Identifier }, #[error("Row-level security invalid: `{error}`, query: `{sql}")] InvalidRowLevelQuery { sql: String, error: String }, + #[error("Name {name} is used for multiple reducers and/or procedures")] + DuplicateFunctionName { name: Identifier }, + #[error("Procedure {procedure} lists non-existent reducer {reducer} as its `on_abort` handler")] + MissingOnAbortHandler { procedure: Identifier, reducer: Identifier }, + #[error("Procedure {procedure} lists another procedure {other_procedure} as its `on_abort` handler, but `on_abort` handlers must be reducers")] + OnAbortHandlerIsProcedure { + procedure: Identifier, + other_procedure: Identifier, + }, + #[error("Expected reducer {reducer_name} to have type {expected} because it is the `on_abort` handler for procedure {procedure_name}, but it has type {actual}")] + IncorrectOnAbortReducerParams { + reducer_name: RawIdentifier, + procedure_name: RawIdentifier, + expected: PrettyAlgebraicType, + actual: PrettyAlgebraicType, + }, #[error("Failed to deserialize default value for table {table} column {col_id}: {err}")] ColumnDefaultValueMalformed { table: RawIdentifier, @@ -173,6 +190,14 @@ pub enum TypeLocation<'a> { position: usize, arg_name: Option>, }, + /// A procedure argument. + ProcedureArg { + procedure_name: Cow<'a, str>, + position: usize, + arg_name: Option>, + }, + /// A procedure return type. + ProcedureReturn { procedure_name: Cow<'a, str> }, /// A type in the typespace. InTypespace { /// The reference to the type within the typespace. @@ -193,6 +218,18 @@ impl TypeLocation<'_> { position, arg_name: arg_name.map(|s| s.to_string().into()), }, + TypeLocation::ProcedureArg { + procedure_name, + position, + arg_name, + } => TypeLocation::ProcedureArg { + procedure_name: procedure_name.to_string().into(), + position, + arg_name: arg_name.map(|s| s.to_string().into()), + }, + Self::ProcedureReturn { procedure_name } => TypeLocation::ProcedureReturn { + procedure_name: procedure_name.to_string().into(), + }, // needed to convince rustc this is allowed. TypeLocation::InTypespace { ref_ } => TypeLocation::InTypespace { ref_ }, } @@ -213,6 +250,20 @@ impl fmt::Display for TypeLocation<'_> { } Ok(()) } + TypeLocation::ProcedureArg { + procedure_name, + position, + arg_name, + } => { + write!(f, "procedure `{procedure_name}` argument {position}")?; + if let Some(arg_name) = arg_name { + write!(f, " (`{arg_name}`)")?; + } + Ok(()) + } + TypeLocation::ProcedureReturn { procedure_name } => { + write!(f, "procedure `{procedure_name}` return value") + } TypeLocation::InTypespace { ref_ } => { write!(f, "typespace ref `{ref_}`") } diff --git a/crates/schema/src/schema.rs b/crates/schema/src/schema.rs index f9d104795f2..149cb08eb5b 100644 --- a/crates/schema/src/schema.rs +++ b/crates/schema/src/schema.rs @@ -922,20 +922,20 @@ pub struct ScheduleSchema { /// The name of the schedule. pub schedule_name: Box, - /// The name of the reducer to call. - pub reducer_name: Box, + /// The name of the reducer or procedure to call. + pub function_name: Box, /// The column containing the `ScheduleAt` enum. pub at_column: ColId, } impl ScheduleSchema { - pub fn for_test(name: impl Into>, reducer: impl Into>, at: impl Into) -> Self { + pub fn for_test(name: impl Into>, function: impl Into>, at: impl Into) -> Self { Self { table_id: TableId::SENTINEL, schedule_id: ScheduleId::SENTINEL, schedule_name: name.into(), - reducer_name: reducer.into(), + function_name: function.into(), at_column: at.into(), } } @@ -955,7 +955,7 @@ impl Schema for ScheduleSchema { table_id: parent_id, schedule_id: id, schedule_name: (*def.name).into(), - reducer_name: (*def.reducer_name).into(), + function_name: (*def.function_name).into(), at_column: def.at_column, // Ignore def.at_column and id_column. Those are recovered at runtime. } @@ -964,9 +964,9 @@ impl Schema for ScheduleSchema { fn check_compatible(&self, _module_def: &ModuleDef, def: &Self::Def) -> Result<(), anyhow::Error> { ensure_eq!(&self.schedule_name[..], &def.name[..], "Schedule name mismatch"); ensure_eq!( - &self.reducer_name[..], - &def.reducer_name[..], - "Schedule reducer name mismatch" + &self.function_name[..], + &def.function_name[..], + "Schedule function name mismatch" ); Ok(()) } diff --git a/modules/module-test/src/lib.rs b/modules/module-test/src/lib.rs index 7986af92841..e5bbf0fc670 100644 --- a/modules/module-test/src/lib.rs +++ b/modules/module-test/src/lib.rs @@ -1,10 +1,10 @@ #![allow(clippy::disallowed_names)] -use spacetimedb::log; use spacetimedb::spacetimedb_lib::db::raw_def::v9::TableAccess; use spacetimedb::spacetimedb_lib::{self, bsatn}; use spacetimedb::{ - duration, table, ConnectionId, Deserialize, Identity, ReducerContext, SpacetimeType, Table, Timestamp, + duration, table, ConnectionId, Deserialize, Identity, ReducerContext, SpacetimeType, Table, TimeDuration, Timestamp, }; +use spacetimedb::{log, ProcedureContext}; pub type TestAlias = TestA; @@ -437,3 +437,53 @@ fn assert_caller_identity_is_module_identity(ctx: &ReducerContext) { log::info!("Called by the owner {owner}"); } } + +#[spacetimedb::procedure] +fn this_is_a_procedure(_ctx: &mut ProcedureContext) { + panic!("nah") +} + +#[derive(SpacetimeType)] +pub struct MyProcedureResult { + foo: String, + bar: i32, +} + +#[spacetimedb::procedure] +fn this_procedure_returns_something(ctx: &mut ProcedureContext) -> MyProcedureResult { + MyProcedureResult { + foo: format!("The time is {}", ctx.timestamp), + bar: 100, + } +} + +#[spacetimedb::procedure] +fn this_procedure_sleeps(ctx: &mut ProcedureContext) -> String { + let before = ctx.timestamp; + let until = before + TimeDuration::from_micros(1000000); + ctx.sleep_until(until); + let after = ctx.timestamp; + format!("Started at {before}, requested to sleep until {until}, woke at {after}") +} + +#[spacetimedb::table( + name = scheduled_procedure_arg, + scheduled(scheduled_procedure) +)] +pub struct ScheduledProcedureArg { + #[primary_key] + #[auto_inc] + scheduled_id: u64, + scheduled_at: spacetimedb::ScheduleAt, + arg: String, +} + +#[spacetimedb::procedure] +fn scheduled_procedure(ctx: &mut ProcedureContext, arg: ScheduledProcedureArg) { + log::info!( + "Scheduled procedure called by {} at {}: {}", + ctx.sender, + ctx.timestamp, + arg.arg + ); +} diff --git a/sdks/rust/src/db_connection.rs b/sdks/rust/src/db_connection.rs index aa5c3f6fb27..0e4a29579e8 100644 --- a/sdks/rust/src/db_connection.rs +++ b/sdks/rust/src/db_connection.rs @@ -1184,7 +1184,8 @@ async fn parse_loop( error: e.error.to_string(), }, ws::ServerMessage::SubscribeApplied(_) => unreachable!("Rust client SDK never sends `SubscribeSingle`, but received a `SubscribeApplied` from the host... huh?"), - ws::ServerMessage::UnsubscribeApplied(_) => unreachable!("Rust client SDK never sends `UnsubscribeSingle`, but received a `UnsubscribeApplied` from the host... huh?") + ws::ServerMessage::UnsubscribeApplied(_) => unreachable!("Rust client SDK never sends `UnsubscribeSingle`, but received a `UnsubscribeApplied` from the host... huh?"), + ws::ServerMessage::ProcedureResult(_) => unimplemented!("Rust client SDK procedure support") }) .expect("Failed to send ParsedMessage to main thread"); }