diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1ad49418be5..85ed92d26a9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -71,7 +71,7 @@ jobs: run: python -m pip install psycopg2-binary xmltodict - name: Run smoketests # Note: clear_database and replication only work in private - run: python -m smoketests ${{ matrix.smoketest_args }} -x clear_database replication + run: python -m smoketests ${{ matrix.smoketest_args }} -x clear_database replication teams - name: Stop containers (Linux) if: always() && runner.os == 'Linux' run: docker compose down diff --git a/Cargo.lock b/Cargo.lock index e6258738d43..fe75a402c43 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7026,6 +7026,7 @@ dependencies = [ "itertools 0.12.1", "mimalloc", "percent-encoding", + "pretty_assertions", "regex", "reqwest 0.12.24", "rolldown", @@ -7050,6 +7051,7 @@ dependencies = [ "tar", "tempfile", "termcolor", + "termtree", "thiserror 1.0.69", "tikv-jemalloc-ctl", "tikv-jemallocator", @@ -7107,6 +7109,7 @@ dependencies = [ "spacetimedb-paths", "spacetimedb-schema", "tempfile", + "thiserror 1.0.69", "tokio", "tokio-stream", "tokio-tungstenite", @@ -7859,6 +7862,7 @@ name = "spacetimedb-testing" version = "1.6.0" dependencies = [ "anyhow", + "bytes", "clap 4.5.50", "duct", "env_logger 0.10.2", @@ -8391,6 +8395,12 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "termtree" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683" + [[package]] name = "test-client" version = "1.6.0" diff --git a/Cargo.toml b/Cargo.toml index 30f1ddde62f..adc06df0e21 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -266,6 +266,7 @@ tar = "0.4" tempdir = "0.3.7" tempfile = "3.20" termcolor = "1.2.0" +termtree = "0.5.1" thin-vec = "0.2.13" thiserror = "1.0.37" tokio = { version = "1.37", features = ["full"] } diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index 365f049b25c..7c0740bd340 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -64,6 +64,7 @@ tabled.workspace = true tar.workspace = true tempfile.workspace = true termcolor.workspace = true +termtree.workspace = true thiserror.workspace = true tokio.workspace = true tokio-tungstenite.workspace = true @@ -77,6 +78,9 @@ clap-markdown.workspace = true rolldown.workspace = true rolldown_utils.workspace = true +[dev-dependencies] +pretty_assertions.workspace = true + [target.'cfg(not(target_env = "msvc"))'.dependencies] tikv-jemallocator = { workspace = true } tikv-jemalloc-ctl = { workspace = true } diff --git a/crates/cli/src/subcommands/delete.rs b/crates/cli/src/subcommands/delete.rs index e0d5c756137..b05e5d205cd 100644 --- a/crates/cli/src/subcommands/delete.rs +++ b/crates/cli/src/subcommands/delete.rs @@ -1,7 +1,15 @@ +use std::io; + use crate::common_args; use crate::config::Config; -use crate::util::{add_auth_header_opt, database_identity, get_auth_header}; +use crate::util::{add_auth_header_opt, database_identity, get_auth_header, y_or_n, AuthHeader}; use clap::{Arg, ArgMatches}; +use http::StatusCode; +use itertools::Itertools as _; +use reqwest::Response; +use spacetimedb_client_api_messages::http::{DatabaseDeleteConfirmationResponse, DatabaseTree, DatabaseTreeNode}; +use spacetimedb_lib::Hash; +use tokio::io::AsyncWriteExt as _; pub fn cli() -> clap::Command { clap::Command::new("delete") @@ -22,11 +30,143 @@ pub async fn exec(mut config: Config, args: &ArgMatches) -> Result<(), anyhow::E let force = args.get_flag("force"); let identity = database_identity(&config, database, server).await?; - - let builder = reqwest::Client::new().delete(format!("{}/v1/database/{}", config.get_host_url(server)?, identity)); + let host_url = config.get_host_url(server)?; + let request_path = format!("{host_url}/v1/database/{identity}"); let auth_header = get_auth_header(&mut config, false, server, !force).await?; - let builder = add_auth_header_opt(builder, &auth_header); - builder.send().await?.error_for_status()?; + let client = reqwest::Client::new(); + + let response = send_request(&client, &request_path, &auth_header, None).await?; + match response.status() { + StatusCode::PRECONDITION_REQUIRED => { + let confirm = response.json::().await?; + println!("WARNING: Deleting the database {identity} will also delete its children!"); + if !force { + print_database_tree_info(&confirm.database_tree).await?; + } + if y_or_n(force, "Do you want to proceed deleting above databases?")? { + send_request(&client, &request_path, &auth_header, Some(confirm.confirmation_token)) + .await? + .error_for_status()?; + } else { + println!("Aborting"); + } + + Ok(()) + } + StatusCode::OK => Ok(()), + _ => response.error_for_status().map(drop).map_err(Into::into), + } +} + +async fn send_request( + client: &reqwest::Client, + request_path: &str, + auth: &AuthHeader, + confirmation_token: Option, +) -> Result { + let mut builder = client.delete(request_path); + builder = add_auth_header_opt(builder, auth); + if let Some(token) = confirmation_token { + builder = builder.query(&[("token", token)]); + } + builder.send().await +} + +async fn print_database_tree_info(tree: &DatabaseTree) -> io::Result<()> { + tokio::io::stdout() + .write_all(as_termtree(tree).to_string().as_bytes()) + .await +} + +fn as_termtree(tree: &DatabaseTree) -> termtree::Tree { + let mut stack: Vec<(&DatabaseTree, bool)> = vec![]; + stack.push((tree, false)); + + let mut built: Vec> = <_>::default(); + + while let Some((node, visited)) = stack.pop() { + if visited { + let mut term_node = termtree::Tree::new(fmt_tree_node(&node.root)); + term_node.leaves = built.drain(built.len() - node.children.len()..).collect(); + term_node.leaves.reverse(); + built.push(term_node); + } else { + stack.push((node, true)); + stack.extend(node.children.iter().rev().map(|child| (child, false))); + } + } + + built + .pop() + .expect("database tree contains a root and we pushed it last") +} + +fn fmt_tree_node(node: &DatabaseTreeNode) -> String { + format!( + "{}{}", + node.database_identity, + if node.database_names.is_empty() { + <_>::default() + } else { + format!(": {}", node.database_names.iter().join(", ")) + } + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use spacetimedb_client_api_messages::http::{DatabaseTree, DatabaseTreeNode}; + use spacetimedb_lib::{sats::u256, Identity}; - Ok(()) + #[test] + fn render_termtree() { + let tree = DatabaseTree { + root: DatabaseTreeNode { + database_identity: Identity::ONE, + database_names: ["parent".into()].into(), + }, + children: vec![ + DatabaseTree { + root: DatabaseTreeNode { + database_identity: Identity::from_u256(u256::new(2)), + database_names: ["child".into()].into(), + }, + children: vec![ + DatabaseTree { + root: DatabaseTreeNode { + database_identity: Identity::from_u256(u256::new(3)), + database_names: ["grandchild".into()].into(), + }, + children: vec![], + }, + DatabaseTree { + root: DatabaseTreeNode { + database_identity: Identity::from_u256(u256::new(5)), + database_names: [].into(), + }, + children: vec![], + }, + ], + }, + DatabaseTree { + root: DatabaseTreeNode { + database_identity: Identity::from_u256(u256::new(4)), + database_names: ["sibling".into(), "bro".into()].into(), + }, + children: vec![], + }, + ], + }; + pretty_assertions::assert_eq!( + "\ +0000000000000000000000000000000000000000000000000000000000000001: parent +├── 0000000000000000000000000000000000000000000000000000000000000004: bro, sibling +└── 0000000000000000000000000000000000000000000000000000000000000002: child + ├── 0000000000000000000000000000000000000000000000000000000000000005 + └── 0000000000000000000000000000000000000000000000000000000000000003: grandchild +", + &as_termtree(&tree).to_string() + ); + } } diff --git a/crates/cli/src/subcommands/publish.rs b/crates/cli/src/subcommands/publish.rs index 8b0b357492f..91433d0c250 100644 --- a/crates/cli/src/subcommands/publish.rs +++ b/crates/cli/src/subcommands/publish.rs @@ -1,15 +1,16 @@ +use anyhow::{ensure, Context}; use clap::Arg; use clap::ArgAction::{Set, SetTrue}; use clap::ArgMatches; use reqwest::{StatusCode, Url}; use spacetimedb_client_api_messages::name::{is_identity, parse_database_name, PublishResult}; -use spacetimedb_client_api_messages::name::{PrePublishResult, PrettyPrintStyle, PublishOp}; +use spacetimedb_client_api_messages::name::{DatabaseNameError, PrePublishResult, PrettyPrintStyle, PublishOp}; use std::path::PathBuf; use std::{env, fs}; use crate::config::Config; use crate::util::{add_auth_header_opt, get_auth_header, AuthHeader, ResponseExt}; -use crate::util::{decode_identity, unauth_error_context, y_or_n}; +use crate::util::{decode_identity, y_or_n}; use crate::{build, common_args}; pub fn cli() -> clap::Command { @@ -75,6 +76,17 @@ pub fn cli() -> clap::Command { .arg( common_args::anonymous() ) + .arg( + Arg::new("parent") + .help("Domain or identity of a parent for this database") + .long("parent") + .long_help( +"A valid domain or identity of an existing database that should be the parent of this database. + +If a parent is given, the new database inherits the team permissions from the parent. +A parent can only be set when a database is created, not when it is updated." + ) + ) .arg( Arg::new("name|identity") .help("A valid domain or identity for this database") @@ -106,6 +118,7 @@ pub async fn exec(mut config: Config, args: &ArgMatches) -> Result<(), anyhow::E let build_options = args.get_one::("build_options").unwrap(); let num_replicas = args.get_one::("num_replicas"); let break_clients_flag = args.get_flag("break_clients"); + let parent = args.get_one::("parent"); // If the user didn't specify an identity and we didn't specify an anonymous identity, then // we want to use the default identity @@ -113,6 +126,9 @@ pub async fn exec(mut config: Config, args: &ArgMatches) -> Result<(), anyhow::E // easily create a new identity with an email let auth_header = get_auth_header(&mut config, anon_identity, server, !force).await?; + let (name_or_identity, parent) = + validate_name_and_parent(name_or_identity.map(String::as_str), parent.map(String::as_str))?; + if !path_to_project.exists() { return Err(anyhow::anyhow!( "Project path does not exist: {}", @@ -152,14 +168,11 @@ pub async fn exec(mut config: Config, args: &ArgMatches) -> Result<(), anyhow::E ); let client = reqwest::Client::new(); - // If a domain or identity was provided, we should locally make sure it looks correct and + // If a name was given, ensure to percent-encode it. + // We also use PUT with a name or identity, and POST otherwise. let mut builder = if let Some(name_or_identity) = name_or_identity { - if !is_identity(name_or_identity) { - parse_database_name(name_or_identity)?; - } let encode_set = const { &percent_encoding::NON_ALPHANUMERIC.remove(b'_').remove(b'-') }; let domain = percent_encoding::percent_encode(name_or_identity.as_bytes(), encode_set); - let mut builder = client.put(format!("{database_host}/v1/database/{domain}")); if !clear_database { @@ -174,7 +187,7 @@ pub async fn exec(mut config: Config, args: &ArgMatches) -> Result<(), anyhow::E break_clients_flag, ) .await?; - }; + } builder } else { @@ -204,6 +217,9 @@ pub async fn exec(mut config: Config, args: &ArgMatches) -> Result<(), anyhow::E eprintln!("WARNING: Use of unstable option `--num-replicas`.\n"); builder = builder.query(&[("num_replicas", *n)]); } + if let Some(parent) = parent { + builder = builder.query(&[("parent", parent)]); + } println!("Publishing module..."); @@ -220,18 +236,6 @@ pub async fn exec(mut config: Config, args: &ArgMatches) -> Result<(), anyhow::E } let res = builder.body(program_bytes).send().await?; - if res.status() == StatusCode::UNAUTHORIZED && !anon_identity { - // If we're not in the `anon_identity` case, then we have already forced the user to log in above (using `get_auth_header`), so this should be safe to unwrap. - let token = config.spacetimedb_token().unwrap(); - let identity = decode_identity(token)?; - let err = res.text().await?; - return unauth_error_context( - Err(anyhow::anyhow!(err)), - &identity, - config.server_nick_or_host(server)?, - ); - } - let response: PublishResult = res.json_or_error().await?; match response { PublishResult::Success { @@ -270,6 +274,47 @@ pub async fn exec(mut config: Config, args: &ArgMatches) -> Result<(), anyhow::E Ok(()) } +fn validate_name_or_identity(name_or_identity: &str) -> Result<(), DatabaseNameError> { + if is_identity(name_or_identity) { + Ok(()) + } else { + parse_database_name(name_or_identity).map(drop) + } +} + +fn invalid_parent_name(name: &str) -> String { + format!("invalid parent database name `{name}`") +} + +fn validate_name_and_parent<'a>( + name: Option<&'a str>, + parent: Option<&'a str>, +) -> anyhow::Result<(Option<&'a str>, Option<&'a str>)> { + if let Some(parent) = parent.as_ref() { + validate_name_or_identity(parent).with_context(|| invalid_parent_name(parent))?; + } + + match name { + Some(name) => match name.split_once('/') { + Some((parent_alt, child)) => { + ensure!( + parent.is_none() || parent.is_some_and(|parent| parent == parent_alt), + "cannot specify both --parent and /" + ); + validate_name_or_identity(parent_alt).with_context(|| invalid_parent_name(parent_alt))?; + validate_name_or_identity(child)?; + + Ok((Some(child), Some(parent_alt))) + } + None => { + validate_name_or_identity(name)?; + Ok((Some(name), parent)) + } + }, + None => Ok((None, parent)), + } +} + /// Determine the pretty print style based on the NO_COLOR environment variable. /// /// See: https://no-color.org @@ -293,16 +338,7 @@ async fn apply_pre_publish_if_needed( auth_header: &AuthHeader, break_clients_flag: bool, ) -> Result { - if let Some(pre) = call_pre_publish( - client, - base_url, - &domain.to_string(), - host_type, - program_bytes, - auth_header, - ) - .await? - { + if let Some(pre) = call_pre_publish(client, base_url, domain, host_type, program_bytes, auth_header).await? { println!("{}", pre.migrate_plan); if pre.break_clients @@ -359,3 +395,67 @@ async fn call_pre_publish( let pre_publish_result: PrePublishResult = res.json_or_error().await?; Ok(Some(pre_publish_result)) } + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_matches; + use spacetimedb_lib::Identity; + + use super::*; + + #[test] + fn validate_none_arguments_returns_none_values() { + assert_matches!(validate_name_and_parent(None, None), Ok((None, None))); + assert_matches!(validate_name_and_parent(Some("foo"), None), Ok((Some(_), None))); + assert_matches!(validate_name_and_parent(None, Some("foo")), Ok((None, Some(_)))); + } + + #[test] + fn validate_valid_arguments_returns_arguments() { + let name = "child"; + let parent = "parent"; + let result = (Some(name), Some(parent)); + assert_matches!( + validate_name_and_parent(Some(name), Some(parent)), + Ok(val) if val == result + ); + } + + #[test] + fn validate_parent_and_path_name_returns_error_unless_parent_equal() { + assert_matches!( + validate_name_and_parent(Some("parent/child"), Some("parent")), + Ok((Some("child"), Some("parent"))) + ); + assert_matches!(validate_name_and_parent(Some("parent/child"), Some("cousin")), Err(_)); + } + + #[test] + fn validate_more_than_two_path_segments_are_an_error() { + assert_matches!(validate_name_and_parent(Some("proc/net/tcp"), None), Err(_)); + assert_matches!(validate_name_and_parent(Some("proc//net"), None), Err(_)); + } + + #[test] + fn validate_trailing_slash_is_an_error() { + assert_matches!(validate_name_and_parent(Some("foo//"), None), Err(_)); + assert_matches!(validate_name_and_parent(Some("foo/bar/"), None), Err(_)); + } + + #[test] + fn validate_parent_cant_have_slash() { + assert_matches!(validate_name_and_parent(Some("child"), Some("par/ent")), Err(_)); + assert_matches!(validate_name_and_parent(Some("child"), Some("parent/")), Err(_)); + } + + #[test] + fn validate_name_or_parent_can_be_identities() { + let parent = Identity::ZERO.to_string(); + let child = Identity::ONE.to_string(); + + assert_matches!( + validate_name_and_parent(Some(&child), Some(&parent)), + Ok(res) if res == (Some(&child), Some(&parent)) + ); + } +} diff --git a/crates/cli/src/util.rs b/crates/cli/src/util.rs index 74b483c54e6..df398ee30b8 100644 --- a/crates/cli/src/util.rs +++ b/crates/cli/src/util.rs @@ -281,15 +281,6 @@ pub fn y_or_n(force: bool, prompt: &str) -> anyhow::Result { Ok(input == "y" || input == "yes") } -pub fn unauth_error_context(res: anyhow::Result, identity: &str, server: &str) -> anyhow::Result { - res.with_context(|| { - format!( - "Identity {identity} is not valid for server {server}. -Please log back in with `spacetime logout` and then `spacetime login`." - ) - }) -} - pub fn decode_identity(token: &String) -> anyhow::Result { // Here, we manually extract and decode the claims from the json web token. // We do this without using the `jsonwebtoken` crate because it doesn't seem to have a way to skip signature verification. diff --git a/crates/client-api-messages/src/http.rs b/crates/client-api-messages/src/http.rs index fe966bf5dfc..02cc75968e6 100644 --- a/crates/client-api-messages/src/http.rs +++ b/crates/client-api-messages/src/http.rs @@ -1,6 +1,9 @@ +use std::collections::BTreeSet; +use std::iter; + use serde::{Deserialize, Serialize}; use spacetimedb_lib::metrics::ExecutionMetrics; -use spacetimedb_lib::ProductType; +use spacetimedb_lib::{Hash, Identity, ProductType}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SqlStmtResult { @@ -27,3 +30,34 @@ impl SqlStmtStats { } } } + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct DatabaseTree { + pub root: DatabaseTreeNode, + pub children: Vec, +} + +impl DatabaseTree { + pub fn iter(&self) -> impl Iterator + '_ { + let mut stack = vec![self]; + iter::from_fn(move || { + let node = stack.pop()?; + for child in node.children.iter().rev() { + stack.push(child); + } + Some(&node.root) + }) + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct DatabaseTreeNode { + pub database_identity: Identity, + pub database_names: BTreeSet, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct DatabaseDeleteConfirmationResponse { + pub database_tree: DatabaseTree, + pub confirmation_token: Hash, +} diff --git a/crates/client-api/Cargo.toml b/crates/client-api/Cargo.toml index be32c90e9c3..59c9e636d94 100644 --- a/crates/client-api/Cargo.toml +++ b/crates/client-api/Cargo.toml @@ -53,6 +53,7 @@ scopeguard.workspace = true serde_with.workspace = true async-stream.workspace = true humantime.workspace = true +thiserror.workspace = true [target.'cfg(not(target_env = "msvc"))'.dependencies] jemalloc_pprof.workspace = true diff --git a/crates/client-api/src/lib.rs b/crates/client-api/src/lib.rs index fd426bb154c..b78db0ae985 100644 --- a/crates/client-api/src/lib.rs +++ b/crates/client-api/src/lib.rs @@ -1,8 +1,11 @@ +use std::future::Future; use std::num::NonZeroU8; use std::sync::Arc; +use anyhow::anyhow; use async_trait::async_trait; use axum::response::ErrorResponse; +use bytes::Bytes; use http::StatusCode; use spacetimedb::client::ClientActorIndex; @@ -16,6 +19,7 @@ use spacetimedb_client_api_messages::name::{DomainName, InsertDomainResult, Regi use spacetimedb_lib::{ProductTypeElement, ProductValue}; use spacetimedb_paths::server::ModuleLogsDir; use spacetimedb_schema::auto_migrate::{MigrationPolicy, PrettyPrintStyle}; +use thiserror::Error; use tokio::sync::watch; pub mod auth; @@ -162,13 +166,22 @@ pub struct DatabaseDef { /// The [`Identity`] the database shall have. pub database_identity: Identity, /// The compiled program of the database module. - pub program_bytes: Vec, + pub program_bytes: Bytes, /// The desired number of replicas the database shall have. /// /// If `None`, the edition default is used. pub num_replicas: Option, /// The host type of the supplied program. pub host_type: HostType, + pub parent: Option, +} + +/// Parameters for resetting a database via [`ControlStateDelegate::clear_database`]. +pub struct DatabaseResetDef { + pub database_identity: Identity, + pub program_bytes: Option, + pub num_replicas: Option, + pub host_type: Option, } /// API of the SpacetimeDB control plane. @@ -240,6 +253,10 @@ pub trait ControlStateWriteAccess: Send + Sync { async fn delete_database(&self, caller_identity: &Identity, database_identity: &Identity) -> anyhow::Result<()>; + /// Remove all data from a database, and reset it according to the + /// given [DatabaseResetDef]. + async fn reset_database(&self, caller_identity: &Identity, spec: DatabaseResetDef) -> anyhow::Result<()>; + // Energy async fn add_energy(&self, identity: &Identity, amount: EnergyQuanta) -> anyhow::Result<()>; async fn withdraw_energy(&self, identity: &Identity, amount: EnergyQuanta) -> anyhow::Result<()>; @@ -339,6 +356,10 @@ impl ControlStateWriteAccess for Arc { (**self).delete_database(caller_identity, database_identity).await } + async fn reset_database(&self, caller_identity: &Identity, spec: DatabaseResetDef) -> anyhow::Result<()> { + (**self).reset_database(caller_identity, spec).await + } + async fn add_energy(&self, identity: &Identity, amount: EnergyQuanta) -> anyhow::Result<()> { (**self).add_energy(identity, amount).await } @@ -395,6 +416,74 @@ impl NodeDelegate for Arc { } } +#[derive(Debug, Error)] +pub enum Unauthorized { + #[error("{subject} is not authorized to perform {action:?}")] + Unauthorized { + subject: Identity, + action: Action, + #[source] + source: Option, + }, + #[error("authorization failed due to internal error")] + InternalError(#[from] anyhow::Error), +} + +impl Unauthorized { + pub fn into_response(self) -> ErrorResponse { + match self { + unauthorized @ Self::Unauthorized { .. } => { + (StatusCode::UNAUTHORIZED, format!("{:#}", anyhow!(unauthorized))).into() + } + Self::InternalError(e) => log_and_500(e), + } + } +} + +#[derive(Debug)] +pub enum Action { + CreateDatabase { parent: Option }, + UpdateDatabase, + ResetDatabase, + DeleteDatabase, + RenameDatabase, + ViewModuleLogs, +} + +pub trait Authorization { + fn authorize_action( + &self, + subject: Identity, + database: Identity, + action: Action, + ) -> impl Future> + Send; + + fn authorize_sql( + &self, + subject: Identity, + database: Identity, + ) -> impl Future> + Send; +} + +impl Authorization for Arc { + fn authorize_action( + &self, + subject: Identity, + database: Identity, + action: Action, + ) -> impl Future> + Send { + (**self).authorize_action(subject, database, action) + } + + fn authorize_sql( + &self, + subject: Identity, + database: Identity, + ) -> impl Future> + Send { + (**self).authorize_sql(subject, database) + } +} + pub fn log_and_500(e: impl std::fmt::Display) -> ErrorResponse { log::error!("internal error: {e:#}"); (StatusCode::INTERNAL_SERVER_ERROR, format!("{e:#}")).into() diff --git a/crates/client-api/src/routes/database.rs b/crates/client-api/src/routes/database.rs index 7be200d5c26..8370c6546ce 100644 --- a/crates/client-api/src/routes/database.rs +++ b/crates/client-api/src/routes/database.rs @@ -1,3 +1,5 @@ +use std::borrow::Cow; +use std::env; use std::num::NonZeroU8; use std::str::FromStr; use std::time::Duration; @@ -8,7 +10,9 @@ use crate::auth::{ }; use crate::routes::subscribe::generate_random_connection_id; pub use crate::util::{ByteStringBody, NameOrIdentity}; -use crate::{log_and_500, ControlStateDelegate, DatabaseDef, NodeDelegate}; +use crate::{ + log_and_500, Action, Authorization, ControlStateDelegate, DatabaseDef, DatabaseResetDef, NodeDelegate, Unauthorized, +}; use axum::body::{Body, Bytes}; use axum::extract::{Path, Query, State}; use axum::response::{ErrorResponse, IntoResponse}; @@ -20,24 +24,46 @@ use http::StatusCode; use serde::Deserialize; use spacetimedb::database_logger::DatabaseLogger; use spacetimedb::host::module_host::ClientConnectedError; -use spacetimedb::host::ReducerCallError; use spacetimedb::host::ReducerOutcome; -use spacetimedb::host::UpdateDatabaseResult; use spacetimedb::host::{FunctionArgs, MigratePlanResult}; +use spacetimedb::host::{ReducerCallError, UpdateDatabaseResult}; use spacetimedb::identity::Identity; use spacetimedb::messages::control_db::{Database, HostType}; +use spacetimedb_client_api_messages::http::SqlStmtResult; use spacetimedb_client_api_messages::name::{ self, DatabaseName, DomainName, MigrationPolicy, PrePublishResult, PrettyPrintStyle, PublishOp, PublishResult, }; use spacetimedb_lib::db::raw_def::v9::RawModuleDefV9; -use spacetimedb_lib::identity::AuthCtx; -use spacetimedb_lib::{sats, ProductValue, Timestamp}; +use spacetimedb_lib::{sats, Hash, ProductValue, Timestamp}; use spacetimedb_schema::auto_migrate::{ MigrationPolicy as SchemaMigrationPolicy, MigrationToken, PrettyPrintStyle as AutoMigratePrettyPrintStyle, }; use super::subscribe::{handle_websocket, HasWebSocketOptions}; +fn require_spacetime_auth_for_creation() -> bool { + env::var("TEMP_REQUIRE_SPACETIME_AUTH").is_ok_and(|v| !v.is_empty()) +} + +// A hacky function to let us restrict database creation on maincloud. +fn allow_creation(auth: &SpacetimeAuth) -> Result<(), ErrorResponse> { + if !require_spacetime_auth_for_creation() { + return Ok(()); + } + if auth.claims.issuer.trim_end_matches('/') == "https://auth.spacetimedb.com" { + Ok(()) + } else { + log::trace!( + "Rejecting creation request because auth issuer is {}", + auth.claims.issuer + ); + Err(( + StatusCode::UNAUTHORIZED, + "To create a database, you must be logged in with a SpacetimeDB account.", + ) + .into()) + } +} #[derive(Deserialize)] pub struct CallParams { name_or_identity: NameOrIdentity, @@ -295,7 +321,7 @@ pub async fn logs( Extension(auth): Extension, ) -> axum::response::Result where - S: ControlStateDelegate + NodeDelegate, + S: ControlStateDelegate + NodeDelegate + Authorization, { // You should not be able to read the logs from a database that you do not own // so, unless you are the owner, this will fail. @@ -305,17 +331,10 @@ where .await? .ok_or(NO_SUCH_DATABASE)?; - if database.owner_identity != auth.claims.identity { - return Err(( - StatusCode::BAD_REQUEST, - format!( - "Identity does not own database, expected: {} got: {}", - database.owner_identity.to_hex(), - auth.claims.identity.to_hex() - ), - ) - .into()); - } + worker_ctx + .authorize_action(auth.claims.identity, database.database_identity, Action::ViewModuleLogs) + .await + .map_err(Unauthorized::into_response)?; let replica = worker_ctx .get_leader_replica_by_database(database.id) @@ -402,7 +421,7 @@ pub async fn sql_direct( sql: String, ) -> axum::response::Result>> where - S: NodeDelegate + ControlStateDelegate, + S: NodeDelegate + ControlStateDelegate + Authorization, { // Anyone is authorized to execute SQL queries. The SQL engine will determine // which queries this identity is allowed to execute against the database. @@ -412,8 +431,10 @@ where .await? .ok_or(NO_SUCH_DATABASE)?; - let auth = AuthCtx::new(database.owner_identity, caller_identity); - log::debug!("auth: {auth:?}"); + let auth = worker_ctx + .authorize_sql(caller_identity, database.database_identity) + .await + .map_err(Unauthorized::into_response)?; let host = worker_ctx .leader(database.id) @@ -432,7 +453,7 @@ pub async fn sql( body: String, ) -> axum::response::Result where - S: NodeDelegate + ControlStateDelegate, + S: NodeDelegate + ControlStateDelegate + Authorization, { let json = sql_direct(worker_ctx, name_or_identity, params, auth.claims.identity, body).await?; @@ -483,6 +504,57 @@ pub async fn get_names( Ok(axum::Json(response)) } +#[derive(Deserialize)] +pub struct ResetDatabaseParams { + name_or_identity: NameOrIdentity, +} + +#[derive(Deserialize)] +pub struct ResetDatabaseQueryParams { + num_replicas: Option, + #[serde(default)] + host_type: HostType, +} + +pub async fn reset( + State(ctx): State, + Path(ResetDatabaseParams { name_or_identity }): Path, + Query(ResetDatabaseQueryParams { + num_replicas, + host_type, + }): Query, + Extension(auth): Extension, + program_bytes: Option, +) -> axum::response::Result> { + let database_identity = name_or_identity.resolve(&ctx).await?; + let database = worker_ctx_find_database(&ctx, &database_identity) + .await? + .ok_or(NO_SUCH_DATABASE)?; + + ctx.authorize_action(auth.claims.identity, database.database_identity, Action::ResetDatabase) + .await + .map_err(Unauthorized::into_response)?; + + let num_replicas = num_replicas.map(validate_replication_factor).transpose()?.flatten(); + ctx.reset_database( + &auth.claims.identity, + DatabaseResetDef { + database_identity, + program_bytes, + num_replicas, + host_type: Some(host_type), + }, + ) + .await + .map_err(log_and_500)?; + + Ok(axum::Json(PublishResult::Success { + domain: name_or_identity.name().cloned(), + database_identity, + op: PublishOp::Updated, + })) +} + #[derive(Deserialize)] pub struct PublishDatabaseParams { name_or_identity: Option, @@ -497,41 +569,15 @@ pub struct PublishDatabaseQueryParams { /// /// Users obtain such a hash via the `/database/:name_or_identity/pre-publish POST` route. /// This is a safeguard to require explicit approval for updates which will break clients. - token: Option, + token: Option, #[serde(default)] policy: MigrationPolicy, #[serde(default)] host_type: HostType, + parent: Option, } -use spacetimedb_client_api_messages::http::SqlStmtResult; -use std::env; - -fn require_spacetime_auth_for_creation() -> bool { - env::var("TEMP_REQUIRE_SPACETIME_AUTH").is_ok_and(|v| !v.is_empty()) -} - -// A hacky function to let us restrict database creation on maincloud. -fn allow_creation(auth: &SpacetimeAuth) -> Result<(), ErrorResponse> { - if !require_spacetime_auth_for_creation() { - return Ok(()); - } - if auth.claims.issuer.trim_end_matches('/') == "https://auth.spacetimedb.com" { - Ok(()) - } else { - log::trace!( - "Rejecting creation request because auth issuer is {}", - auth.claims.issuer - ); - Err(( - StatusCode::UNAUTHORIZED, - "To create a database, you must be logged in with a SpacetimeDB account.", - ) - .into()) - } -} - -pub async fn publish( +pub async fn publish( State(ctx): State, Path(PublishDatabaseParams { name_or_identity }): Path, Query(PublishDatabaseQueryParams { @@ -540,144 +586,215 @@ pub async fn publish( token, policy, host_type, + parent, }): Query, Extension(auth): Extension, - body: Bytes, + program_bytes: Bytes, ) -> axum::response::Result> { - // You should not be able to publish to a database that you do not own - // so, unless you are the owner, this will fail. - - let (database_identity, db_name) = match &name_or_identity { - Some(noa) => match noa.try_resolve(&ctx).await.map_err(log_and_500)? { - Ok(resolved) => (resolved, noa.name()), - Err(name) => { - // `name_or_identity` was a `NameOrIdentity::Name`, but no record - // exists yet. Create it now with a fresh identity. - allow_creation(&auth)?; - let database_auth = SpacetimeAuth::alloc(&ctx).await?; - let database_identity = database_auth.claims.identity; - let tld: name::Tld = name.clone().into(); - let tld = match ctx - .register_tld(&auth.claims.identity, tld) - .await - .map_err(log_and_500)? - { - name::RegisterTldResult::Success { domain } - | name::RegisterTldResult::AlreadyRegistered { domain } => domain, - name::RegisterTldResult::Unauthorized { .. } => { - return Err(( - StatusCode::UNAUTHORIZED, - axum::Json(PublishResult::PermissionDenied { name: name.clone() }), - ) - .into()) - } - }; - let res = ctx - .create_dns_record(&auth.claims.identity, &tld.into(), &database_identity) - .await - .map_err(log_and_500)?; - match res { - name::InsertDomainResult::Success { .. } => {} - name::InsertDomainResult::TldNotRegistered { .. } - | name::InsertDomainResult::PermissionDenied { .. } => { - return Err(log_and_500("impossible: we just registered the tld")) - } - name::InsertDomainResult::OtherError(e) => return Err(log_and_500(e)), - } - (database_identity, Some(name)) + // If `clear`, check that the database exists and delegate to `reset`. + // If it doesn't exist, ignore the `clear` parameter. + // TODO: Replace with actual redirect at the next possible version bump. + if clear { + let name_or_identity = name_or_identity + .as_ref() + .ok_or_else(|| bad_request("Clear database requires database name or identity".into()))?; + if let Ok(identity) = name_or_identity.try_resolve(&ctx).await.map_err(log_and_500)? { + if ctx.get_database_by_identity(&identity).map_err(log_and_500)?.is_some() { + return reset( + State(ctx), + Path(ResetDatabaseParams { + name_or_identity: name_or_identity.clone(), + }), + Query(ResetDatabaseQueryParams { + num_replicas, + host_type, + }), + Extension(auth), + Some(program_bytes), + ) + .await; } - }, - None => { - let database_auth = SpacetimeAuth::alloc(&ctx).await?; - let database_identity = database_auth.claims.identity; - (database_identity, None) } - }; + } - let policy: SchemaMigrationPolicy = match policy { - MigrationPolicy::BreakClients => { - if let Some(token) = token { - Ok(SchemaMigrationPolicy::BreakClients(token)) - } else { - Err(( - StatusCode::BAD_REQUEST, - "Migration policy is set to `BreakClients`, but no migration token was provided.", - )) - } - } + let (database_identity, db_name) = get_or_create_identity_and_name(&ctx, &auth, name_or_identity.as_ref()).await?; + let maybe_parent_database_identity = match parent.as_ref() { + None => None, + Some(parent) => parent.resolve(&ctx).await.map(Some)?, + }; - MigrationPolicy::Compatible => Ok(SchemaMigrationPolicy::Compatible), - }?; + // Check that the replication factor looks somewhat sane. + let num_replicas = num_replicas.map(validate_replication_factor).transpose()?.flatten(); log::trace!("Publishing to the identity: {}", database_identity.to_hex()); - let op = { - let exists = ctx - .get_database_by_identity(&database_identity) - .map_err(log_and_500)? - .is_some(); - if !exists { + // Check if the database already exists. + let existing = ctx.get_database_by_identity(&database_identity).map_err(log_and_500)?; + match existing.as_ref() { + // If not, check that the we caller is sufficiently authenticated. + None => { allow_creation(&auth)?; + if let Some(parent) = maybe_parent_database_identity { + ctx.authorize_action( + auth.claims.identity, + database_identity, + Action::CreateDatabase { parent: Some(parent) }, + ) + .await + .map_err(Unauthorized::into_response)?; + } } - - if clear && exists { - ctx.delete_database(&auth.claims.identity, &database_identity) + // If yes, authorize via ctx. + Some(database) => { + ctx.authorize_action(auth.claims.identity, database.database_identity, Action::UpdateDatabase) .await - .map_err(log_and_500)?; + .map_err(Unauthorized::into_response)?; } + } - if exists { - PublishOp::Updated - } else { - PublishOp::Created - } + // Indicate in the response whether we created or updated the database. + let publish_op = if existing.is_some() { + PublishOp::Updated + } else { + PublishOp::Created + }; + // If a parent is given, resolve to an existing database. + let parent = if let Some(name_or_identity) = parent { + let identity = name_or_identity + .resolve(&ctx) + .await + .map_err(|_| bad_request(format!("Parent database {name_or_identity} not found").into()))?; + Some(identity) + } else { + None }; - let num_replicas = num_replicas - .map(|n| { - let n = u8::try_from(n).map_err(|_| (StatusCode::BAD_REQUEST, "Replication factor {n} out of bounds"))?; - Ok::<_, ErrorResponse>(NonZeroU8::new(n)) - }) - .transpose()? - .flatten(); - + let schema_migration_policy = schema_migration_policy(policy, token)?; let maybe_updated = ctx .publish_database( &auth.claims.identity, DatabaseDef { database_identity, - program_bytes: body.into(), + program_bytes, num_replicas, host_type, + parent, }, - policy, + schema_migration_policy, ) .await .map_err(log_and_500)?; - if let Some(updated) = maybe_updated { - match updated { - UpdateDatabaseResult::AutoMigrateError(errs) => { - return Err((StatusCode::BAD_REQUEST, format!("Database update rejected: {errs}")).into()); - } - UpdateDatabaseResult::ErrorExecutingMigration(err) => { - return Err(( - StatusCode::BAD_REQUEST, - format!("Failed to create or update the database: {err}"), - ) - .into()); - } + match maybe_updated { + Some(UpdateDatabaseResult::AutoMigrateError(errs)) => { + Err(bad_request(format!("Database update rejected: {errs}").into())) + } + Some(UpdateDatabaseResult::ErrorExecutingMigration(err)) => Err(bad_request( + format!("Failed to create or update the database: {err}").into(), + )), + None + | Some( UpdateDatabaseResult::NoUpdateNeeded | UpdateDatabaseResult::UpdatePerformed - | UpdateDatabaseResult::UpdatePerformedWithClientDisconnect => {} + | UpdateDatabaseResult::UpdatePerformedWithClientDisconnect, + ) => Ok(axum::Json(PublishResult::Success { + domain: db_name.cloned(), + database_identity, + op: publish_op, + })), + } +} + +/// Try to resolve `name_or_identity` to an [Identity] and [DatabaseName]. +/// +/// - If the database exists and has a name registered for it, return that. +/// - If the database does not exist, but `name_or_identity` is a name, +/// try to register the name and return alongside a newly allocated [Identity] +/// - Otherwise, if the database does not exist and `name_or_identity` is `None`, +/// allocate a fresh [Identity] and no name. +/// +async fn get_or_create_identity_and_name<'a>( + ctx: &(impl ControlStateDelegate + NodeDelegate), + auth: &SpacetimeAuth, + name_or_identity: Option<&'a NameOrIdentity>, +) -> axum::response::Result<(Identity, Option<&'a DatabaseName>)> { + match name_or_identity { + Some(noi) => match noi.try_resolve(ctx).await.map_err(log_and_500)? { + Ok(resolved) => Ok((resolved, noi.name())), + Err(name) => { + // `name_or_identity` was a `NameOrIdentity::Name`, but no record + // exists yet. Create it now with a fresh identity. + allow_creation(auth)?; + let database_auth = SpacetimeAuth::alloc(ctx).await?; + let database_identity = database_auth.claims.identity; + create_name(ctx, auth, &database_identity, name).await?; + Ok((database_identity, Some(name))) + } + }, + None => { + let database_auth = SpacetimeAuth::alloc(ctx).await?; + let database_identity = database_auth.claims.identity; + Ok((database_identity, None)) } } +} - Ok(axum::Json(PublishResult::Success { - domain: db_name.cloned(), - database_identity, - op, - })) +/// Try to register `name` for database `database_identity`. +async fn create_name( + ctx: &(impl NodeDelegate + ControlStateDelegate), + auth: &SpacetimeAuth, + database_identity: &Identity, + name: &DatabaseName, +) -> axum::response::Result<()> { + let tld: name::Tld = name.clone().into(); + let tld = match ctx + .register_tld(&auth.claims.identity, tld) + .await + .map_err(log_and_500)? + { + name::RegisterTldResult::Success { domain } | name::RegisterTldResult::AlreadyRegistered { domain } => domain, + name::RegisterTldResult::Unauthorized { .. } => { + return Err(( + StatusCode::UNAUTHORIZED, + axum::Json(PublishResult::PermissionDenied { name: name.clone() }), + ) + .into()) + } + }; + let res = ctx + .create_dns_record(&auth.claims.identity, &tld.into(), database_identity) + .await + .map_err(log_and_500)?; + match res { + name::InsertDomainResult::Success { .. } => Ok(()), + name::InsertDomainResult::TldNotRegistered { .. } | name::InsertDomainResult::PermissionDenied { .. } => { + Err(log_and_500("impossible: we just registered the tld")) + } + name::InsertDomainResult::OtherError(e) => Err(log_and_500(e)), + } +} + +fn schema_migration_policy( + policy: MigrationPolicy, + token: Option, +) -> axum::response::Result { + const MISSING_TOKEN: &str = "Migration policy is set to `BreakClients`, but no migration token was provided."; + + match policy { + MigrationPolicy::BreakClients => token + .map(SchemaMigrationPolicy::BreakClients) + .ok_or_else(|| bad_request(MISSING_TOKEN.into())), + MigrationPolicy::Compatible => Ok(SchemaMigrationPolicy::Compatible), + } +} + +fn validate_replication_factor(n: usize) -> Result, ErrorResponse> { + let n = u8::try_from(n).map_err(|_| bad_request(format!("Replication factor {n} out of bounds").into()))?; + Ok(NonZeroU8::new(n)) +} + +fn bad_request(message: Cow<'static, str>) -> ErrorResponse { + (StatusCode::BAD_REQUEST, message).into() } #[derive(serde::Deserialize)] @@ -693,12 +810,12 @@ pub struct PrePublishQueryParams { host_type: HostType, } -pub async fn pre_publish( +pub async fn pre_publish( State(ctx): State, Path(PrePublishParams { name_or_identity }): Path, Query(PrePublishQueryParams { style, host_type }): Query, Extension(auth): Extension, - body: Bytes, + program_bytes: Bytes, ) -> axum::response::Result> { // User should not be able to print migration plans for a database that they do not own let database_identity = resolve_and_authenticate(&ctx, &name_or_identity, &auth).await?; @@ -711,9 +828,10 @@ pub async fn pre_publish( .migrate_plan( DatabaseDef { database_identity, - program_bytes: body.into(), + program_bytes, num_replicas: None, host_type, + parent: None, }, style, ) @@ -751,44 +869,41 @@ pub async fn pre_publish( /// Resolves the [`NameOrIdentity`] to a database identity and checks if the /// `auth` identity owns the database. -async fn resolve_and_authenticate( +async fn resolve_and_authenticate( ctx: &S, name_or_identity: &NameOrIdentity, auth: &SpacetimeAuth, ) -> axum::response::Result { let database_identity = name_or_identity.resolve(ctx).await?; - let database = worker_ctx_find_database(ctx, &database_identity) .await? .ok_or(NO_SUCH_DATABASE)?; - if database.owner_identity != auth.claims.identity { - return Err(( - StatusCode::UNAUTHORIZED, - format!( - "Identity does not own database, expected: {} got: {}", - database.owner_identity.to_hex(), - auth.claims.identity.to_hex() - ), - ) - .into()); - } + ctx.authorize_action(auth.claims.identity, database.database_identity, Action::UpdateDatabase) + .await + .map_err(Unauthorized::into_response)?; Ok(database_identity) } #[derive(Deserialize)] pub struct DeleteDatabaseParams { - name_or_identity: NameOrIdentity, + pub name_or_identity: NameOrIdentity, } -pub async fn delete_database( +pub async fn delete_database( State(ctx): State, Path(DeleteDatabaseParams { name_or_identity }): Path, Extension(auth): Extension, ) -> axum::response::Result { let database_identity = name_or_identity.resolve(&ctx).await?; + let Some(_database) = worker_ctx_find_database(&ctx, &database_identity).await? else { + return Ok(()); + }; + ctx.authorize_action(auth.claims.identity, database_identity, Action::DeleteDatabase) + .await + .map_err(Unauthorized::into_response)?; ctx.delete_database(&auth.claims.identity, &database_identity) .await .map_err(log_and_500)?; @@ -831,7 +946,7 @@ pub struct SetNamesParams { name_or_identity: NameOrIdentity, } -pub async fn set_names( +pub async fn set_names( State(ctx): State, Path(SetNamesParams { name_or_identity }): Path, Extension(auth): Extension, @@ -854,14 +969,18 @@ pub async fn set_names( )); }; - if database.owner_identity != auth.claims.identity { - return Ok(( - StatusCode::UNAUTHORIZED, - axum::Json(name::SetDomainsResult::NotYourDatabase { - database: database.database_identity, - }), - )); - } + ctx.authorize_action(auth.claims.identity, database.database_identity, Action::RenameDatabase) + .await + .map_err(|e| match e { + Unauthorized::Unauthorized { .. } => ( + StatusCode::UNAUTHORIZED, + axum::Json(name::SetDomainsResult::NotYourDatabase { + database: database.database_identity, + }), + ) + .into(), + Unauthorized::InternalError(e) => log_and_500(e), + })?; for name in &validated_names { if ctx.lookup_identity(name.as_str()).unwrap().is_some() { @@ -948,13 +1067,15 @@ pub struct DatabaseRoutes { pub sql_post: MethodRouter, /// POST: /database/:name_or_identity/pre-publish pub pre_publish: MethodRouter, + /// PUT: /database/:name_or_identity/reset + pub db_reset: MethodRouter, /// GET: /database/: name_or_identity/unstable/timestamp pub timestamp_get: MethodRouter, } impl Default for DatabaseRoutes where - S: NodeDelegate + ControlStateDelegate + HasWebSocketOptions + Clone + 'static, + S: NodeDelegate + ControlStateDelegate + HasWebSocketOptions + Authorization + Clone + 'static, { fn default() -> Self { use axum::routing::{delete, get, post, put}; @@ -973,6 +1094,7 @@ where logs_get: get(logs::), sql_post: post(sql::), pre_publish: post(pre_publish::), + db_reset: put(reset::), timestamp_get: get(get_timestamp::), } } @@ -980,7 +1102,7 @@ where impl DatabaseRoutes where - S: NodeDelegate + ControlStateDelegate + Clone + 'static, + S: NodeDelegate + ControlStateDelegate + Authorization + Clone + 'static, { pub fn into_router(self, ctx: S) -> axum::Router { let db_router = axum::Router::::new() @@ -997,7 +1119,8 @@ where .route("/logs", self.logs_get) .route("/sql", self.sql_post) .route("/unstable/timestamp", self.timestamp_get) - .route("/pre_publish", self.pre_publish); + .route("/pre_publish", self.pre_publish) + .route("/reset", self.db_reset); axum::Router::new() .route("/", self.root_post) diff --git a/crates/client-api/src/routes/identity.rs b/crates/client-api/src/routes/identity.rs index be9adde55f9..c8dc44d2964 100644 --- a/crates/client-api/src/routes/identity.rs +++ b/crates/client-api/src/routes/identity.rs @@ -2,6 +2,7 @@ use std::time::Duration; use axum::extract::{Path, State}; use axum::response::IntoResponse; +use axum::routing::MethodRouter; use http::header::CONTENT_TYPE; use http::StatusCode; use serde::{Deserialize, Serialize}; @@ -64,12 +65,12 @@ impl<'de> serde::Deserialize<'de> for IdentityForUrl { #[derive(Deserialize)] pub struct GetDatabasesParams { - identity: IdentityForUrl, + pub identity: IdentityForUrl, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct GetDatabasesResponse { - identities: Vec, + pub identities: Vec, } pub async fn get_databases( @@ -135,15 +136,46 @@ pub async fn get_public_key(State(ctx): State) -> axum::resp )) } -pub fn router() -> axum::Router +/// A struct to allow customization of the `/identity` routes. +pub struct IdentityRoutes { + /// POST /identity + pub create_post: MethodRouter, + /// GET /identity/public-key + pub public_key_get: MethodRouter, + /// POST /identity/websocket-tocken + pub websocket_token_post: MethodRouter, + /// GET /identity/:identity/verify + pub verify_get: MethodRouter, + /// GET /identity/:identity/databases + pub databases_get: MethodRouter, +} + +impl Default for IdentityRoutes +where + S: NodeDelegate + ControlStateDelegate + Clone + 'static, +{ + fn default() -> Self { + use axum::routing::{get, post}; + Self { + create_post: post(create_identity::), + public_key_get: get(get_public_key::), + websocket_token_post: post(create_websocket_token::), + verify_get: get(validate_token), + databases_get: get(get_databases::), + } + } +} + +impl IdentityRoutes where S: NodeDelegate + ControlStateDelegate + Clone + 'static, { - use axum::routing::{get, post}; - axum::Router::new() - .route("/", post(create_identity::)) - .route("/public-key", get(get_public_key::)) - .route("/websocket-token", post(create_websocket_token::)) - .route("/:identity/verify", get(validate_token)) - .route("/:identity/databases", get(get_databases::)) + pub fn into_router(self) -> axum::Router { + axum::Router::new() + .route("/", self.create_post) + .route("/public-key", self.public_key_get) + .route("/websocket-token", self.websocket_token_post) + .route("/:identity/verify", self.verify_get) + .route("/:identity/databases", self.databases_get) + } } diff --git a/crates/client-api/src/routes/mod.rs b/crates/client-api/src/routes/mod.rs index f0930eefb4c..08e1e73cb77 100644 --- a/crates/client-api/src/routes/mod.rs +++ b/crates/client-api/src/routes/mod.rs @@ -1,8 +1,7 @@ -use database::DatabaseRoutes; use http::header; use tower_http::cors; -use crate::{ControlStateDelegate, NodeDelegate}; +use crate::{Authorization, ControlStateDelegate, NodeDelegate}; pub mod database; pub mod energy; @@ -13,19 +12,26 @@ pub mod metrics; pub mod prometheus; pub mod subscribe; +use self::{database::DatabaseRoutes, identity::IdentityRoutes}; + /// This API call is just designed to allow clients to determine whether or not they can /// establish a connection to SpacetimeDB. This API call doesn't actually do anything. pub async fn ping(_auth: crate::auth::SpacetimeAuthHeader) {} #[allow(clippy::let_and_return)] -pub fn router(ctx: &S, database_routes: DatabaseRoutes, extra: axum::Router) -> axum::Router +pub fn router( + ctx: &S, + database_routes: DatabaseRoutes, + identity_routes: IdentityRoutes, + extra: axum::Router, +) -> axum::Router where - S: NodeDelegate + ControlStateDelegate + Clone + 'static, + S: NodeDelegate + ControlStateDelegate + Authorization + Clone + 'static, { use axum::routing::get; let router = axum::Router::new() .nest("/database", database_routes.into_router(ctx.clone())) - .nest("/identity", identity::router()) + .nest("/identity", identity_routes.into_router()) .nest("/energy", energy::router()) .nest("/prometheus", prometheus::router()) .nest("/metrics", metrics::router()) diff --git a/crates/client-api/src/routes/subscribe.rs b/crates/client-api/src/routes/subscribe.rs index 457fb0bf96b..1a53b578786 100644 --- a/crates/client-api/src/routes/subscribe.rs +++ b/crates/client-api/src/routes/subscribe.rs @@ -49,7 +49,7 @@ use crate::util::websocket::{ CloseCode, CloseFrame, Message as WsMessage, WebSocketConfig, WebSocketStream, WebSocketUpgrade, WsError, }; use crate::util::{NameOrIdentity, XForwardedFor}; -use crate::{log_and_500, ControlStateDelegate, NodeDelegate}; +use crate::{log_and_500, Authorization, ControlStateDelegate, NodeDelegate, Unauthorized}; #[allow(clippy::declare_interior_mutable_const)] pub const TEXT_PROTOCOL: HeaderValue = HeaderValue::from_static(ws_api::TEXT_PROTOCOL); @@ -106,7 +106,7 @@ pub async fn handle_websocket( ws: WebSocketUpgrade, ) -> axum::response::Result where - S: NodeDelegate + ControlStateDelegate + HasWebSocketOptions, + S: NodeDelegate + ControlStateDelegate + HasWebSocketOptions + Authorization, { if connection_id.is_some() { // TODO: Bump this up to `log::warn!` after removing the client SDKs' uses of that parameter. @@ -125,6 +125,10 @@ where } let db_identity = name_or_identity.resolve(&ctx).await?; + let sql_auth = ctx + .authorize_sql(auth.claims.identity, db_identity) + .await + .map_err(Unauthorized::into_response)?; let (res, ws_upgrade, protocol) = ws.select_protocol([(BIN_PROTOCOL, Protocol::Binary), (TEXT_PROTOCOL, Protocol::Text)]); @@ -218,6 +222,7 @@ where let client = ClientConnection::spawn( client_id, auth.into(), + sql_auth, client_config, leader.replica_id, module_rx, diff --git a/crates/client-api/src/util.rs b/crates/client-api/src/util.rs index c38bf33c0ae..509b891e483 100644 --- a/crates/client-api/src/util.rs +++ b/crates/client-api/src/util.rs @@ -89,7 +89,7 @@ impl NameOrIdentity { /// Otherwise, if `self` is a [`NameOrIdentity::Name`], the [`Identity`] is /// looked up by that name in the SpacetimeDB DNS and returned. /// - /// Errors are returned if [`NameOrIdentity::Name`] the DNS lookup fails. + /// Errors are returned if the DNS lookup fails. /// /// An `Ok` result is itself a [`Result`], which is `Err(DatabaseName)` if the /// given [`NameOrIdentity::Name`] is not registered in the SpacetimeDB DNS, @@ -111,7 +111,7 @@ impl NameOrIdentity { self.try_resolve(ctx) .await .map_err(log_and_500)? - .map_err(|_| StatusCode::NOT_FOUND.into()) + .map_err(|name| (StatusCode::NOT_FOUND, format!("Could not resolve database `{name}`")).into()) } } diff --git a/crates/core/src/client/client_connection.rs b/crates/core/src/client/client_connection.rs index ccc63497bc2..1f2770c426f 100644 --- a/crates/core/src/client/client_connection.rs +++ b/crates/core/src/client/client_connection.rs @@ -28,7 +28,7 @@ use spacetimedb_client_api_messages::websocket::{ UnsubscribeMulti, }; use spacetimedb_durability::{DurableOffset, TxOffset}; -use spacetimedb_lib::identity::RequestId; +use spacetimedb_lib::identity::{AuthCtx, RequestId}; use spacetimedb_lib::metrics::ExecutionMetrics; use spacetimedb_lib::Identity; use tokio::sync::mpsc::error::{SendError, TrySendError}; @@ -423,6 +423,7 @@ pub struct ClientConnection { sender: Arc, pub replica_id: u64, module_rx: watch::Receiver, + auth: AuthCtx, } impl Deref for ClientConnection { @@ -674,9 +675,11 @@ impl ClientConnection { /// to verify that the database at `module_rx` approves of this connection, /// and should not invoke this method if that call returns an error, /// and pass the returned [`Connected`] as `_proof_of_client_connected_call`. + #[allow(clippy::too_many_arguments)] pub async fn spawn( id: ClientActorId, auth: ConnectionAuthCtx, + sql_auth: AuthCtx, config: ClientConfig, replica_id: u64, mut module_rx: watch::Receiver, @@ -734,6 +737,7 @@ impl ClientConnection { sender, replica_id, module_rx, + auth: sql_auth, }; let actor_fut = actor(this.clone(), receiver); @@ -749,10 +753,12 @@ impl ClientConnection { replica_id: u64, module_rx: watch::Receiver, ) -> Self { + let auth = AuthCtx::new(module_rx.borrow().database_info().database_identity, id.identity); Self { sender: Arc::new(ClientConnectionSender::dummy(id, config, module_rx.clone())), replica_id, module_rx, + auth, } } @@ -842,9 +848,13 @@ impl ClientConnection { let me = self.clone(); self.module() .on_module_thread("subscribe_single", move || { - me.module() - .subscriptions() - .add_single_subscription(me.sender, subscription, timer, None) + me.module().subscriptions().add_single_subscription( + me.sender, + me.auth.clone(), + subscription, + timer, + None, + ) }) .await? } @@ -854,7 +864,7 @@ impl ClientConnection { asyncify(move || { me.module() .subscriptions() - .remove_single_subscription(me.sender, request, timer) + .remove_single_subscription(me.sender, me.auth.clone(), request, timer) }) .await } @@ -869,7 +879,7 @@ impl ClientConnection { .on_module_thread("subscribe_multi", move || { me.module() .subscriptions() - .add_multi_subscription(me.sender, request, timer, None) + .add_multi_subscription(me.sender, me.auth.clone(), request, timer, None) }) .await? } @@ -884,7 +894,7 @@ impl ClientConnection { .on_module_thread("unsubscribe_multi", move || { me.module() .subscriptions() - .remove_multi_subscription(me.sender, request, timer) + .remove_multi_subscription(me.sender, me.auth.clone(), request, timer) }) .await? } @@ -894,7 +904,7 @@ impl ClientConnection { asyncify(move || { me.module() .subscriptions() - .add_legacy_subscriber(me.sender, subscription, timer, None) + .add_legacy_subscriber(me.sender, me.auth.clone(), subscription, timer, None) }) .await } @@ -907,7 +917,7 @@ impl ClientConnection { ) -> Result<(), anyhow::Error> { self.module() .one_off_query::( - self.id.identity, + self.auth.clone(), query.to_owned(), self.sender.clone(), message_id.to_owned(), @@ -925,7 +935,7 @@ impl ClientConnection { ) -> Result<(), anyhow::Error> { self.module() .one_off_query::( - self.id.identity, + self.auth.clone(), query.to_owned(), self.sender.clone(), message_id.to_owned(), diff --git a/crates/core/src/db/relational_db.rs b/crates/core/src/db/relational_db.rs index 639ae8f0ba4..7faaa3fc517 100644 --- a/crates/core/src/db/relational_db.rs +++ b/crates/core/src/db/relational_db.rs @@ -624,6 +624,10 @@ impl RelationalDB { self.database_identity } + pub fn owner_identity(&self) -> Identity { + self.owner_identity + } + /// The number of bytes on disk occupied by the durability layer. /// /// If this is an in-memory instance, `Ok(0)` is returned. diff --git a/crates/core/src/host/host_controller.rs b/crates/core/src/host/host_controller.rs index 72d9c6d2e18..b6643d578d3 100644 --- a/crates/core/src/host/host_controller.rs +++ b/crates/core/src/host/host_controller.rs @@ -542,12 +542,7 @@ async fn make_replica_ctx( send_worker_queue.clone(), ))); let downgraded = Arc::downgrade(&subscriptions); - let subscriptions = ModuleSubscriptions::new( - relational_db.clone(), - subscriptions, - send_worker_queue, - database.owner_identity, - ); + let subscriptions = ModuleSubscriptions::new(relational_db.clone(), subscriptions, send_worker_queue); // If an error occurs when evaluating a subscription, // we mark each client that was affected, diff --git a/crates/core/src/host/module_host.rs b/crates/core/src/host/module_host.rs index 899f374cff1..377bd1871b7 100644 --- a/crates/core/src/host/module_host.rs +++ b/crates/core/src/host/module_host.rs @@ -1329,7 +1329,7 @@ impl ModuleHost { #[tracing::instrument(level = "trace", skip_all)] pub async fn one_off_query( &self, - caller_identity: Identity, + auth: AuthCtx, query: String, client: Arc, message_id: Vec, @@ -1340,7 +1340,6 @@ impl ModuleHost { let replica_ctx = self.replica_ctx(); let db = replica_ctx.relational_db.clone(); let subscriptions = replica_ctx.subscriptions.clone(); - let auth = AuthCtx::new(replica_ctx.owner_identity, caller_identity); log::debug!("One-off query: {query}"); let metrics = self .on_module_thread("one_off_query", move || { diff --git a/crates/core/src/sql/ast.rs b/crates/core/src/sql/ast.rs index 898a30c7279..8b76d80a67a 100644 --- a/crates/core/src/sql/ast.rs +++ b/crates/core/src/sql/ast.rs @@ -6,7 +6,6 @@ use spacetimedb_datastore::locking_tx_datastore::state_view::StateView; use spacetimedb_datastore::system_tables::{StRowLevelSecurityFields, ST_ROW_LEVEL_SECURITY_ID}; use spacetimedb_expr::check::SchemaView; use spacetimedb_expr::statement::compile_sql_stmt; -use spacetimedb_lib::db::auth::StAccess; use spacetimedb_lib::identity::AuthCtx; use spacetimedb_primitives::{ColId, TableId}; use spacetimedb_sats::{AlgebraicType, AlgebraicValue}; @@ -492,22 +491,20 @@ impl Deref for SchemaViewer<'_, T> { impl SchemaView for SchemaViewer<'_, T> { fn table_id(&self, name: &str) -> Option { - let AuthCtx { owner, caller } = self.auth; // Get the schema from the in-memory state instead of fetching from the database for speed self.tx .table_id_from_name(name) .ok() .flatten() .and_then(|table_id| self.schema_for_table(table_id)) - .filter(|schema| schema.table_access == StAccess::Public || caller == owner) + .filter(|schema| self.auth.has_read_access(schema.table_access)) .map(|schema| schema.table_id) } fn schema_for_table(&self, table_id: TableId) -> Option> { - let AuthCtx { owner, caller } = self.auth; self.tx .get_schema(table_id) - .filter(|schema| schema.table_access == StAccess::Public || caller == owner) + .filter(|schema| self.auth.has_read_access(schema.table_access)) .cloned() } diff --git a/crates/core/src/sql/execute.rs b/crates/core/src/sql/execute.rs index 089148cdd2e..1fac67ef175 100644 --- a/crates/core/src/sql/execute.rs +++ b/crates/core/src/sql/execute.rs @@ -122,7 +122,7 @@ pub fn execute_sql( let mut tx = db.begin_mut_tx(IsolationLevel::Serializable, Workload::Sql); let mut updates = Vec::with_capacity(ast.len()); let res = execute( - &mut DbProgram::new(db, &mut (&mut tx).into(), auth), + &mut DbProgram::new(db, &mut (&mut tx).into(), auth.clone()), ast, sql, &mut updates, @@ -130,7 +130,7 @@ pub fn execute_sql( if res.is_ok() && !updates.is_empty() { let event = ModuleEvent { timestamp: Timestamp::now(), - caller_identity: auth.caller, + caller_identity: auth.caller(), caller_connection_id: None, function_call: ModuleFunctionCall { reducer: String::new(), @@ -249,7 +249,7 @@ pub fn run( } Statement::DML(stmt) => { // An extra layer of auth is required for DML - if auth.caller != auth.owner { + if !auth.has_write_access() { return Err(anyhow!("Only owners are authorized to run SQL DML statements").into()); } @@ -287,7 +287,7 @@ pub fn run( None, ModuleEvent { timestamp: Timestamp::now(), - caller_identity: auth.caller, + caller_identity: auth.caller(), caller_connection_id: None, function_call: ModuleFunctionCall { reducer: String::new(), @@ -510,7 +510,7 @@ pub(crate) mod tests { expected: impl IntoIterator, ) { assert_eq!( - run(db, sql, *auth, None, &mut vec![]) + run(db, sql, auth.clone(), None, &mut vec![]) .unwrap() .rows .into_iter() @@ -1270,19 +1270,25 @@ pub(crate) mod tests { let run = |db, sql, auth, subs| run(db, sql, auth, subs, &mut vec![]); // No row limit, both queries pass. - assert!(run(&db, "SELECT * FROM T", internal_auth, None).is_ok()); - assert!(run(&db, "SELECT * FROM T", external_auth, None).is_ok()); + assert!(run(&db, "SELECT * FROM T", internal_auth.clone(), None).is_ok()); + assert!(run(&db, "SELECT * FROM T", external_auth.clone(), None).is_ok()); // Set row limit. - assert!(run(&db, "SET row_limit = 4", internal_auth, None).is_ok()); + assert!(run(&db, "SET row_limit = 4", internal_auth.clone(), None).is_ok()); // External query fails. - assert!(run(&db, "SELECT * FROM T", internal_auth, None).is_ok()); - assert!(run(&db, "SELECT * FROM T", external_auth, None).is_err()); + assert!(run(&db, "SELECT * FROM T", internal_auth.clone(), None).is_ok()); + assert!(run(&db, "SELECT * FROM T", external_auth.clone(), None).is_err()); // Increase row limit. - assert!(run(&db, "DELETE FROM st_var WHERE name = 'row_limit'", internal_auth, None).is_ok()); - assert!(run(&db, "SET row_limit = 5", internal_auth, None).is_ok()); + assert!(run( + &db, + "DELETE FROM st_var WHERE name = 'row_limit'", + internal_auth.clone(), + None + ) + .is_ok()); + assert!(run(&db, "SET row_limit = 5", internal_auth.clone(), None).is_ok()); // Both queries pass. assert!(run(&db, "SELECT * FROM T", internal_auth, None).is_ok()); @@ -1333,10 +1339,10 @@ pub(crate) mod tests { ..ExecutionMetrics::default() }; - check(&db, "INSERT INTO T (a) VALUES (5)", internal_auth, ins)?; - check(&db, "UPDATE T SET a = 2", internal_auth, upd)?; + check(&db, "INSERT INTO T (a) VALUES (5)", internal_auth.clone(), ins)?; + check(&db, "UPDATE T SET a = 2", internal_auth.clone(), upd)?; assert_eq!( - run(&db, "SELECT * FROM T", internal_auth, None)?.rows, + run(&db, "SELECT * FROM T", internal_auth.clone(), None)?.rows, vec![product!(2u8)] ); check(&db, "DELETE FROM T", internal_auth, del)?; diff --git a/crates/core/src/subscription/execution_unit.rs b/crates/core/src/subscription/execution_unit.rs index 75eb0a0442c..794495d38c7 100644 --- a/crates/core/src/subscription/execution_unit.rs +++ b/crates/core/src/subscription/execution_unit.rs @@ -10,6 +10,7 @@ use crate::util::slow::SlowQueryLogger; use crate::vm::{build_query, TxMode}; use spacetimedb_client_api_messages::websocket::{Compression, QueryUpdate, RowListLen as _, SingleQueryUpdate}; use spacetimedb_datastore::locking_tx_datastore::TxId; +use spacetimedb_lib::identity::AuthCtx; use spacetimedb_lib::Identity; use spacetimedb_primitives::TableId; use spacetimedb_sats::{u256, ProductValue}; @@ -335,7 +336,7 @@ impl ExecutionUnit { } impl AuthAccess for ExecutionUnit { - fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError> { - self.eval_plan.check_auth(owner, caller) + fn check_auth(&self, auth: &AuthCtx) -> Result<(), AuthError> { + self.eval_plan.check_auth(auth) } } diff --git a/crates/core/src/subscription/module_subscription_actor.rs b/crates/core/src/subscription/module_subscription_actor.rs index 41eb62e60ba..ce2b2a93d78 100644 --- a/crates/core/src/subscription/module_subscription_actor.rs +++ b/crates/core/src/subscription/module_subscription_actor.rs @@ -50,7 +50,6 @@ pub struct ModuleSubscriptions { /// You will deadlock otherwise. subscriptions: Subscriptions, broadcast_queue: BroadcastQueue, - owner_identity: Identity, stats: Arc, } @@ -179,7 +178,6 @@ impl ModuleSubscriptions { relational_db: Arc, subscriptions: Subscriptions, broadcast_queue: BroadcastQueue, - owner_identity: Identity, ) -> Self { let db = &relational_db.database_identity(); let stats = Arc::new(SubscriptionGauges::new(db)); @@ -188,7 +186,6 @@ impl ModuleSubscriptions { relational_db, subscriptions, broadcast_queue, - owner_identity, stats, } } @@ -209,7 +206,6 @@ impl ModuleSubscriptions { db, SubscriptionManager::for_test_without_metrics_arc_rwlock(), send_worker_queue, - Identity::ZERO, ) } @@ -307,6 +303,7 @@ impl ModuleSubscriptions { pub fn add_single_subscription( &self, sender: Arc, + auth: AuthCtx, request: SubscribeSingle, timer: Instant, _assert: Option, @@ -329,9 +326,8 @@ impl ModuleSubscriptions { }; let sql = request.query; - let auth = AuthCtx::new(self.owner_identity, sender.id.identity); - let hash = QueryHash::from_string(&sql, auth.caller, false); - let hash_with_param = QueryHash::from_string(&sql, auth.caller, true); + let hash = QueryHash::from_string(&sql, auth.caller(), false); + let hash_with_param = QueryHash::from_string(&sql, auth.caller(), true); let (tx, tx_offset) = self.begin_tx(Workload::Subscribe); @@ -398,6 +394,7 @@ impl ModuleSubscriptions { pub fn remove_single_subscription( &self, sender: Arc, + auth: AuthCtx, request: Unsubscribe, timer: Instant, ) -> Result, DBError> { @@ -435,7 +432,6 @@ impl ModuleSubscriptions { }; let (tx, tx_offset) = self.begin_tx(Workload::Unsubscribe); - let auth = AuthCtx::new(self.owner_identity, sender.id.identity); let (table_rows, metrics) = return_on_err_with_sql!( self.evaluate_initial_subscription(sender.clone(), query.clone(), &tx, &auth, TableUpdateType::Unsubscribe), query.sql(), @@ -471,6 +467,7 @@ impl ModuleSubscriptions { pub fn remove_multi_subscription( &self, sender: Arc, + auth: AuthCtx, request: UnsubscribeMulti, timer: Instant, ) -> Result, DBError> { @@ -518,7 +515,7 @@ impl ModuleSubscriptions { sender.clone(), &removed_queries, &tx, - &AuthCtx::new(self.owner_identity, sender.id.identity), + &auth, TableUpdateType::Unsubscribe, ), send_err_msg, @@ -567,6 +564,7 @@ impl ModuleSubscriptions { fn compile_queries( &self, sender: Identity, + auth: AuthCtx, queries: &[Box], num_queries: usize, metrics: &SubscriptionMetrics, @@ -586,8 +584,6 @@ impl ModuleSubscriptions { query_hashes.push((sql, hash, hash_with_param)); } - let auth = AuthCtx::new(self.owner_identity, sender); - // We always get the db lock before the subscription lock to avoid deadlocks. let (tx, _tx_offset) = self.begin_tx(Workload::Subscribe); @@ -651,6 +647,7 @@ impl ModuleSubscriptions { pub fn add_multi_subscription( &self, sender: Arc, + auth: AuthCtx, request: SubscribeMulti, timer: Instant, _assert: Option, @@ -683,6 +680,7 @@ impl ModuleSubscriptions { let (queries, auth, tx, compile_timer) = return_on_err!( self.compile_queries( sender.id.identity, + auth, &request.query_strings, num_queries, &subscription_metrics @@ -765,6 +763,7 @@ impl ModuleSubscriptions { pub fn add_legacy_subscriber( &self, sender: Arc, + auth: AuthCtx, subscription: Subscribe, timer: Instant, _assert: Option, @@ -778,6 +777,7 @@ impl ModuleSubscriptions { let (queries, auth, tx, compile_timer) = self.compile_queries( sender.id.identity, + auth, &subscription.query_strings, num_queries, &subscription_metrics, @@ -1069,14 +1069,14 @@ mod tests { db.clone(), SubscriptionManager::for_test_without_metrics_arc_rwlock(), send_worker_queue, - owner, ); + let auth = AuthCtx::new(owner, sender.auth.claims.identity); let subscribe = Subscribe { query_strings: [sql.into()].into(), request_id: 0, }; - module_subscriptions.add_legacy_subscriber(sender, subscribe, Instant::now(), assert)?; + module_subscriptions.add_legacy_subscriber(sender, auth, subscribe, Instant::now(), assert)?; Ok(()) } @@ -1315,25 +1315,27 @@ mod tests { /// Subscribe to a query as a client fn subscribe_single( subs: &ModuleSubscriptions, + auth: AuthCtx, sql: &'static str, sender: Arc, counter: &mut u32, ) -> anyhow::Result<()> { *counter += 1; - subs.add_single_subscription(sender, single_subscribe(sql, *counter), Instant::now(), None)?; + subs.add_single_subscription(sender, auth, single_subscribe(sql, *counter), Instant::now(), None)?; Ok(()) } /// Subscribe to a set of queries as a client fn subscribe_multi( subs: &ModuleSubscriptions, + auth: AuthCtx, queries: &[&'static str], sender: Arc, counter: &mut u32, ) -> anyhow::Result { *counter += 1; let metrics = subs - .add_multi_subscription(sender, multi_subscribe(queries, *counter), Instant::now(), None) + .add_multi_subscription(sender, auth, multi_subscribe(queries, *counter), Instant::now(), None) .map(|metrics| metrics.unwrap_or_default())?; Ok(metrics) } @@ -1341,20 +1343,22 @@ mod tests { /// Unsubscribe from a single query fn unsubscribe_single( subs: &ModuleSubscriptions, + auth: AuthCtx, sender: Arc, query_id: u32, ) -> anyhow::Result<()> { - subs.remove_single_subscription(sender, single_unsubscribe(query_id), Instant::now())?; + subs.remove_single_subscription(sender, auth, single_unsubscribe(query_id), Instant::now())?; Ok(()) } /// Unsubscribe from a set of queries fn unsubscribe_multi( subs: &ModuleSubscriptions, + auth: AuthCtx, sender: Arc, query_id: u32, ) -> anyhow::Result<()> { - subs.remove_multi_subscription(sender, multi_unsubscribe(query_id), Instant::now())?; + subs.remove_multi_subscription(sender, auth, multi_unsubscribe(query_id), Instant::now())?; Ok(()) } @@ -1536,13 +1540,14 @@ mod tests { let client_id = client_id_from_u8(1); let (tx, mut rx) = client_connection(client_id, &db); + let auth = AuthCtx::new(db.owner_identity(), client_id.identity); let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); db.create_table_for_test("t", &[("x", AlgebraicType::U8)], &[])?; // Subscribe to an invalid query (r is not in scope) let sql = "select r.* from t"; - subscribe_single(&subs, sql, tx, &mut 0)?; + subscribe_single(&subs, auth, sql, tx, &mut 0)?; check_subscription_err(sql, rx.recv().await); @@ -1557,13 +1562,14 @@ mod tests { let client_id = client_id_from_u8(1); let (tx, mut rx) = client_connection(client_id, &db); + let auth = AuthCtx::new(db.owner_identity(), client_id.identity); let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); db.create_table_for_test("t", &[("x", AlgebraicType::U8)], &[])?; // Subscribe to an invalid query (r is not in scope) let sql = "select r.* from t"; - subscribe_multi(&subs, &[sql], tx, &mut 0)?; + subscribe_multi(&subs, auth, &[sql], tx, &mut 0)?; check_subscription_err(sql, rx.recv().await); @@ -1578,6 +1584,7 @@ mod tests { let client_id = client_id_from_u8(1); let (tx, mut rx) = client_connection(client_id, &db); + let auth = AuthCtx::new(db.owner_identity(), client_id.identity); let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); // Create a table `t` with an index on `id` @@ -1596,7 +1603,7 @@ mod tests { // Subscribe to `t` let sql = "select * from t where id = 1"; - subscribe_single(&subs, sql, tx.clone(), &mut query_id)?; + subscribe_single(&subs, auth.clone(), sql, tx.clone(), &mut query_id)?; // The initial subscription should succeed assert!(matches!( @@ -1611,7 +1618,7 @@ mod tests { with_auto_commit(&db, |tx| db.drop_index(tx, index_id))?; // Unsubscribe from `t` - unsubscribe_single(&subs, tx, query_id)?; + unsubscribe_single(&subs, auth, tx, query_id)?; // Why does the unsubscribe fail? // This relies on some knowledge of the underlying implementation. @@ -1633,6 +1640,7 @@ mod tests { let client_id = client_id_from_u8(1); let (tx, mut rx) = client_connection(client_id, &db); + let auth = AuthCtx::new(db.owner_identity(), client_id.identity); let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); // Create a table `t` with an index on `id` @@ -1653,7 +1661,7 @@ mod tests { // Subscribe to `t` let sql = "select * from t where id = 1"; - subscribe_multi(&subs, &[sql], tx.clone(), &mut query_id)?; + subscribe_multi(&subs, auth.clone(), &[sql], tx.clone(), &mut query_id)?; // The initial subscription should succeed assert!(matches!( @@ -1668,7 +1676,7 @@ mod tests { with_auto_commit(&db, |tx| db.drop_index(tx, index_id))?; // Unsubscribe from `t` - unsubscribe_multi(&subs, tx, query_id)?; + unsubscribe_multi(&subs, auth, tx, query_id)?; // Why does the unsubscribe fail? // This relies on some knowledge of the underlying implementation. @@ -1688,6 +1696,7 @@ mod tests { let client_id = client_id_from_u8(1); let (tx, mut rx) = client_connection(client_id, &db); + let auth = AuthCtx::new(db.owner_identity(), client_id.identity); let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); // Create two tables `t` and `s` with indexes on their `id` columns @@ -1703,7 +1712,7 @@ mod tests { }) })?; let sql = "select t.* from t join s on t.id = s.id"; - subscribe_single(&subs, sql, tx, &mut 0)?; + subscribe_single(&subs, auth, sql, tx, &mut 0)?; // The initial subscription should succeed assert!(matches!( @@ -1752,6 +1761,9 @@ mod tests { let (tx_for_a, mut rx_for_a) = client_connection(client_id_for_a, &db); let (tx_for_b, mut rx_for_b) = client_connection(client_id_for_b, &db); + let auth_for_a = AuthCtx::new(db.owner_identity(), client_id_for_a.identity); + let auth_for_b = AuthCtx::new(db.owner_identity(), client_id_for_b.identity); + let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); let schema = [("identity", AlgebraicType::identity())]; @@ -1764,12 +1776,14 @@ mod tests { // Each client should receive different rows. subscribe_multi( &subs, + auth_for_a, &["select * from t where identity = :sender"], tx_for_a, &mut query_ids, )?; subscribe_multi( &subs, + auth_for_b, &["select * from t where identity = :sender"], tx_for_b, &mut query_ids, @@ -1819,6 +1833,9 @@ mod tests { let (tx_for_a, mut rx_for_a) = client_connection(client_id_for_a, &db); let (tx_for_b, mut rx_for_b) = client_connection(client_id_for_b, &db); + let auth_for_a = AuthCtx::new(db.owner_identity(), client_id_for_a.identity); + let auth_for_b = AuthCtx::new(db.owner_identity(), client_id_for_b.identity); + let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); let schema = [("id", AlgebraicType::identity())]; @@ -1843,8 +1860,8 @@ mod tests { // Have each client subscribe to `w`. // Because `w` is gated using parameterized RLS rules, // each client should receive different rows. - subscribe_multi(&subs, &["select * from w"], tx_for_a, &mut query_ids)?; - subscribe_multi(&subs, &["select * from w"], tx_for_b, &mut query_ids)?; + subscribe_multi(&subs, auth_for_a, &["select * from w"], tx_for_a, &mut query_ids)?; + subscribe_multi(&subs, auth_for_b, &["select * from w"], tx_for_b, &mut query_ids)?; // Wait for both subscriptions assert!(matches!( @@ -1883,9 +1900,15 @@ mod tests { async fn test_rls_for_owner() -> anyhow::Result<()> { let db = relational_db()?; + let client_id_for_a = client_id_from_u8(0); + let client_id_for_b = client_id_from_u8(1); + // Establish a connection for owner and client - let (tx_for_a, mut rx_for_a) = client_connection(client_id_from_u8(0), &db); - let (tx_for_b, mut rx_for_b) = client_connection(client_id_from_u8(1), &db); + let (tx_for_a, mut rx_for_a) = client_connection(client_id_for_a, &db); + let (tx_for_b, mut rx_for_b) = client_connection(client_id_for_b, &db); + + let auth_for_a = AuthCtx::new(db.owner_identity(), client_id_for_a.identity); + let auth_for_b = AuthCtx::new(db.owner_identity(), client_id_for_b.identity); let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); @@ -1898,8 +1921,8 @@ mod tests { let mut query_ids = 0; // Have owner and client subscribe to `t` - subscribe_multi(&subs, &["select * from t"], tx_for_a, &mut query_ids)?; - subscribe_multi(&subs, &["select * from t"], tx_for_b, &mut query_ids)?; + subscribe_multi(&subs, auth_for_a, &["select * from t"], tx_for_a, &mut query_ids)?; + subscribe_multi(&subs, auth_for_b, &["select * from t"], tx_for_b, &mut query_ids)?; // Wait for both subscriptions assert_matches!( @@ -1961,9 +1984,12 @@ mod tests { async fn test_no_empty_updates() -> anyhow::Result<()> { let db = relational_db()?; + let client_id = client_id_from_u8(1); + // Establish a client connection - let (tx, mut rx) = client_connection(client_id_from_u8(1), &db); + let (tx, mut rx) = client_connection(client_id, &db); + let auth = AuthCtx::new(db.owner_identity(), client_id.identity); let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); let schema = [("x", AlgebraicType::U8)]; @@ -1971,7 +1997,7 @@ mod tests { let t_id = db.create_table_for_test("t", &schema, &[])?; // Subscribe to rows of `t` where `x` is 0 - subscribe_multi(&subs, &["select * from t where x = 0"], tx, &mut 0)?; + subscribe_multi(&subs, auth, &["select * from t where x = 0"], tx, &mut 0)?; // Wait to receive the initial subscription message assert!(matches!(rx.recv().await, Some(SerializableMessage::Subscription(_)))); @@ -2011,9 +2037,11 @@ mod tests { async fn test_no_compression_for_subscribe() -> anyhow::Result<()> { let db = relational_db()?; + let client_id = client_id_from_u8(1); // Establish a client connection with compression - let (tx, mut rx) = client_connection_with_compression(client_id_from_u8(1), &db, Compression::Brotli); + let (tx, mut rx) = client_connection_with_compression(client_id, &db, Compression::Brotli); + let auth = AuthCtx::new(db.owner_identity(), client_id.identity); let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); let table_id = db.create_table_for_test("t", &[("x", AlgebraicType::U64)], &[])?; @@ -2029,7 +2057,7 @@ mod tests { commit_tx(&db, &subs, [], inserts)?; // Subscribe to the entire table - subscribe_multi(&subs, &["select * from t"], tx, &mut 0)?; + subscribe_multi(&subs, auth, &["select * from t"], tx, &mut 0)?; // Assert the table updates within this message are all be uncompressed match rx.recv().await { @@ -2057,14 +2085,16 @@ mod tests { let db = relational_db()?; // Establish a client connection - let (tx, mut rx) = client_connection(client_id_from_u8(1), &db); + let client_id = client_id_from_u8(1); + let (tx, mut rx) = client_connection(client_id, &db); + let auth = AuthCtx::new(db.owner_identity(), client_id.identity); let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); let schema = [("x", AlgebraicType::U8), ("y", AlgebraicType::U8)]; let t_id = db.create_table_for_test("t", &schema, &[])?; // Subscribe to `t` - subscribe_multi(&subs, &["select * from t"], tx, &mut 0)?; + subscribe_multi(&subs, auth, &["select * from t"], tx, &mut 0)?; // Wait to receive the initial subscription message assert_matches!(rx.recv().await, Some(SerializableMessage::Subscription(_))); @@ -2077,7 +2107,7 @@ mod tests { run( &db, "INSERT INTO t (x, y) VALUES (0, 1)", - auth, + auth.clone(), Some(&subs), &mut vec![], )?; @@ -2085,7 +2115,13 @@ mod tests { // Client should receive insert assert_tx_update_for_table(rx.recv(), t_id, &schema, [product![0_u8, 1_u8]], []).await; - run(&db, "UPDATE t SET y=2 WHERE x=0", auth, Some(&subs), &mut vec![])?; + run( + &db, + "UPDATE t SET y=2 WHERE x=0", + auth.clone(), + Some(&subs), + &mut vec![], + )?; // Client should receive update assert_tx_update_for_table(rx.recv(), t_id, &schema, [product![0_u8, 2_u8]], [product![0_u8, 1_u8]]).await; @@ -2105,8 +2141,10 @@ mod tests { let db = relational_db()?; // Establish a client connection with compression - let (tx, mut rx) = client_connection_with_compression(client_id_from_u8(1), &db, Compression::Brotli); + let client_id = client_id_from_u8(1); + let (tx, mut rx) = client_connection_with_compression(client_id, &db, Compression::Brotli); + let auth = AuthCtx::new(db.owner_identity(), client_id.identity); let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); let table_id = db.create_table_for_test("t", &[("x", AlgebraicType::U64)], &[])?; @@ -2118,7 +2156,7 @@ mod tests { } // Subscribe to the entire table - subscribe_multi(&subs, &["select * from t"], tx, &mut 0)?; + subscribe_multi(&subs, auth, &["select * from t"], tx, &mut 0)?; // Wait to receive the initial subscription message assert!(matches!(rx.recv().await, Some(SerializableMessage::Subscription(_)))); @@ -2156,8 +2194,10 @@ mod tests { let db = relational_db()?; // Establish a client connection - let (sender, mut rx) = client_connection(client_id_from_u8(1), &db); + let client_id = client_id_from_u8(1); + let (sender, mut rx) = client_connection(client_id, &db); + let auth = AuthCtx::new(db.owner_identity(), client_id.identity); let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); let p_schema = [("id", AlgebraicType::U64), ("signed_in", AlgebraicType::Bool)]; @@ -2170,7 +2210,7 @@ mod tests { let p_id = db.create_table_for_test("p", &p_schema, &[0.into()])?; let l_id = db.create_table_for_test("l", &l_schema, &[0.into()])?; - subscribe_multi(&subs, queries, sender, &mut 0)?; + subscribe_multi(&subs, auth, queries, sender, &mut 0)?; assert!(matches!(rx.recv().await, Some(SerializableMessage::Subscription(_)))); @@ -2269,10 +2309,14 @@ mod tests { async fn test_query_pruning() -> anyhow::Result<()> { let db = relational_db()?; + let client_id_a = client_id_from_u8(1); + let client_id_b = client_id_from_u8(2); // Establish a connection for each client - let (tx_for_a, mut rx_for_a) = client_connection(client_id_from_u8(1), &db); - let (tx_for_b, mut rx_for_b) = client_connection(client_id_from_u8(2), &db); + let (tx_for_a, mut rx_for_a) = client_connection(client_id_a, &db); + let (tx_for_b, mut rx_for_b) = client_connection(client_id_b, &db); + let auth_a = AuthCtx::new(db.owner_identity(), client_id_a.identity); + let auth_b = AuthCtx::new(db.owner_identity(), client_id_b.identity); let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); let u_id = db.create_table_for_test( @@ -2312,6 +2356,7 @@ mod tests { // Returns (i: 0, a: 1, b: 1) subscribe_multi( &subs, + auth_a.clone(), &[ "select u.* from u join v on u.i = v.i where v.x = 4", "select u.* from u join v on u.i = v.i where v.x = 6", @@ -2323,6 +2368,7 @@ mod tests { // Returns (i: 1, a: 2, b: 2) subscribe_multi( &subs, + auth_b.clone(), &[ "select u.* from u join v on u.i = v.i where v.x = 5", "select u.* from u join v on u.i = v.i where v.x = 7", @@ -2411,8 +2457,10 @@ mod tests { async fn test_join_pruning() -> anyhow::Result<()> { let db = relational_db()?; - let (tx, mut rx) = client_connection(client_id_from_u8(1), &db); + let client_id = client_id_from_u8(1); + let (tx, mut rx) = client_connection(client_id, &db); + let auth = AuthCtx::new(db.owner_identity(), client_id.identity); let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); let u_id = db.create_table_for_test_with_the_works( @@ -2459,6 +2507,7 @@ mod tests { subscribe_multi( &subs, + auth, &[ "select u.* from u join v on u.i = v.i where v.x = 1", "select u.* from u join v on u.i = v.i where v.x = 2", @@ -2565,9 +2614,14 @@ mod tests { async fn test_subscribe_distinct_queries_same_plan() -> anyhow::Result<()> { let db = relational_db()?; + let client_id_a = client_id_from_u8(1); + let client_id_b = client_id_from_u8(2); // Establish a connection for each client - let (tx_for_a, mut rx_for_a) = client_connection(client_id_from_u8(1), &db); - let (tx_for_b, mut rx_for_b) = client_connection(client_id_from_u8(2), &db); + let (tx_for_a, mut rx_for_a) = client_connection(client_id_a, &db); + let (tx_for_b, mut rx_for_b) = client_connection(client_id_b, &db); + + let auth_a = AuthCtx::new(db.owner_identity(), client_id_a.identity); + let auth_b = AuthCtx::new(db.owner_identity(), client_id_b.identity); let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); @@ -2603,12 +2657,14 @@ mod tests { // Both clients subscribe to the same query modulo whitespace subscribe_multi( &subs, + auth_a, &["select u.* from u join v on u.i = v.i where v.x = 1"], tx_for_a, &mut query_ids, )?; subscribe_multi( &subs, + auth_b, &["select u.* from u join v on u.i = v.i where v.x = 1"], tx_for_b.clone(), &mut query_ids, @@ -2659,9 +2715,15 @@ mod tests { async fn test_unsubscribe_distinct_queries_same_plan() -> anyhow::Result<()> { let db = relational_db()?; + let client_id_a = client_id_from_u8(1); + let client_id_b = client_id_from_u8(2); + // Establish a connection for each client - let (tx_for_a, mut rx_for_a) = client_connection(client_id_from_u8(1), &db); - let (tx_for_b, mut rx_for_b) = client_connection(client_id_from_u8(2), &db); + let (tx_for_a, mut rx_for_a) = client_connection(client_id_a, &db); + let (tx_for_b, mut rx_for_b) = client_connection(client_id_b, &db); + + let auth_a = AuthCtx::new(db.owner_identity(), client_id_a.identity); + let auth_b = AuthCtx::new(db.owner_identity(), client_id_b.identity); let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); @@ -2696,12 +2758,14 @@ mod tests { subscribe_multi( &subs, + auth_a, &["select u.* from u join v on u.i = v.i where v.x = 1"], tx_for_a, &mut query_ids, )?; subscribe_multi( &subs, + auth_b.clone(), &["select u.* from u join v on u.i = v.i where v.x = 1"], tx_for_b.clone(), &mut query_ids, @@ -2723,7 +2787,7 @@ mod tests { })) ); - unsubscribe_multi(&subs, tx_for_b, query_ids)?; + unsubscribe_multi(&subs, auth_b, tx_for_b, query_ids)?; assert_matches!( rx_for_b.recv().await, @@ -2772,8 +2836,10 @@ mod tests { let db = relational_db()?; // Establish a client connection - let (tx, mut rx) = client_connection(client_id_from_u8(1), &db); + let client_id = client_id_from_u8(1); + let (tx, mut rx) = client_connection(client_id, &db); + let auth = AuthCtx::new(db.owner_identity(), client_id.identity); let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); let schema = &[("id", AlgebraicType::U64), ("a", AlgebraicType::U64)]; @@ -2788,6 +2854,7 @@ mod tests { // Subscribe to queries that return empty results let metrics = subscribe_multi( &subs, + auth, &[ "select t.* from t where a = 0", "select t.* from t join s on t.id = s.id where s.a = 0", @@ -2896,18 +2963,30 @@ mod tests { async fn test_confirmed_reads() -> anyhow::Result<()> { let (db, durability) = relational_db_with_manual_durability()?; + let client_id_confirmed = client_id_from_u8(1); + let client_id_unconfirmed = client_id_from_u8(2); + let (tx_for_confirmed, mut rx_for_confirmed) = - client_connection_with_confirmed_reads(client_id_from_u8(1), &db, true); + client_connection_with_confirmed_reads(client_id_confirmed, &db, true); let (tx_for_unconfirmed, mut rx_for_unconfirmed) = - client_connection_with_confirmed_reads(client_id_from_u8(2), &db, false); + client_connection_with_confirmed_reads(client_id_unconfirmed, &db, false); + + let auth_confirmed = AuthCtx::new(db.owner_identity(), client_id_confirmed.identity); + let auth_unconfirmed = AuthCtx::new(db.owner_identity(), client_id_unconfirmed.identity); let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); let table = db.create_table_for_test("t", &[("x", AlgebraicType::U8)], &[])?; let schema = ProductType::from([AlgebraicType::U8]); // Subscribe both clients. - subscribe_multi(&subs, &["select * from t"], tx_for_confirmed, &mut 0)?; - subscribe_multi(&subs, &["select * from t"], tx_for_unconfirmed, &mut 0)?; + subscribe_multi(&subs, auth_confirmed, &["select * from t"], tx_for_confirmed, &mut 0)?; + subscribe_multi( + &subs, + auth_unconfirmed, + &["select * from t"], + tx_for_unconfirmed, + &mut 0, + )?; assert_matches!( rx_for_unconfirmed.recv().await, diff --git a/crates/core/src/subscription/module_subscription_manager.rs b/crates/core/src/subscription/module_subscription_manager.rs index 8cf32595672..2438dfe6fc5 100644 --- a/crates/core/src/subscription/module_subscription_manager.rs +++ b/crates/core/src/subscription/module_subscription_manager.rs @@ -1668,7 +1668,7 @@ mod tests { let auth = AuthCtx::for_testing(); let tx = SchemaViewer::new(&*tx, &auth); let (plans, has_param) = SubscriptionPlan::compile(sql, &tx, &auth).unwrap(); - let hash = QueryHash::from_string(sql, auth.caller, has_param); + let hash = QueryHash::from_string(sql, auth.caller(), has_param); Ok(Arc::new(Plan::new(plans, hash, sql.into()))) }) } diff --git a/crates/core/src/subscription/query.rs b/crates/core/src/subscription/query.rs index 02036f4b299..6235ba413b0 100644 --- a/crates/core/src/subscription/query.rs +++ b/crates/core/src/subscription/query.rs @@ -87,7 +87,7 @@ pub fn compile_read_only_query(auth: &AuthCtx, tx: &Tx, input: &str) -> Result

> for ExecutionSet { } impl AuthAccess for ExecutionSet { - fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError> { - self.exec_units.iter().try_for_each(|eu| eu.check_auth(owner, caller)) + fn check_auth(&self, auth: &AuthCtx) -> Result<(), AuthError> { + self.exec_units.iter().try_for_each(|eu| eu.check_auth(auth)) } } @@ -616,7 +615,7 @@ pub(crate) fn get_all(relational_db: &RelationalDB, tx: &Tx, auth: &AuthCtx) -> .get_all_tables(tx)? .iter() .map(Deref::deref) - .filter(|t| t.table_type == StTableType::User && (auth.is_owner() || t.table_access == StAccess::Public)) + .filter(|t| t.table_type == StTableType::User && auth.has_read_access(t.table_access)) .map(|schema| { let sql = format!("SELECT * FROM {}", schema.table_name); let tx = SchemaViewer::new(tx, auth); @@ -625,12 +624,12 @@ pub(crate) fn get_all(relational_db: &RelationalDB, tx: &Tx, auth: &AuthCtx) -> plans, QueryHash::from_string( &sql, - auth.caller, + auth.caller(), // Note that when generating hashes for queries from owners, // we always treat them as if they were parameterized by :sender. // This is because RLS is not applicable to owners. // Hence owner hashes must never overlap with client hashes. - auth.is_owner() || has_param, + auth.bypass_rls() || has_param, ), sql, ) @@ -652,7 +651,7 @@ pub(crate) fn legacy_get_all( .get_all_tables(tx)? .iter() .map(Deref::deref) - .filter(|t| t.table_type == StTableType::User && (auth.is_owner() || t.table_access == StAccess::Public)) + .filter(|t| t.table_type == StTableType::User && auth.has_read_access(t.table_access)) .map(|src| SupportedQuery { kind: query::Supported::Select, expr: QueryExpr::new(src), diff --git a/crates/core/src/vm.rs b/crates/core/src/vm.rs index 8fa228e41e4..904fe302cb2 100644 --- a/crates/core/src/vm.rs +++ b/crates/core/src/vm.rs @@ -466,7 +466,7 @@ pub fn check_row_limit( row_est: impl Fn(&Query, &TxId) -> u64, auth: &AuthCtx, ) -> Result<(), DBError> { - if auth.caller != auth.owner { + if !auth.exceed_row_limit() { if let Some(limit) = db.row_limit(tx)? { let mut estimate: u64 = 0; for query in queries { @@ -627,7 +627,7 @@ impl<'db, 'tx> DbProgram<'db, 'tx> { impl ProgramVm for DbProgram<'_, '_> { // Safety: For DbProgram with tx = TxMode::Tx variant, all queries must match to CrudCode::Query and no other branch. fn eval_query(&mut self, query: CrudExpr, sources: Sources<'_, N>) -> Result { - query.check_auth(self.auth.owner, self.auth.caller)?; + query.check_auth(&self.auth)?; match query { CrudExpr::Query(query) => self._eval_query(&query, sources), diff --git a/crates/expr/src/check.rs b/crates/expr/src/check.rs index 14aaa34a9ed..8976e22a210 100644 --- a/crates/expr/src/check.rs +++ b/crates/expr/src/check.rs @@ -160,7 +160,7 @@ impl TypeChecker for SubChecker { pub fn parse_and_type_sub(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> TypingResult<(ProjectName, bool)> { let ast = parse_subscription(sql)?; let has_param = ast.has_parameter(); - let ast = ast.resolve_sender(auth.caller); + let ast = ast.resolve_sender(auth.caller()); expect_table_type(SubChecker::type_ast(ast, tx)?).map(|plan| (plan, has_param)) } diff --git a/crates/expr/src/rls.rs b/crates/expr/src/rls.rs index 89bdf7149f8..2c3afaed8ff 100644 --- a/crates/expr/src/rls.rs +++ b/crates/expr/src/rls.rs @@ -18,7 +18,7 @@ pub fn resolve_views_for_sub( has_param: &mut bool, ) -> anyhow::Result> { // RLS does not apply to the database owner - if auth.is_owner() { + if auth.bypass_rls() { return Ok(vec![expr]); } @@ -56,7 +56,7 @@ pub fn resolve_views_for_sub( /// Mainly a wrapper around [resolve_views_for_expr]. pub fn resolve_views_for_sql(tx: &impl SchemaView, expr: ProjectList, auth: &AuthCtx) -> anyhow::Result { // RLS does not apply to the database owner - if auth.is_owner() { + if auth.bypass_rls() { return Ok(expr); } // The subscription language is a subset of the sql language. diff --git a/crates/expr/src/statement.rs b/crates/expr/src/statement.rs index 50e9bdb4c22..45716fa18e2 100644 --- a/crates/expr/src/statement.rs +++ b/crates/expr/src/statement.rs @@ -428,7 +428,7 @@ impl TypeChecker for SqlChecker { } pub fn parse_and_type_sql(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> TypingResult { - match parse_sql(sql)?.resolve_sender(auth.caller) { + match parse_sql(sql)?.resolve_sender(auth.caller()) { SqlAst::Select(ast) => Ok(Statement::Select(SqlChecker::type_ast(ast, tx)?)), SqlAst::Insert(insert) => Ok(Statement::DML(DML::Insert(type_insert(insert, tx)?))), SqlAst::Delete(delete) => Ok(Statement::DML(DML::Delete(type_delete(delete, tx)?))), diff --git a/crates/lib/src/identity.rs b/crates/lib/src/identity.rs index 8420e5e6245..3a5b2bf0743 100644 --- a/crates/lib/src/identity.rs +++ b/crates/lib/src/identity.rs @@ -1,37 +1,111 @@ +use crate::db::auth::StAccess; use crate::from_hex_pad; use blake3; use core::mem; use spacetimedb_bindings_macro::{Deserialize, Serialize}; use spacetimedb_sats::hex::HexString; use spacetimedb_sats::{impl_st, u256, AlgebraicType, AlgebraicValue}; +use std::sync::Arc; use std::{fmt, str::FromStr}; pub type RequestId = u32; -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +/// Set of permissions the SQL engine may ask for. +pub enum SqlPermission { + /// Read permissions given the [StAccess] of a table. + /// + /// [StAccess] must be passed in order to allow external implementations + /// to fail compilation should the [StAccess] enum ever gain additional + /// variants. Implementations should always do an exhaustive match thus. + /// + /// [SqlAuthorization::has_sql_permission] must return true if + /// [StAccess::Public]. + Read(StAccess), + /// Write access, i.e. executing DML. + Write, + /// If granted, no row limit checks will be performed for subscription queries. + ExceedRowLimit, + /// RLS does not apply to database owners (for some definition of owner). + /// If the subject qualifies as an owner, the permission should be granted. + BypassRLS, +} + +/// Types than can grant or deny [SqlPermission]s. +pub trait SqlAuthorization { + /// Returns `true` if permission `p` is granted, `false` otherwise. + fn has_sql_permission(&self, p: SqlPermission) -> bool; +} + +impl bool> SqlAuthorization for T { + fn has_sql_permission(&self, p: SqlPermission) -> bool { + self(p) + } +} + +/// [SqlAuthorization] trait object. +pub type SqlPermissions = Arc; + +/// The legacy permissions (sans "teams") grant everything if the owner is +/// equal to the caller. +fn owner_permissions(owner: Identity, caller: Identity) -> SqlPermissions { + let is_owner = owner == caller; + Arc::new(move |p| match p { + SqlPermission::Read(access) => match access { + StAccess::Public => true, + StAccess::Private => is_owner, + }, + _ => is_owner, + }) +} + +/// Authorization for SQL operations (queries, DML, subscription queries). +#[derive(Clone)] pub struct AuthCtx { - pub owner: Identity, - pub caller: Identity, + caller: Identity, + permissions: SqlPermissions, } impl AuthCtx { pub fn new(owner: Identity, caller: Identity) -> Self { - Self { owner, caller } + Self::with_permissions(caller, owner_permissions(owner, caller)) } + + pub fn with_permissions(caller: Identity, permissions: SqlPermissions) -> Self { + Self { caller, permissions } + } + /// For when the owner == caller pub fn for_current(owner: Identity) -> Self { - Self { owner, caller: owner } + Self::new(owner, owner) } - /// Does `owner == caller` - pub fn is_owner(&self) -> bool { - self.owner == self.caller + + pub fn has_permission(&self, p: SqlPermission) -> bool { + self.permissions.has_sql_permission(p) + } + + pub fn has_read_access(&self, table_access: StAccess) -> bool { + self.has_permission(SqlPermission::Read(table_access)) + } + + pub fn has_write_access(&self) -> bool { + self.has_permission(SqlPermission::Write) + } + + pub fn exceed_row_limit(&self) -> bool { + self.has_permission(SqlPermission::ExceedRowLimit) } + + pub fn bypass_rls(&self) -> bool { + self.has_permission(SqlPermission::BypassRLS) + } + + pub fn caller(&self) -> Identity { + self.caller + } + /// WARNING: Use this only for simple test were the `auth` don't matter pub fn for_testing() -> Self { - AuthCtx { - owner: Identity::__dummy(), - caller: Identity::__dummy(), - } + Self::new(Identity::__dummy(), Identity::__dummy()) } } diff --git a/crates/pg/src/pg_server.rs b/crates/pg/src/pg_server.rs index 30fc75f0438..860df156f25 100644 --- a/crates/pg/src/pg_server.rs +++ b/crates/pg/src/pg_server.rs @@ -25,7 +25,7 @@ use pgwire::tokio::process_socket; use spacetimedb_client_api::auth::validate_token; use spacetimedb_client_api::routes::database; use spacetimedb_client_api::routes::database::{SqlParams, SqlQueryParams}; -use spacetimedb_client_api::{ControlStateReadAccess, ControlStateWriteAccess, NodeDelegate}; +use spacetimedb_client_api::{Authorization, ControlStateReadAccess, ControlStateWriteAccess, NodeDelegate}; use spacetimedb_client_api_messages::http::SqlStmtResult; use spacetimedb_client_api_messages::name::DatabaseName; use spacetimedb_lib::sats::satn::{PsqlClient, TypedSerializer}; @@ -147,7 +147,10 @@ struct PgSpacetimeDB { parameter_provider: DefaultServerParameterProvider, } -impl PgSpacetimeDB { +impl PgSpacetimeDB +where + T: ControlStateReadAccess + ControlStateWriteAccess + NodeDelegate + Authorization + Clone, +{ async fn exe_sql(&self, query: String) -> PgWireResult> { let params = self.cached.lock().await.clone().unwrap(); let db = SqlParams { @@ -294,8 +297,9 @@ impl SimpleQueryHandler - for PgSpacetimeDB +impl SimpleQueryHandler for PgSpacetimeDB +where + T: Sync + Send + ControlStateReadAccess + ControlStateWriteAccess + NodeDelegate + Authorization + Clone, { async fn do_query(&self, _client: &mut C, query: &str) -> PgWireResult> where @@ -326,8 +330,9 @@ impl PgSpacetimeDBFactory { } } -impl PgWireServerHandlers - for PgSpacetimeDBFactory +impl PgWireServerHandlers for PgSpacetimeDBFactory +where + T: Sync + Send + ControlStateReadAccess + ControlStateWriteAccess + NodeDelegate + Authorization + Clone, { fn simple_query_handler(&self) -> Arc { self.handler.clone() @@ -340,11 +345,10 @@ impl( - shutdown: Arc, - ctx: T, - tcp: TcpListener, -) { +pub async fn start_pg(shutdown: Arc, ctx: T, tcp: TcpListener) +where + T: ControlStateReadAccess + ControlStateWriteAccess + NodeDelegate + Authorization + Clone + 'static, +{ let factory = Arc::new(PgSpacetimeDBFactory::new(ctx)); log::debug!( diff --git a/crates/schema/src/def/error.rs b/crates/schema/src/def/error.rs index 83bae995fc1..7acdfd3a937 100644 --- a/crates/schema/src/def/error.rs +++ b/crates/schema/src/def/error.rs @@ -77,8 +77,8 @@ pub enum AuthError { IndexPrivate { named: String }, #[error("Sequence `{named}` is private")] SequencePrivate { named: String }, - #[error("Only the database owner can perform the requested operation")] - OwnerRequired, + #[error("Insufficient privileges to perform the requested operation")] + InsuffientPrivileges, #[error("Constraint `{named}` is private")] ConstraintPrivate { named: String }, } diff --git a/crates/sqltest/src/space.rs b/crates/sqltest/src/space.rs index a67e7d21842..e773426b916 100644 --- a/crates/sqltest/src/space.rs +++ b/crates/sqltest/src/space.rs @@ -71,7 +71,7 @@ impl SpaceDb { self.conn.with_read_only(Workload::Sql, |tx| { let ast = compile_sql(&self.conn, &AuthCtx::for_testing(), tx, sql)?; let (subs, _runtime) = ModuleSubscriptions::for_test_new_runtime(Arc::new(self.conn.db.clone())); - let result = execute_sql(&self.conn, sql, ast, self.auth, Some(&subs))?; + let result = execute_sql(&self.conn, sql, ast, self.auth.clone(), Some(&subs))?; //remove comments to see which SQL worked. Can't collect it outside from lack of a hook in the external `sqllogictest` crate... :( //append_file(&std::path::PathBuf::from(".ok.sql"), sql)?; Ok(result) diff --git a/crates/standalone/src/control_db.rs b/crates/standalone/src/control_db.rs index 00152ef3d46..6128cf8614e 100644 --- a/crates/standalone/src/control_db.rs +++ b/crates/standalone/src/control_db.rs @@ -38,6 +38,8 @@ pub enum Error { RecordAlreadyExists(DomainName), #[error("database with identity {0} already exists")] DatabaseAlreadyExists(Identity), + #[error("database with identity {0} does not exist")] + DatabaseNotFound(Identity), #[error("failed to register {0} domain")] DomainRegistrationFailure(DomainName), #[error("failed to decode data")] @@ -377,6 +379,21 @@ impl ControlDb { Ok(id) } + pub(crate) fn update_database(&self, database: Database) -> Result<()> { + let Some(stored_database) = self.get_database_by_identity(&database.database_identity)? else { + return Err(Error::DatabaseNotFound(database.database_identity)); + }; + + let tree = self.db.open_tree("database_by_identity")?; + let buf = sled::IVec::from(compat::Database::from(database).to_vec()?); + tree.insert(stored_database.database_identity.to_be_byte_array(), buf.clone())?; + + let tree = self.db.open_tree("database")?; + tree.insert(stored_database.id.to_be_bytes(), buf)?; + + Ok(()) + } + pub fn delete_database(&self, id: u64) -> Result> { let tree = self.db.open_tree("database")?; let tree_by_identity = self.db.open_tree("database_by_identity")?; @@ -430,7 +447,7 @@ impl ControlDb { // if !tree.contains_key(database_id.to_be_bytes())? { // return Err(anyhow::anyhow!("No such database.")); // } - // + //1073741824 let replicas = self .get_replicas()? .iter() diff --git a/crates/standalone/src/lib.rs b/crates/standalone/src/lib.rs index 4ed8ccf5de3..a3cda83da3b 100644 --- a/crates/standalone/src/lib.rs +++ b/crates/standalone/src/lib.rs @@ -5,7 +5,7 @@ pub mod version; use crate::control_db::ControlDb; use crate::subcommands::{extract_schema, start}; -use anyhow::{ensure, Context as _, Ok}; +use anyhow::Context as _; use async_trait::async_trait; use clap::{ArgMatches, Command}; use spacetimedb::client::ClientActorIndex; @@ -14,13 +14,13 @@ use spacetimedb::db; use spacetimedb::db::persistence::LocalPersistenceProvider; use spacetimedb::energy::{EnergyBalance, EnergyQuanta, NullEnergyMonitor}; use spacetimedb::host::{DiskStorage, HostController, MigratePlanResult, UpdateDatabaseResult}; -use spacetimedb::identity::Identity; +use spacetimedb::identity::{AuthCtx, Identity}; use spacetimedb::messages::control_db::{Database, Node, Replica}; use spacetimedb::util::jobs::JobCores; use spacetimedb::worker_metrics::WORKER_METRICS; use spacetimedb_client_api::auth::{self, LOCALHOST}; use spacetimedb_client_api::routes::subscribe::{HasWebSocketOptions, WebSocketOptions}; -use spacetimedb_client_api::{Host, NodeDelegate}; +use spacetimedb_client_api::{ControlStateReadAccess, DatabaseResetDef, Host, NodeDelegate}; use spacetimedb_client_api_messages::name::{DomainName, InsertDomainResult, RegisterTldResult, SetDomainsResult, Tld}; use spacetimedb_datastore::db_metrics::data_size::DATA_SIZE_METRICS; use spacetimedb_datastore::db_metrics::DB_METRICS; @@ -258,13 +258,6 @@ impl spacetimedb_client_api::ControlStateWriteAccess for StandaloneEnv { // The database already exists, so we'll try to update it. // If that fails, we'll keep the old one. Some(database) => { - ensure!( - &database.owner_identity == publisher, - "Permission denied: `{}` does not own database `{}`", - publisher, - spec.database_identity.to_abbreviated_hex() - ); - let database_id = database.id; let database_identity = database.database_identity; @@ -273,7 +266,7 @@ impl spacetimedb_client_api::ControlStateWriteAccess for StandaloneEnv { .await? .ok_or_else(|| anyhow::anyhow!("No leader for database"))?; let update_result = leader - .update(database, spec.host_type, spec.program_bytes.into(), policy) + .update(database, spec.host_type, spec.program_bytes.to_vec().into(), policy) .await?; if update_result.was_successful() { let replicas = self.control_db.get_replicas_by_database(database_id)?; @@ -337,7 +330,13 @@ impl spacetimedb_client_api::ControlStateWriteAccess for StandaloneEnv { .await? .ok_or_else(|| anyhow::anyhow!("No leader for database"))?; self.host_controller - .migrate_plan(db, spec.host_type, host.replica_id, spec.program_bytes.into(), style) + .migrate_plan( + db, + spec.host_type, + host.replica_id, + spec.program_bytes.to_vec().into(), + style, + ) .await } None => anyhow::bail!( @@ -347,19 +346,10 @@ impl spacetimedb_client_api::ControlStateWriteAccess for StandaloneEnv { } } - async fn delete_database(&self, caller_identity: &Identity, database_identity: &Identity) -> anyhow::Result<()> { + async fn delete_database(&self, _caller_identity: &Identity, database_identity: &Identity) -> anyhow::Result<()> { let Some(database) = self.control_db.get_database_by_identity(database_identity)? else { return Ok(()); }; - anyhow::ensure!( - &database.owner_identity == caller_identity, - // TODO: `PermissionDenied` should be a variant of `Error`, - // so we can match on it and return better error responses - // from HTTP endpoints. - "Permission denied: `{caller_identity}` does not own database `{}`", - database_identity.to_abbreviated_hex() - ); - self.control_db.delete_database(database.id)?; for instance in self.control_db.get_replicas_by_database(database.id)? { @@ -369,6 +359,42 @@ impl spacetimedb_client_api::ControlStateWriteAccess for StandaloneEnv { Ok(()) } + async fn reset_database(&self, _caller_identity: &Identity, spec: DatabaseResetDef) -> anyhow::Result<()> { + let mut database = self + .control_db + .get_database_by_identity(&spec.database_identity)? + .with_context(|| format!("Database `{}` does not exist", spec.database_identity))?; + let database_id = database.id; + + if let Some(program) = spec.program_bytes { + let program_bytes = &program[..]; + let program = Program::from_bytes(program_bytes); + let _hash_for_assert = program.hash; + + database.initial_program = program.hash; + if let Some(host_type) = spec.host_type { + database.host_type = host_type; + } + + self.host_controller + .check_module_validity(database.clone(), program) + .await?; + let _stored_hash_for_assert = self.program_store.put(program_bytes).await?; + debug_assert_eq!(_hash_for_assert, _stored_hash_for_assert); + + self.control_db.update_database(database)?; + } + + for instance in self.control_db.get_replicas_by_database(database_id)? { + self.delete_replica(instance.id).await?; + } + // Standalone only support a single replica. + let num_replicas = 1; + self.schedule_replicas(database_id, num_replicas).await?; + + Ok(()) + } + async fn add_energy(&self, identity: &Identity, amount: EnergyQuanta) -> anyhow::Result<()> { let balance = self .control_db @@ -412,6 +438,42 @@ impl spacetimedb_client_api::ControlStateWriteAccess for StandaloneEnv { } } +impl spacetimedb_client_api::Authorization for StandaloneEnv { + async fn authorize_action( + &self, + subject: Identity, + database: Identity, + action: spacetimedb_client_api::Action, + ) -> Result<(), spacetimedb_client_api::Unauthorized> { + let database = self + .get_database_by_identity(&database)? + .with_context(|| format!("database {database} not found")) + .with_context(|| format!("Unable to authorize {subject} to perform {action:?})"))?; + if subject == database.owner_identity { + return Ok(()); + } + + Err(spacetimedb_client_api::Unauthorized::Unauthorized { + subject, + action, + source: None, + }) + } + + async fn authorize_sql( + &self, + subject: Identity, + database: Identity, + ) -> Result { + let database = self + .get_database_by_identity(&database)? + .with_context(|| format!("database {database} not found")) + .with_context(|| format!("Unable to authorize {subject} for SQL"))?; + + Ok(AuthCtx::new(database.owner_identity, subject)) + } +} + impl StandaloneEnv { async fn insert_replica(&self, replica: Replica) -> Result<(), anyhow::Error> { let mut new_replica = replica.clone(); diff --git a/crates/standalone/src/subcommands/start.rs b/crates/standalone/src/subcommands/start.rs index 3d9cc44af4f..811c3538117 100644 --- a/crates/standalone/src/subcommands/start.rs +++ b/crates/standalone/src/subcommands/start.rs @@ -1,3 +1,4 @@ +use spacetimedb_client_api::routes::identity::IdentityRoutes; use spacetimedb_pg::pg_server; use std::sync::Arc; @@ -185,7 +186,7 @@ pub async fn exec(args: &ArgMatches, db_cores: JobCores) -> anyhow::Result<()> { db_routes.db_put = db_routes.db_put.layer(DefaultBodyLimit::disable()); db_routes.pre_publish = db_routes.pre_publish.layer(DefaultBodyLimit::disable()); let extra = axum::Router::new().nest("/health", spacetimedb_client_api::routes::health::router()); - let service = router(&ctx, db_routes, extra).with_state(ctx.clone()); + let service = router(&ctx, db_routes, IdentityRoutes::default(), extra).with_state(ctx.clone()); let tcp = TcpListener::bind(listen_addr).await.context(format!( "failed to bind the SpacetimeDB server to '{listen_addr}', please check that the address is valid and not already in use" diff --git a/crates/testing/Cargo.toml b/crates/testing/Cargo.toml index ff9d70e306e..2d06a8dddce 100644 --- a/crates/testing/Cargo.toml +++ b/crates/testing/Cargo.toml @@ -16,6 +16,7 @@ spacetimedb-paths.workspace = true spacetimedb-schema.workspace = true anyhow.workspace = true +bytes.workspace = true env_logger.workspace = true log.workspace = true clap.workspace = true diff --git a/crates/testing/src/modules.rs b/crates/testing/src/modules.rs index c00d70717fd..326f30bf704 100644 --- a/crates/testing/src/modules.rs +++ b/crates/testing/src/modules.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use std::sync::OnceLock; use std::time::Instant; +use bytes::Bytes; use spacetimedb::config::CertificateAuthority; use spacetimedb::messages::control_db::HostType; use spacetimedb::util::jobs::JobCores; @@ -95,7 +96,7 @@ pub struct CompiledModule { name: String, path: PathBuf, pub(super) host_type: HostType, - program_bytes: OnceLock>, + program_bytes: OnceLock, } #[derive(Debug, PartialEq, Eq)] @@ -124,12 +125,16 @@ impl CompiledModule { &self.path } - pub fn program_bytes(&self) -> &[u8] { - self.program_bytes.get_or_init(|| std::fs::read(&self.path).unwrap()) + pub fn program_bytes(&self) -> Bytes { + self.program_bytes + .get_or_init(|| std::fs::read(&self.path).unwrap().into()) + .clone() } pub async fn extract_schema(&self) -> ModuleDef { - spacetimedb::host::extract_schema(self.program_bytes().into(), self.host_type) + // TODO: extract_schema should accept &[u8] + let boxed_bytes: Box<[u8]> = self.program_bytes()[..].into(); + spacetimedb::host::extract_schema(boxed_bytes, self.host_type) .await .unwrap() } @@ -198,15 +203,14 @@ impl CompiledModule { let db_identity = SpacetimeAuth::alloc(&env).await.unwrap().claims.identity; let connection_id = generate_random_connection_id(); - let program_bytes = self.program_bytes().to_owned(); - env.publish_database( &identity, DatabaseDef { database_identity: db_identity, - program_bytes, + program_bytes: self.program_bytes(), num_replicas: None, host_type: self.host_type, + parent: None, }, MigrationPolicy::Compatible, ) diff --git a/crates/vm/src/expr.rs b/crates/vm/src/expr.rs index d9a70d47548..23cc59e12d2 100644 --- a/crates/vm/src/expr.rs +++ b/crates/vm/src/expr.rs @@ -8,7 +8,7 @@ use itertools::Itertools; use smallvec::SmallVec; use spacetimedb_data_structures::map::{HashSet, IntMap}; use spacetimedb_lib::db::auth::{StAccess, StTableType}; -use spacetimedb_lib::Identity; +use spacetimedb_lib::identity::AuthCtx; use spacetimedb_primitives::*; use spacetimedb_sats::satn::Satn; use spacetimedb_sats::{AlgebraicType, AlgebraicValue, ProductValue}; @@ -25,7 +25,7 @@ use std::{fmt, iter, mem}; /// Trait for checking if the `caller` have access to `Self` pub trait AuthAccess { - fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError>; + fn check_auth(&self, auth: &AuthCtx) -> Result<(), AuthError>; } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, From)] @@ -1958,12 +1958,8 @@ impl QueryExpr { } impl AuthAccess for Query { - fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError> { - if owner == caller { - return Ok(()); - } - - self.walk_sources(&mut |s| s.check_auth(owner, caller)) + fn check_auth(&self, auth: &AuthCtx) -> Result<(), AuthError> { + self.walk_sources(&mut |s| s.check_auth(auth)) } } @@ -2017,8 +2013,8 @@ impl fmt::Display for Query { } impl AuthAccess for SourceExpr { - fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError> { - if owner == caller || self.table_access() == StAccess::Public { + fn check_auth(&self, auth: &AuthCtx) -> Result<(), AuthError> { + if auth.has_read_access(self.table_access()) { return Ok(()); } @@ -2029,26 +2025,24 @@ impl AuthAccess for SourceExpr { } impl AuthAccess for QueryExpr { - fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError> { - if owner == caller { - return Ok(()); - } - self.walk_sources(&mut |s| s.check_auth(owner, caller)) + fn check_auth(&self, auth: &AuthCtx) -> Result<(), AuthError> { + self.walk_sources(&mut |s| s.check_auth(auth)) } } impl AuthAccess for CrudExpr { - fn check_auth(&self, owner: Identity, caller: Identity) -> Result<(), AuthError> { - if owner == caller { - return Ok(()); - } + fn check_auth(&self, auth: &AuthCtx) -> Result<(), AuthError> { // Anyone may query, so as long as the tables involved are public. if let CrudExpr::Query(q) = self { - return q.check_auth(owner, caller); + return q.check_auth(auth); } // Mutating operations require `owner == caller`. - Err(AuthError::OwnerRequired) + if !auth.has_write_access() { + return Err(AuthError::InsuffientPrivileges); + } + + Ok(()) } } @@ -2117,7 +2111,7 @@ impl From for CodeResult { mod tests { use super::*; - use spacetimedb_lib::db::raw_def::v9::RawModuleDefV9Builder; + use spacetimedb_lib::{db::raw_def::v9::RawModuleDefV9Builder, Identity}; use spacetimedb_sats::{product, AlgebraicType, ProductType}; use spacetimedb_schema::{def::ModuleDef, relation::Column, schema::Schema}; use typed_arena::Arena; @@ -2201,16 +2195,19 @@ mod tests { } fn assert_owner_private(auth: &T) { - assert!(auth.check_auth(ALICE, ALICE).is_ok()); + assert!(auth.check_auth(&AuthCtx::new(ALICE, ALICE)).is_ok()); assert!(matches!( - auth.check_auth(ALICE, BOB), + auth.check_auth(&AuthCtx::new(ALICE, BOB)), Err(AuthError::TablePrivate { .. }) )); } fn assert_owner_required(auth: T) { - assert!(auth.check_auth(ALICE, ALICE).is_ok()); - assert!(matches!(auth.check_auth(ALICE, BOB), Err(AuthError::OwnerRequired))); + assert!(auth.check_auth(&AuthCtx::new(ALICE, ALICE)).is_ok()); + assert!(matches!( + auth.check_auth(&AuthCtx::new(ALICE, BOB)), + Err(AuthError::InsuffientPrivileges) + )); } fn mem_table(id: TableId, name: &str, fields: &[(u16, AlgebraicType, bool)]) -> SourceExpr { diff --git a/docs/docs/cli-reference.md b/docs/docs/cli-reference.md index e28061f97b4..630c3954c29 100644 --- a/docs/docs/cli-reference.md +++ b/docs/docs/cli-reference.md @@ -92,6 +92,10 @@ Run `spacetime help publish` for more detailed information. * `-j`, `--js-path ` — UNSTABLE: The system path (absolute or relative) to the javascript file we should publish, instead of building the project. * `--break-clients` — Allow breaking changes when publishing to an existing database identity. This will break existing clients. * `--anonymous` — Perform this action with an anonymous identity +* `--parent ` — A valid domain or identity of an existing database that should be the parent of this database. + + If a parent is given, the new database inherits the team permissions from the parent. + A parent can only be set when a database is created, not when it is updated. * `-s`, `--server ` — The nickname, domain name or URL of the server to host the database. * `-y`, `--yes` — Run non-interactively wherever possible. This will answer "yes" to almost all prompts, but will sometimes answer "no" to preserve non-interactivity (e.g. when prompting whether to log in with spacetimedb.com). diff --git a/smoketests/__init__.py b/smoketests/__init__.py index 63486c09b08..2cf3bae62b9 100644 --- a/smoketests/__init__.py +++ b/smoketests/__init__.py @@ -365,7 +365,7 @@ def tearDown(self): if "database_identity" in self.__dict__: try: # TODO: save the credentials in publish_module() - self.spacetime("delete", self.database_identity) + self.spacetime("delete", "--yes", self.database_identity) except Exception: pass @@ -374,7 +374,7 @@ def tearDownClass(cls): if hasattr(cls, "database_identity"): try: # TODO: save the credentials in publish_module() - cls.spacetime("delete", cls.database_identity) + cls.spacetime("delete", "--yes", cls.database_identity) except Exception: pass diff --git a/smoketests/tests/domains.py b/smoketests/tests/domains.py index 61fef422f2f..8a39be046d5 100644 --- a/smoketests/tests/domains.py +++ b/smoketests/tests/domains.py @@ -1,5 +1,4 @@ from .. import Smoketest, random_string -import unittest import json class Domains(Smoketest): @@ -26,15 +25,15 @@ def test_set_name(self): with self.assertRaises(Exception): self.spacetime("logs", orig_name) - @unittest.expectedFailure def test_subdomain_behavior(self): """Test how we treat the / character in published names""" root_name = random_string() self.publish_module(root_name) - id_to_rename = self.database_identity - self.publish_module(f"{root_name}/test") + # TODO: This is valid in editions with the teams feature, but + # smoketests don't know the target's edition. + # self.publish_module(f"{root_name}/test") with self.assertRaises(Exception): self.publish_module(f"{root_name}//test") diff --git a/smoketests/tests/permissions.py b/smoketests/tests/permissions.py index 39ffca6e022..3b1068d0c36 100644 --- a/smoketests/tests/permissions.py +++ b/smoketests/tests/permissions.py @@ -81,7 +81,7 @@ class PrivateTablePermissions(Smoketest): MODULE_CODE = """ use spacetimedb::{ReducerContext, Table}; -#[spacetimedb::table(name = secret)] +#[spacetimedb::table(name = secret, private)] pub struct Secret { answer: u8, } @@ -97,9 +97,9 @@ class PrivateTablePermissions(Smoketest): } #[spacetimedb::reducer] -pub fn do_thing(ctx: &ReducerContext) { +pub fn do_thing(ctx: &ReducerContext, thing: String) { ctx.db.secret().insert(Secret { answer: 20 }); - ctx.db.common_knowledge().insert(CommonKnowledge { thing: "howdy".to_owned() }); + ctx.db.common_knowledge().insert(CommonKnowledge { thing }); } """ @@ -113,7 +113,7 @@ def test_private_table(self): " 42 ", "" ]) - self.assertMultiLineEqual(out, answer) + self.assertMultiLineEqual(str(out), answer) self.reset_config() self.new_identity() @@ -121,12 +121,33 @@ def test_private_table(self): with self.assertRaises(Exception): self.spacetime("sql", self.database_identity, "select * from secret") + # Subscribing to the private table failes. with self.assertRaises(Exception): self.subscribe("SELECT * FROM secret", n=0) + # Subscribing to the public table works. + sub = self.subscribe("SELECT * FROM common_knowledge", n = 1) + self.call("do_thing", "godmorgon") + self.assertEqual(sub(), [ + { + 'common_knowledge': { + 'deletes': [], + 'inserts': [{'thing': 'godmorgon'}] + } + } + ]) + + # Subscribing to both tables returns updates for the public one. sub = self.subscribe("SELECT * FROM *", n=1) - self.call("do_thing", anon=True) - self.assertEqual(sub(), [{'common_knowledge': {'deletes': [], 'inserts': [{'thing': 'howdy'}]}}]) + self.call("do_thing", "howdy", anon=True) + self.assertEqual(sub(), [ + { + 'common_knowledge': { + 'deletes': [], + 'inserts': [{'thing': 'howdy'}] + } + } + ]) class LifecycleReducers(Smoketest): diff --git a/smoketests/tests/teams.py b/smoketests/tests/teams.py new file mode 100644 index 00000000000..0eeea186865 --- /dev/null +++ b/smoketests/tests/teams.py @@ -0,0 +1,335 @@ +import json +import toml + +from .. import Smoketest, parse_sql_result, random_string + +class CreateChildDatabase(Smoketest): + AUTOPUBLISH = False + + def test_create_child_database(self): + """ + Test that the owner can add a child database, + and that deleting the parent also deletes the child. + """ + + parent_name = random_string() + child_name = random_string() + + self.publish_module(parent_name) + parent_identity = self.database_identity + self.publish_module(f"{parent_name}/{child_name}") + child_identity = self.database_identity + + databases = self.query_controldb(parent_identity, child_identity) + self.assertEqual(2, len(databases)) + + self.spacetime("delete", "--yes", parent_name) + + databases = self.query_controldb(parent_identity, child_identity) + self.assertEqual(0, len(databases)) + + def query_controldb(self, parent, child): + res = self.spacetime( + "sql", + "spacetime-control", + f"select * from database where database_identity = 0x{parent} or database_identity = 0x{child}" + ) + return parse_sql_result(str(res)) + +class PermissionsTest(Smoketest): + AUTOPUBLISH = False + + def create_identity(self): + """ + Obtain a fresh identity and token from the server. + Doesn't alter the config.toml for this test instance. + """ + resp = self.api_call("POST", "/v1/identity") + return json.loads(resp) + + def create_collaborators(self, database): + """ + Create collaborators for the current database, one for each role. + """ + collaborators = {} + roles = ["Owner", "Admin", "Developer", "Viewer"] + for role in roles: + identity_and_token = self.create_identity() + self.call_controldb_reducer( + "upsert_collaborator", + {"Name": database}, + [f"0x{identity_and_token['identity']}"], + {role: {}} + ) + collaborators[role] = identity_and_token + return collaborators + + + def call_controldb_reducer(self, reducer, *args): + """ + Call a controldb reducer. + """ + self.spacetime("call", "spacetime-control", reducer, *map(json.dumps, args)) + + def login_with(self, identity_and_token: dict): + self.spacetime("logout") + config = toml.load(self.config_path) + config['spacetimedb_token'] = identity_and_token['token'] + with open(self.config_path, 'w') as f: + toml.dump(config, f) + + def publish_as(self, role_and_token, module, code, clear = False): + print(f"publishing {module} as {role_and_token[0]}:") + print(f"{code}") + self.login_with(role_and_token[1]) + self.write_module_code(code) + self.publish_module(module, clear = clear) + return self.database_identity + + def sql_as(self, role_and_token, database, sql): + """ + Log in as `token` and run an SQL statement against `database` + """ + print(f"running sql as {role_and_token[0]}: {sql}") + self.login_with(role_and_token[1]) + res = self.spacetime("sql", database, sql) + return parse_sql_result(str(res)) + + def subscribe_as(self, role_and_token, *queries, n): + """ + Log in as `token` and subscribe to the current database using `queries`. + """ + print(f"subscribe as {role_and_token[0]}: {queries}") + self.login_with(role_and_token[1]) + return self.subscribe(*queries, n = n) + + +class MutableSql(PermissionsTest): + MODULE_CODE = """ +#[spacetimedb::table(name = person, public)] +struct Person { + name: String, +} +""" + def test_permissions_for_mutable_sql_transactions(self): + """ + Tests that only owners and admins can perform mutable SQL transactions. + """ + + name = random_string() + self.publish_module(name) + team = self.create_collaborators(name) + + for role, token in team.items(): + self.login_with(token) + dml = f"insert into person (name) values ('bob-the-{role}')" + if role == "Owner" or role == "Admin": + self.spacetime("sql", name, dml) + else: + with self.assertRaises(Exception): + self.spacetime("sql", name, dml) + + +class PublishDatabase(PermissionsTest): + MODULE_CODE = """ +#[spacetimedb::table(name = person, public)] +struct Person { + name: String, +} +""" + + MODULE_CODE_OWNER = MODULE_CODE + """ +#[spacetimedb::table(name = owner)] +struct Owner { + name: String, +} +""" + + MODULE_CODE_ADMIN = MODULE_CODE_OWNER + """ +#[spacetimedb::table(name = admin)] +struct Admin { + name: String, +} +""" + + MODULE_CODE_DEVELOPER = MODULE_CODE_ADMIN + """ +#[spacetimedb::table(name = developer)] +struct Developer { + name: String, +} +""" + + MODULE_CODE_VIEWER = MODULE_CODE_DEVELOPER + """ +#[spacetimedb::table(name = viewer)] +struct Viewer { + name: String, +} +""" + + def test_permissions_publish(self): + """ + Tests that only owner, admin and developer roles can publish a database. + """ + + parent = random_string() + self.publish_module(parent) + + (owner, admin, developer, viewer) = self.create_collaborators(parent).items() + succeed_with = [ + (owner, self.MODULE_CODE_OWNER), + (admin, self.MODULE_CODE_ADMIN), + (developer, self.MODULE_CODE_DEVELOPER) + ] + + for role_and_token, code in succeed_with: + self.publish_as(role_and_token, parent, code) + + with self.assertRaises(Exception): + self.publish_as(viewer, parent, self.MODULE_CODE_VIEWER) + + # Create a child database. + child = random_string() + child_path = f"{parent}/{child}" + + # Developer and viewer should not be able to create a child. + for role_and_token in [developer, viewer]: + with self.assertRaises(Exception): + self.publish_as(role_and_token, child_path, self.MODULE_CODE) + # But admin should succeed. + self.publish_as(admin, child_path, self.MODULE_CODE) + + # Once created, only viewer should be denied updating. + for role_and_token, code in succeed_with: + self.publish_as(role_and_token, child_path, code) + + with self.assertRaises(Exception): + self.publish_as(viewer, child_path, self.MODULE_CODE_VIEWER) + + +class ClearDatabase(PermissionsTest): + def test_permissions_clear(self): + """ + Tests that only owners and admins can clear a database. + """ + + parent = random_string() + self.publish_module(parent) + # First degree owner can clear. + self.publish_module(parent, clear = True) + + (owner, admin, developer, viewer) = self.create_collaborators(parent).items() + + # Owner and admin collaborators can clear. + for role_and_token in [owner, admin]: + self.publish_as(role_and_token, parent, self.MODULE_CODE, clear = True) + + # Others can't. + for role_and_token in [developer, viewer]: + with self.assertRaises(Exception): + self.publish_as(role_and_token, parent, self.MODULE_CODE, clear = True) + + # Same applies to child. + child = random_string() + child_path = f"{parent}/{child}" + + self.publish_as(owner, child_path, self.MODULE_CODE) + + for role_and_token in [owner, admin]: + self.publish_as(role_and_token, parent, self.MODULE_CODE, clear = True) + + for role_and_token in [developer, viewer]: + with self.assertRaises(Exception): + self.publish_as(role_and_token, parent, self.MODULE_CODE, clear = True) + + +class DeleteDatabase(PermissionsTest): + def delete_as(self, role_and_token, database): + print(f"delete {database} as {role_and_token[0]}") + self.login_with(role_and_token[1]) + self.spacetime("delete", "--yes", database) + + def test_permissions_delete(self): + """ + Tests that only owners can delete databases. + """ + + parent = random_string() + self.publish_module(parent) + self.spacetime("delete", "--yes", parent) + + self.publish_module(parent) + + (owner, admin, developer, viewer) = self.create_collaborators(parent).items() + + for role_and_token in [admin, developer, viewer]: + with self.assertRaises(Exception): + self.delete_as(role_and_token, parent) + + child = random_string() + child_path = f"{parent}/{child}" + + # If admin creates a child, they should also be able to delete it, + # because they are the owner of the child. + print("publish and delete as admin") + self.publish_as(admin, child_path, self.MODULE_CODE) + self.delete_as(admin, child) + + # The owner role should be able to delete. + print("publish as admin, delete as owner") + self.publish_as(admin, child_path, self.MODULE_CODE) + self.delete_as(owner, child) + + # Anyone else should be denied if not direct owner. + print("publish as owner, deny deletion by admin, developer, viewer") + self.publish_as(owner, child_path, self.MODULE_CODE) + for role_and_token in [admin, developer, viewer]: + with self.assertRaises(Exception): + self.delete_as(role_and_token, child) + + print("delete child as owner") + self.delete_as(owner, child) + + print("delete parent as owner") + self.delete_as(owner, parent) + + +class PrivateTables(PermissionsTest): + def test_permissions_private_tables(self): + """ + Test that all collaborators can read private tables. + """ + + parent = random_string() + self.publish_module(parent) + + team = self.create_collaborators(parent) + owner = ("Owner", team['Owner']) + + self.sql_as(owner, parent, "insert into person (name) values ('horsti')") + + for role_and_token in team.items(): + rows = self.sql_as(role_and_token, parent, "select * from person") + self.assertEqual(rows, [{ "name": '"horsti"' }]) + + for role_and_token in team.items(): + sub = self.subscribe_as(role_and_token, "select * from person", n = 2) + self.sql_as(owner, parent, "insert into person (name) values ('hansmans')") + self.sql_as(owner, parent, "delete from person where name = 'hansmans'") + res = sub() + self.assertEqual( + res, + [ + { + 'person': { + 'deletes': [], + 'inserts': [{'name': 'hansmans'}] + } + }, + { + 'person': { + 'deletes': [{'name': 'hansmans'}], + 'inserts': [] + } + } + ], + )