diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index a6a6290..7b2874a 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -2,9 +2,9 @@ name: Documentation on: push: - branches: [ main ] + branches: [main] pull_request: - branches: [ main ] + branches: [main] env: CARGO_TERM_COLOR: always @@ -31,6 +31,6 @@ jobs: - name: Build documentation run: | - cargo doc --all-features --no-deps --document-private-items + cargo doc --features=migrate,tokio-comp,json --no-deps --document-private-items env: RUSTDOCFLAGS: "--cfg docsrs -Dwarnings" diff --git a/.sqlx/query-1356c17313c5f3491e384d1667106a42fe6918792ae9bea8848a6ec1a3895a76.json b/.sqlx/query-1356c17313c5f3491e384d1667106a42fe6918792ae9bea8848a6ec1a3895a76.json new file mode 100644 index 0000000..06b7454 --- /dev/null +++ b/.sqlx/query-1356c17313c5f3491e384d1667106a42fe6918792ae9bea8848a6ec1a3895a76.json @@ -0,0 +1,32 @@ +{ + "db_name": "SQLite", + "query": "SELECT\n id,\n status,\n last_result as result\nFROM\n Jobs\nWHERE\n id IN (\n SELECT\n value\n FROM\n json_each(?)\n )\n AND (\n status = 'Done'\n OR (\n status = 'Failed'\n AND attempts >= max_attempts\n )\n OR status = 'Killed'\n )\n", + "describe": { + "columns": [ + { + "name": "id", + "ordinal": 0, + "type_info": "Text" + }, + { + "name": "status", + "ordinal": 1, + "type_info": "Text" + }, + { + "name": "result", + "ordinal": 2, + "type_info": "Text" + } + ], + "parameters": { + "Right": 1 + }, + "nullable": [ + false, + false, + true + ] + }, + "hash": "1356c17313c5f3491e384d1667106a42fe6918792ae9bea8848a6ec1a3895a76" +} diff --git a/.sqlx/query-1cd760004a2341bbded38a4fa431eaa74232f6f6f3121c7086a0b138195e9b0d.json b/.sqlx/query-1cd760004a2341bbded38a4fa431eaa74232f6f6f3121c7086a0b138195e9b0d.json index 92182f9..6dd5de4 100644 --- a/.sqlx/query-1cd760004a2341bbded38a4fa431eaa74232f6f6f3121c7086a0b138195e9b0d.json +++ b/.sqlx/query-1cd760004a2341bbded38a4fa431eaa74232f6f6f3121c7086a0b138195e9b0d.json @@ -62,6 +62,11 @@ "name": "priority", "ordinal": 11, "type_info": "Integer" + }, + { + "name": "metadata", + "ordinal": 12, + "type_info": "Text" } ], "parameters": { @@ -79,7 +84,8 @@ true, true, true, - false + false, + true ] }, "hash": "1cd760004a2341bbded38a4fa431eaa74232f6f6f3121c7086a0b138195e9b0d" diff --git a/.sqlx/query-40a420986a37c9db2fc35b161092d6063f186242ca1808cfae9d6477c7d8687b.json b/.sqlx/query-40a420986a37c9db2fc35b161092d6063f186242ca1808cfae9d6477c7d8687b.json index bab3c7a..f314bdd 100644 --- a/.sqlx/query-40a420986a37c9db2fc35b161092d6063f186242ca1808cfae9d6477c7d8687b.json +++ b/.sqlx/query-40a420986a37c9db2fc35b161092d6063f186242ca1808cfae9d6477c7d8687b.json @@ -62,6 +62,11 @@ "name": "priority", "ordinal": 11, "type_info": "Integer" + }, + { + "name": "metadata", + "ordinal": 12, + "type_info": "Text" } ], "parameters": { @@ -79,7 +84,8 @@ true, true, true, - false + false, + true ] }, "hash": "40a420986a37c9db2fc35b161092d6063f186242ca1808cfae9d6477c7d8687b" diff --git a/.sqlx/query-61843a18bffdee192cd01f1537f0f03d75403970fc8347d0f017f04d746b2b98.json b/.sqlx/query-61843a18bffdee192cd01f1537f0f03d75403970fc8347d0f017f04d746b2b98.json index 194ba6c..5fa4c2a 100644 --- a/.sqlx/query-61843a18bffdee192cd01f1537f0f03d75403970fc8347d0f017f04d746b2b98.json +++ b/.sqlx/query-61843a18bffdee192cd01f1537f0f03d75403970fc8347d0f017f04d746b2b98.json @@ -62,6 +62,11 @@ "name": "priority", "ordinal": 11, "type_info": "Integer" + }, + { + "name": "metadata", + "ordinal": 12, + "type_info": "Text" } ], "parameters": { @@ -79,7 +84,8 @@ true, true, true, - false + false, + true ] }, "hash": "61843a18bffdee192cd01f1537f0f03d75403970fc8347d0f017f04d746b2b98" diff --git a/.sqlx/query-a43052e877930a522d6a63789935653f4fe06d10361c555f8ade4b65589520bc.json b/.sqlx/query-a43052e877930a522d6a63789935653f4fe06d10361c555f8ade4b65589520bc.json index cb5c3e6..c2bd6c0 100644 --- a/.sqlx/query-a43052e877930a522d6a63789935653f4fe06d10361c555f8ade4b65589520bc.json +++ b/.sqlx/query-a43052e877930a522d6a63789935653f4fe06d10361c555f8ade4b65589520bc.json @@ -62,6 +62,11 @@ "name": "priority", "ordinal": 11, "type_info": "Integer" + }, + { + "name": "metadata", + "ordinal": 12, + "type_info": "Text" } ], "parameters": { @@ -79,7 +84,8 @@ true, true, true, - false + false, + true ] }, "hash": "a43052e877930a522d6a63789935653f4fe06d10361c555f8ade4b65589520bc" diff --git a/.sqlx/query-4ead10dafe0d2d024654403f289c4db57308e5cf88f44efaf23b05c585626465.json b/.sqlx/query-ae9b77d1f5b3f5125a37cdc29f6489b5f84003af0154b50166aaa6472863cd2d.json similarity index 67% rename from .sqlx/query-4ead10dafe0d2d024654403f289c4db57308e5cf88f44efaf23b05c585626465.json rename to .sqlx/query-ae9b77d1f5b3f5125a37cdc29f6489b5f84003af0154b50166aaa6472863cd2d.json index e2ab422..fa62399 100644 --- a/.sqlx/query-4ead10dafe0d2d024654403f289c4db57308e5cf88f44efaf23b05c585626465.json +++ b/.sqlx/query-ae9b77d1f5b3f5125a37cdc29f6489b5f84003af0154b50166aaa6472863cd2d.json @@ -1,12 +1,12 @@ { "db_name": "SQLite", - "query": "INSERT INTO\n Jobs\nVALUES\n (\n ?1,\n ?2,\n ?3,\n 'Pending',\n 0,\n ?4,\n ?5,\n NULL,\n NULL,\n NULL,\n NULL,\n ?6\n )\n", + "query": "INSERT INTO\n Jobs\nVALUES\n (\n ?1,\n ?2,\n ?3,\n 'Pending',\n 0,\n ?4,\n ?5,\n NULL,\n NULL,\n NULL,\n NULL,\n ?6,\n ?7\n )\n", "describe": { "columns": [], "parameters": { - "Right": 6 + "Right": 7 }, "nullable": [] }, - "hash": "4ead10dafe0d2d024654403f289c4db57308e5cf88f44efaf23b05c585626465" + "hash": "ae9b77d1f5b3f5125a37cdc29f6489b5f84003af0154b50166aaa6472863cd2d" } diff --git a/.sqlx/query-d7aefe54cd7388c208fff5b946390f217b575f0ca464a5faddd0fe2d51793983.json b/.sqlx/query-d7aefe54cd7388c208fff5b946390f217b575f0ca464a5faddd0fe2d51793983.json index d5f2525..045d050 100644 --- a/.sqlx/query-d7aefe54cd7388c208fff5b946390f217b575f0ca464a5faddd0fe2d51793983.json +++ b/.sqlx/query-d7aefe54cd7388c208fff5b946390f217b575f0ca464a5faddd0fe2d51793983.json @@ -62,6 +62,11 @@ "name": "priority", "ordinal": 11, "type_info": "Integer" + }, + { + "name": "metadata", + "ordinal": 12, + "type_info": "Text" } ], "parameters": { @@ -79,7 +84,8 @@ true, true, true, - false + false, + true ] }, "hash": "d7aefe54cd7388c208fff5b946390f217b575f0ca464a5faddd0fe2d51793983" diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..7c97836 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,17 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + + +### Added + +- Workflow support + +### Changed + +- Moved from monorepo diff --git a/Cargo.toml b/Cargo.toml index 61082ba..80e7c07 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "apalis-sqlite" -version = "1.0.0-alpha.1" +version = "1.0.0-alpha.2" authors = ["Njuguna Mureithi "] readme = "README.md" edition = "2024" @@ -29,7 +29,8 @@ serde_json = { version = "1" } apalis-core = { version = "1.0.0-alpha.4", default-features = false, features = [ "sleep", "json", -] } +], git = "https://github.com/geofmureithi/apalis.git", branch = "chore/traits-expansion" } +apalis-workflow = { version = "0.1.0-alpha.3", git = "https://github.com/geofmureithi/apalis.git", branch = "chore/traits-expansion" } log = "0.4.21" futures = "0.3.30" tokio = { version = "1", features = ["rt", "net"], optional = true } @@ -45,10 +46,13 @@ bytes = "1.1.0" [dev-dependencies] tokio = { version = "1", features = ["macros", "rt-multi-thread"] } -apalis-core = { version = "1.0.0-alpha.4", features = ["test-utils"] } +apalis-core = { version = "1.0.0-alpha.4", features = [ + "test-utils", +], git = "https://github.com/geofmureithi/apalis.git", branch = "chore/traits-expansion" } +apalis-workflow = { version = "0.1.0-alpha.2", git = "https://github.com/geofmureithi/apalis.git", branch = "chore/traits-expansion" } apalis-sqlite = { path = ".", features = ["migrate", "tokio-comp"] } [package.metadata.docs.rs] # defines the configuration attribute `docsrs` rustdoc-args = ["--cfg", "docsrs"] -all-features = true +features = ["migrate", "tokio-comp", "json"] diff --git a/README.md b/README.md index f0166c3..ff2c9ef 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ async fn main() { Ok(()) } - let worker = apalis_core::worker::builder::WorkerBuilder::new("worker-1") + let worker = WorkerBuilder::new("worker-1") .backend(backend) .build(send_reminder); worker.run().await.unwrap(); @@ -93,13 +93,57 @@ async fn main() { Ok(()) } - let worker = apalis_core::worker::builder::WorkerBuilder::new("worker-2") + let worker = WorkerBuilder::new("worker-2") .backend(backend) .build(send_reminder); worker.run().await.unwrap(); } ``` +### Workflow Example + +```rust,no_run +#[tokio::main] +async fn main() { + let workflow = WorkFlow::new("odd-numbers-workflow") + .then(|a: usize| async move { + Ok::<_, WorkflowError>((0..=a).collect::>()) + }) + .filter_map(|x| async move { + if x % 2 != 0 { Some(x) } else { None } + }) + .filter_map(|x| async move { + if x % 3 != 0 { Some(x) } else { None } + }) + .filter_map(|x| async move { + if x % 5 != 0 { Some(x) } else { None } + }) + .delay_for(Duration::from_millis(1000)) + .then(|a: Vec| async move { + println!("Sum: {}", a.iter().sum::()); + Ok::<(), WorkflowError>(()) + }); + + let pool = SqlitePool::connect(":memory:").await.unwrap(); + SqliteStorage::setup(&pool).await.unwrap(); + let mut sqlite = SqliteStorage::new_in_queue(&pool, "test-workflow"); + + sqlite.push(100usize).await.unwrap(); + + let worker = WorkerBuilder::new("rango-tango") + .backend(sqlite) + .on_event(|ctx, ev| { + println!("On Event = {:?}", ev); + if matches!(ev, Event::Error(_)) { + ctx.stop().unwrap(); + } + }) + .build(workflow); + + worker.run().await.unwrap(); +} +``` + ## Migrations If the `migrate` feature is enabled, you can run built-in migrations with: diff --git a/migrations/json/20251018162501_metadata.sql b/migrations/json/20251018162501_metadata.sql new file mode 100644 index 0000000..63bccc7 --- /dev/null +++ b/migrations/json/20251018162501_metadata.sql @@ -0,0 +1,4 @@ +ALTER TABLE + Jobs +ADD + COLUMN metadata TEXT; diff --git a/queries/backend/fetch_completed_tasks.sql b/queries/backend/fetch_completed_tasks.sql new file mode 100644 index 0000000..bc27097 --- /dev/null +++ b/queries/backend/fetch_completed_tasks.sql @@ -0,0 +1,21 @@ +SELECT + id, + status, + last_result as result +FROM + Jobs +WHERE + id IN ( + SELECT + value + FROM + json_each(?) + ) + AND ( + status = 'Done' + OR ( + status = 'Failed' + AND attempts >= max_attempts + ) + OR status = 'Killed' + ) diff --git a/queries/task/sink.sql b/queries/task/sink.sql index ee42376..6c5e4c0 100644 --- a/queries/task/sink.sql +++ b/queries/task/sink.sql @@ -13,5 +13,6 @@ VALUES NULL, NULL, NULL, - ?6 + ?6, + ?7 ) diff --git a/src/ack.rs b/src/ack.rs index 875e0c1..6a27027 100644 --- a/src/ack.rs +++ b/src/ack.rs @@ -1,16 +1,20 @@ +use std::any::Any; + use apalis_core::{ error::{AbortError, BoxDynError}, task::{Parts, status::Status}, worker::{context::WorkerContext, ext::ack::Acknowledge}, }; +use apalis_workflow::StepResult; use futures::{FutureExt, future::BoxFuture}; use serde::Serialize; +use serde_json::Value; use sqlx::SqlitePool; use tower_layer::Layer; use tower_service::Service; use ulid::Ulid; -use crate::{SqliteTask, context::SqliteContext}; +use crate::{CompactType, SqliteTask, context::SqliteContext}; #[derive(Clone)] pub struct SqliteAck { @@ -22,7 +26,7 @@ impl SqliteAck { } } -impl Acknowledge for SqliteAck { +impl Acknowledge for SqliteAck { type Error = sqlx::Error; type Future = BoxFuture<'static, Result<(), Self::Error>>; fn ack( @@ -33,7 +37,20 @@ impl Acknowledge for SqliteAck { let task_id = parts.task_id; let worker_id = parts.ctx.lock_by().clone(); - let response = serde_json::to_string(&res.as_ref().map_err(|e| e.to_string())); + // Workflows need special handling to serialize the response correctly + let response = match res { + Ok(r) => { + if let Some(res_ref) = (r as &dyn Any).downcast_ref::>() { + let res_deserialized: Result = + serde_json::from_str(&res_ref.0); + serde_json::to_string(&res_deserialized.map_err(|e| e.to_string())) + } else { + serde_json::to_string(&res.as_ref().map_err(|e| e.to_string())) + } + } + _ => serde_json::to_string(&res.as_ref().map_err(|e| e.to_string())), + }; + let status = calculate_status(parts, res); parts.status.store(status.clone()); let attempt = parts.attempt.current() as i32; @@ -151,7 +168,7 @@ where .unwrap(); let parts = &req.parts; let task_id = match &parts.task_id { - Some(id) => id.inner().clone(), + Some(id) => *id.inner(), None => { return async { Err(sqlx::Error::ColumnNotFound("TASK_ID_FOR_LOCK".to_owned()).into()) diff --git a/src/callback.rs b/src/callback.rs index eeecc63..72a7312 100644 --- a/src/callback.rs +++ b/src/callback.rs @@ -50,14 +50,6 @@ pub(crate) extern "C" fn update_hook_callback( let db = CStr::from_ptr(db_name).to_string_lossy().to_string(); let table = CStr::from_ptr(table_name).to_string_lossy().to_string(); - log::debug!( - "DB Event - Operation: {}, DB: {}, Table: {}, RowID: {}", - op_str, - db, - table, - rowid - ); - // Recover sender from raw pointer let tx = &mut *(arg as *mut mpsc::UnboundedSender); diff --git a/src/context.rs b/src/context.rs index c1c22ff..5e14694 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1,10 +1,13 @@ -use std::convert::Infallible; +use std::{collections::HashMap, convert::Infallible}; -use apalis_core::task_fn::FromRequest; +use apalis_core::{task::metadata::MetadataExt, task_fn::FromRequest}; -use serde::{Deserialize, Serialize}; +use serde::{ + Deserialize, Serialize, + de::{DeserializeOwned, Error}, +}; -use crate::SqliteTask; +use crate::{CompactType, SqliteTask}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SqliteContext { @@ -15,6 +18,7 @@ pub struct SqliteContext { done_at: Option, priority: i32, queue: Option, + meta: HashMap, } impl Default for SqliteContext { @@ -34,6 +38,7 @@ impl SqliteContext { lock_by: None, priority: 0, queue: None, + meta: HashMap::new(), } } @@ -111,6 +116,15 @@ impl SqliteContext { self.queue = Some(queue); self } + + pub fn meta(&self) -> &HashMap { + &self.meta + } + + pub fn with_meta(mut self, meta: HashMap) -> Self { + self.meta = meta; + self + } } impl FromRequest> for SqliteContext { @@ -119,3 +133,20 @@ impl FromRequest> for SqliteContext { Ok(req.parts.ctx.clone()) } } + +impl MetadataExt for SqliteContext { + type Error = serde_json::Error; + fn extract(&self) -> Result { + self.meta + .get(std::any::type_name::()) + .and_then(|v| serde_json::from_str::(v).ok()) + .ok_or(serde_json::Error::custom("Failed to extract metadata")) + } + fn inject(&mut self, value: T) -> Result<(), Self::Error> { + self.meta.insert( + std::any::type_name::().to_string(), + serde_json::to_string(&value).unwrap(), + ); + Ok(()) + } +} diff --git a/src/fetcher.rs b/src/fetcher.rs index 3490005..28e05d3 100644 --- a/src/fetcher.rs +++ b/src/fetcher.rs @@ -8,7 +8,7 @@ use std::{ use apalis_core::{ backend::{ - codec::{Codec, }, + codec::Codec, poll_strategy::{PollContext, PollStrategyExt}, }, task::Task, @@ -28,6 +28,7 @@ pub async fn fetch_next>( ) -> Result>, sqlx::Error> where D::Error: std::error::Error + Send + Sync + 'static, + Args: 'static, { let job_type = config.queue().to_string(); let buffer_size = config.buffer_size() as i32; diff --git a/src/from_row.rs b/src/from_row.rs index a08d351..38d0489 100644 --- a/src/from_row.rs +++ b/src/from_row.rs @@ -21,11 +21,11 @@ pub(crate) struct TaskRow { pub(crate) lock_by: Option, pub(crate) done_at: Option, pub(crate) priority: Option, - // pub(crate) meta: Option>, + pub(crate) metadata: Option, } impl TaskRow { - pub fn try_into_task, Args>( + pub fn try_into_task, Args: 'static>( self, ) -> Result, sqlx::Error> where @@ -37,12 +37,34 @@ impl TaskRow { .with_max_attempts(self.max_attempts.unwrap_or(25) as i32) .with_last_result(self.last_result) .with_priority(self.priority.unwrap_or(0) as i32) + .with_meta( + self.metadata + .map(|m| serde_json::from_str(&m).unwrap_or_default()) + .unwrap_or_default(), + ) .with_queue( self.job_type .ok_or(sqlx::Error::ColumnNotFound("job_type".to_owned()))?, ) .with_lock_at(self.lock_at); - let args = D::decode(&self.job).map_err(|e| sqlx::Error::Decode(e.into()))?; + + // Optimize for the case where Args and CompactType are the same type + // to avoid unnecessary serialization/deserialization. + // That comes at the cost of using unsafe code, and leaking memory + let args = if std::any::TypeId::of::() == std::any::TypeId::of::() { + // SAFETY: We've verified that Args and CompactType are the same type. + // We use ptr::read to move the value out without calling drop on self.job. + // Then we use mem::forget to prevent self from being dropped (which would + // try to drop self.job again, causing a double free). + unsafe { + let job_ptr = &self.job as *const CompactType as *const Args; + let args = std::ptr::read(job_ptr); + std::mem::forget(self.job); + args + } + } else { + D::decode(&self.job).map_err(|e| sqlx::Error::Decode(e.into()))? + }; let task = TaskBuilder::new(args) .with_ctx(ctx) .with_attempt(Attempt::new_with_value( @@ -78,6 +100,11 @@ impl TaskRow { .with_max_attempts(self.max_attempts.unwrap_or(25) as i32) .with_last_result(self.last_result) .with_priority(self.priority.unwrap_or(0) as i32) + .with_meta( + self.metadata + .map(|m| serde_json::from_str(&m).unwrap_or_default()) + .unwrap_or_default(), + ) .with_queue( self.job_type .ok_or(sqlx::Error::ColumnNotFound("job_type".to_owned()))?, diff --git a/src/lib.rs b/src/lib.rs index 1b2be94..9b49277 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -33,28 +33,25 @@ //! # use futures::StreamExt; //! # use futures::SinkExt; //! # use apalis_core::worker::builder::WorkerBuilder; +//! # use apalis_core::backend::TaskSink; //! #[tokio::main] //! async fn main() { //! let pool = SqlitePool::connect(":memory:").await.unwrap(); //! SqliteStorage::setup(&pool).await.unwrap(); //! let mut backend = SqliteStorage::new(&pool); //! -//! let mut start = 0; +//! let mut start = 0usize; //! let mut items = stream::repeat_with(move || { //! start += 1; -//! let task = Task::builder(start) -//! .run_after(Duration::from_secs(1)) -//! .with_ctx(SqliteContext::new().with_priority(1)) -//! .build(); -//! Ok(task) +//! start //! }) //! .take(10); -//! backend.send_all(&mut items).await.unwrap(); +//! backend.push_stream(&mut items).await.unwrap(); //! //! async fn send_reminder(item: usize, wrk: WorkerContext) -> Result<(), BoxDynError> { -//! # if item == 10 { -//! # wrk.stop().unwrap(); -//! # } +//! if item == 10 { +//! wrk.stop().unwrap(); +//! } //! Ok(()) //! } //! @@ -128,6 +125,64 @@ //! worker.run().await.unwrap(); //! } //! ``` +//! ### Workflow Example +//! +//! ```rust,no_run +//! # use apalis_sqlite::{SqliteStorage, SqliteContext, Config}; +//! # use apalis_core::task::Task; +//! # use apalis_core::worker::context::WorkerContext; +//! # use sqlx::SqlitePool; +//! # use futures::stream; +//! # use std::time::Duration; +//! # use apalis_core::error::BoxDynError; +//! # use futures::StreamExt; +//! # use futures::SinkExt; +//! # use apalis_core::worker::builder::WorkerBuilder; +//! # use apalis_workflow::WorkFlow; +//! # use apalis_workflow::WorkflowError; +//! # use apalis_core::worker::event::Event; +//! # use apalis_core::backend::WeakTaskSink; +//! # use apalis_core::worker::ext::event_listener::EventListenerExt; +//! #[tokio::main] +//! async fn main() { +//! let workflow = WorkFlow::new("odd-numbers-workflow") +//! .then(|a: usize| async move { +//! Ok::<_, WorkflowError>((0..=a).collect::>()) +//! }) +//! .filter_map(|x| async move { +//! if x % 2 != 0 { Some(x) } else { None } +//! }) +//! .filter_map(|x| async move { +//! if x % 3 != 0 { Some(x) } else { None } +//! }) +//! .filter_map(|x| async move { +//! if x % 5 != 0 { Some(x) } else { None } +//! }) +//! .delay_for(Duration::from_millis(1000)) +//! .then(|a: Vec| async move { +//! println!("Sum: {}", a.iter().sum::()); +//! Ok::<(), WorkflowError>(()) +//! }); +//! +//! let pool = SqlitePool::connect(":memory:").await.unwrap(); +//! SqliteStorage::setup(&pool).await.unwrap(); +//! let mut sqlite = SqliteStorage::new_in_queue(&pool, "test-workflow"); +//! +//! sqlite.push(100usize).await.unwrap(); +//! +//! let worker = WorkerBuilder::new("rango-tango") +//! .backend(sqlite) +//! .on_event(|ctx, ev| { +//! println!("On Event = {:?}", ev); +//! if matches!(ev, Event::Error(_)) { +//! ctx.stop().unwrap(); +//! } +//! }) +//! .build(workflow); +//! +//! worker.run().await.unwrap(); +//! } +//! ``` //! //! ## Migrations //! @@ -145,6 +200,8 @@ //! ## License //! //! Licensed under either of Apache License, Version 2.0 or MIT license at your option. +//! +//! [`SqliteStorageWithHook`]: crate::SqliteStorage use std::{fmt, marker::PhantomData}; use apalis_core::{ @@ -198,6 +255,7 @@ pub use sqlx::SqlitePool; #[cfg(feature = "json")] pub type CompactType = String; +// Bytes not yet supported due to sqlx limitations #[cfg(feature = "bytes")] pub type CompactType = Vec; @@ -295,6 +353,27 @@ impl SqliteStorage { } } + pub fn new_in_queue( + pool: &Pool, + queue: &str, + ) -> SqliteStorage< + T, + JsonCodec, + fetcher::SqliteFetcher>, + > { + let config = Config::new(queue); + SqliteStorage { + pool: pool.clone(), + job_type: PhantomData, + sink: SqliteSink::new(pool, &config), + config, + codec: PhantomData, + fetcher: fetcher::SqliteFetcher { + _marker: PhantomData, + }, + } + } + pub fn new_with_codec( pool: &Pool, config: &Config, @@ -340,7 +419,7 @@ impl SqliteStorage { job_type: PhantomData, config: config.clone(), codec: PhantomData, - sink: SqliteSink::new(&pool, config), + sink: SqliteSink::new(pool, config), fetcher: HookCallbackListener, } } @@ -354,7 +433,7 @@ impl SqliteStorage { job_type: PhantomData, config: config.clone(), codec: PhantomData, - sink: SqliteSink::new(&pool, config), + sink: SqliteSink::new(pool, config), fetcher: HookCallbackListener, } } @@ -530,14 +609,19 @@ where mod tests { use std::time::Duration; + use apalis_workflow::{WorkFlow, WorkflowError}; use chrono::Local; use apalis_core::{ - backend::poll_strategy::{IntervalStrategy, StrategyBuilder}, + backend::{ + WeakTaskSink, + poll_strategy::{IntervalStrategy, StrategyBuilder}, + }, error::BoxDynError, - worker::builder::WorkerBuilder, + task::data::Data, + worker::{builder::WorkerBuilder, event::Event, ext::event_listener::EventListenerExt}, }; - use futures::SinkExt; + use serde::{Deserialize, Serialize}; use super::*; @@ -553,14 +637,10 @@ mod tests { let mut items = stream::repeat_with(move || { start += 1; - let task = Task::builder(start) - .run_after(Duration::from_secs(1)) - .with_ctx(SqliteContext::new().with_priority(1)) - .build(); - Ok(task) + start }) .take(ITEMS); - backend.send_all(&mut items).await.unwrap(); + backend.push_stream(&mut items).await.unwrap(); println!("Starting worker at {}", Local::now()); @@ -623,4 +703,164 @@ mod tests { .build(send_reminder); worker.run().await.unwrap(); } + + #[tokio::test] + async fn test_workflow() { + let workflow = WorkFlow::new("odd-numbers-workflow") + .then(|a: usize| async move { Ok::<_, WorkflowError>((0..=a).collect::>()) }) + .filter_map(|x| async move { if x % 2 != 0 { Some(x) } else { None } }) + .filter_map(|x| async move { if x % 3 != 0 { Some(x) } else { None } }) + .filter_map(|x| async move { if x % 5 != 0 { Some(x) } else { None } }) + .delay_for(Duration::from_millis(1000)) + .then(|a: Vec| async move { + println!("Sum: {}", a.iter().sum::()); + Err::<(), WorkflowError>(WorkflowError::MissingContextError) + }); + + let pool = SqlitePool::connect(":memory:").await.unwrap(); + SqliteStorage::setup(&pool).await.unwrap(); + let mut sqlite = SqliteStorage::new_with_callback( + &pool, + &Config::new("workflow-queue").with_poll_interval( + StrategyBuilder::new() + .apply(IntervalStrategy::new(Duration::from_millis(100))) + .build(), + ), + ); + + sqlite.push(100usize).await.unwrap(); + + let worker = WorkerBuilder::new("rango-tango") + .backend(sqlite) + .on_event(|ctx, ev| { + println!("On Event = {:?}", ev); + if matches!(ev, Event::Error(_)) { + ctx.stop().unwrap(); + } + }) + .build(workflow); + worker.run().await.unwrap(); + } + + #[tokio::test] + async fn test_workflow_complete() { + #[derive(Debug, Serialize, Deserialize, Clone)] + struct PipelineConfig { + min_confidence: f32, + enable_sentiment: bool, + } + + #[derive(Debug, Serialize, Deserialize)] + struct UserInput { + text: String, + } + + #[derive(Debug, Serialize, Deserialize)] + struct Classified { + text: String, + label: String, + confidence: f32, + } + + #[derive(Debug, Serialize, Deserialize)] + struct Summary { + text: String, + sentiment: Option, + } + + let workflow = WorkFlow::new("text-pipeline") + // Step 1: Preprocess input (e.g., tokenize, lowercase) + .then(|input: UserInput, mut worker: WorkerContext| async move { + worker.emit(&Event::Custom(Box::new(format!( + "Preprocessing input: {}", + input.text + )))); + let processed = input.text.to_lowercase(); + Ok::<_, WorkflowError>(processed) + }) + // Step 2: Classify text + .then(|text: String| async move { + let confidence = 0.85; // pretend model confidence + let items = text.split_whitespace().collect::>(); + let results = items + .into_iter() + .map(|x| Classified { + text: x.to_string(), + label: if x.contains("rust") { + "Tech" + } else { + "General" + } + .to_string(), + confidence, + }) + .collect::>(); + Ok::<_, WorkflowError>(results) + }) + // Step 3: Filter out low-confidence predictions + .filter_map( + |c: Classified| async move { if c.confidence >= 0.6 { Some(c) } else { None } }, + ) + .filter_map(move |c: Classified, config: Data| { + let cfg = config.enable_sentiment; + async move { + if !cfg { + return Some(Summary { + text: c.text, + sentiment: None, + }); + } + + // pretend we run a sentiment model + let sentiment = if c.text.contains("delightful") { + "positive" + } else { + "neutral" + }; + Some(Summary { + text: c.text, + sentiment: Some(sentiment.to_string()), + }) + } + }) + .then(|a: Vec, mut worker: WorkerContext| async move { + worker.emit(&Event::Custom(Box::new(format!( + "Generated {} summaries", + a.len() + )))); + worker.stop() + }); + + let pool = SqlitePool::connect(":memory:").await.unwrap(); + SqliteStorage::setup(&pool).await.unwrap(); + let mut sqlite = SqliteStorage::new_with_callback(&pool, &Config::new("text-pipeline")); + + let input = UserInput { + text: "Rust makes systems programming delightful!".to_string(), + }; + sqlite.push(input).await.unwrap(); + + let worker = WorkerBuilder::new("rango-tango") + .backend(sqlite) + .data(PipelineConfig { + min_confidence: 0.8, + enable_sentiment: true, + }) + .on_event(|ctx, ev| match ev { + Event::Custom(msg) => { + if let Some(m) = msg.downcast_ref::() { + println!("Custom Message: {}", m); + } + } + Event::Error(_) => { + println!("On Error = {:?}", ev); + ctx.stop().unwrap(); + } + _ => { + println!("On Event = {:?}", ev); + } + }) + .build(workflow); + worker.run().await.unwrap(); + } } diff --git a/src/queries/fetch_by_id.rs b/src/queries/fetch_by_id.rs index 04d5658..4b8aad6 100644 --- a/src/queries/fetch_by_id.rs +++ b/src/queries/fetch_by_id.rs @@ -12,6 +12,7 @@ where Backend, D: Codec, D::Error: std::error::Error + Send + Sync + 'static, + Args: 'static, { fn fetch_by_id( &mut self, diff --git a/src/queries/list_tasks.rs b/src/queries/list_tasks.rs index d86843b..e4d5a75 100644 --- a/src/queries/list_tasks.rs +++ b/src/queries/list_tasks.rs @@ -12,6 +12,7 @@ where Backend, D: Codec, D::Error: std::error::Error + Send + Sync + 'static, + Args: 'static, { fn list_tasks( &self, diff --git a/src/queries/mod.rs b/src/queries/mod.rs index 52ef448..64dfd8f 100644 --- a/src/queries/mod.rs +++ b/src/queries/mod.rs @@ -1,13 +1,14 @@ use apalis_core::backend::StatType; pub mod fetch_by_id; +pub mod keep_alive; pub mod list_queues; pub mod list_tasks; pub mod list_workers; pub mod metrics; pub mod reenqueue_orphaned; pub mod register_worker; -pub mod keep_alive; +pub mod wait_for; fn stat_type_from_string(s: &str) -> StatType { match s { diff --git a/src/queries/reenqueue_orphaned.rs b/src/queries/reenqueue_orphaned.rs index 93788da..a0e7d32 100644 --- a/src/queries/reenqueue_orphaned.rs +++ b/src/queries/reenqueue_orphaned.rs @@ -23,11 +23,11 @@ pub fn reenqueue_orphaned( res.rows_affected() ); } - return Ok(res.rows_affected()); + Ok(res.rows_affected()) } Err(e) => { - log::error!("Failed to re-enqueue orphaned tasks: {}", e); - return Err(e); + log::error!("Failed to re-enqueue orphaned tasks: {e}"); + Err(e) } } } diff --git a/src/queries/wait_for.rs b/src/queries/wait_for.rs new file mode 100644 index 0000000..ddaf54d --- /dev/null +++ b/src/queries/wait_for.rs @@ -0,0 +1,114 @@ +use std::{collections::HashSet, str::FromStr, vec}; + +use apalis_core::{ + backend::{Backend, TaskResult, WaitForCompletion}, + task::{status::Status, task_id::TaskId}, +}; +use futures::{StreamExt, stream::BoxStream}; +use serde::de::DeserializeOwned; +use ulid::Ulid; + +use crate::{CompactType, SqliteStorage}; + +#[derive(Debug)] +struct ResultRow { + pub id: Option, + pub status: Option, + pub result: Option, +} + +impl WaitForCompletion for SqliteStorage +where + SqliteStorage: + Backend, + Result: DeserializeOwned, +{ + type ResultStream = BoxStream<'static, Result, Self::Error>>; + fn wait_for( + &self, + task_ids: impl IntoIterator>, + ) -> Self::ResultStream { + let pool = self.pool.clone(); + let ids: HashSet = task_ids.into_iter().map(|id| id.to_string()).collect(); + + let stream = futures::stream::unfold(ids, move |mut remaining_ids| { + let pool = pool.clone(); + async move { + if remaining_ids.is_empty() { + return None; + } + + let ids_vec: Vec = remaining_ids.iter().cloned().collect(); + let ids_vec = serde_json::to_string(&ids_vec).unwrap(); + let rows = sqlx::query_file_as!( + ResultRow, + "queries/backend/fetch_completed_tasks.sql", + ids_vec + ) + .fetch_all(&pool) + .await + .ok()?; + + if rows.is_empty() { + apalis_core::timer::sleep(std::time::Duration::from_millis(500)).await; + return Some((futures::stream::iter(vec![]), remaining_ids)); + } + + let mut results = Vec::new(); + for row in rows { + let task_id = row.id.clone().unwrap(); + remaining_ids.remove(&task_id); + // Here we would normally decode the output O from the row + // For simplicity, we assume O is String and the output is stored in row.output + let result: Result = + serde_json::from_str(&row.result.unwrap()).unwrap(); + results.push(Ok(TaskResult::new( + TaskId::from_str(&task_id).ok()?, + Status::from_str(&row.status.unwrap()).ok()?, + result, + ))); + } + + Some((futures::stream::iter(results), remaining_ids)) + } + }); + stream.flatten().boxed() + } + + // Implementation of check_status + fn check_status( + &self, + task_ids: impl IntoIterator> + Send, + ) -> impl Future>, Self::Error>> + Send { + let pool = self.pool.clone(); + let ids: Vec = task_ids.into_iter().map(|id| id.to_string()).collect(); + + async move { + let ids = serde_json::to_string(&ids).unwrap(); + let rows = + sqlx::query_file_as!(ResultRow, "queries/backend/fetch_completed_tasks.sql", ids) + .fetch_all(&pool) + .await?; + + let mut results = Vec::new(); + for row in rows { + let task_id = TaskId::from_str(&row.id.unwrap()) + .map_err(|_| sqlx::Error::Protocol("Invalid task ID".into()))?; + + let result: Result = serde_json::from_str(&row.result.unwrap()) + .map_err(|_| sqlx::Error::Protocol("Failed to decode result".into()))?; + + results.push(TaskResult::new( + task_id, + row.status + .unwrap() + .parse() + .map_err(|_| sqlx::Error::Protocol("Invalid status value".into()))?, + result, + )); + } + + Ok(results) + } + } +} diff --git a/src/shared.rs b/src/shared.rs index 4ca3844..719d5d1 100644 --- a/src/shared.rs +++ b/src/shared.rs @@ -287,14 +287,10 @@ mod tests { use std::time::Duration; use apalis_core::{ - backend::TaskSink, - error::BoxDynError, - task::{Task, task_id::TaskId}, + backend::TaskSink, error::BoxDynError, task::task_id::TaskId, worker::builder::WorkerBuilder, }; - use crate::context::SqliteContext; - use super::*; #[tokio::test] @@ -309,13 +305,8 @@ mod tests { let mut int_store = store.make_shared().unwrap(); - let task = Task::builder(99u32) - .run_after(Duration::from_secs(2)) - .with_ctx(SqliteContext::new().with_priority(1)) - .build(); - map_store - .send_all(&mut stream::iter(vec![task].into_iter().map(Ok))) + .push(HashMap::::from([("value".to_string(), 42)])) .await .unwrap(); int_store.push(99).await.unwrap(); diff --git a/src/sink.rs b/src/sink.rs index 95453ff..396bad3 100644 --- a/src/sink.rs +++ b/src/sink.rs @@ -4,7 +4,6 @@ use std::{ task::{Context, Poll}, }; -use apalis_core::backend::codec::Codec; use futures::{ FutureExt, Sink, future::{BoxFuture, Shared}, @@ -59,6 +58,7 @@ pub async fn push_tasks( Some(ref queue) => queue.to_string(), None => cfg.queue().to_string(), }; + let meta = serde_json::to_string(&task.parts.ctx.meta()).unwrap_or_default(); sqlx::query_file!( "queries/task/sink.sql", args, @@ -67,6 +67,7 @@ pub async fn push_tasks( max_attempts, run_at, priority, + meta ) .execute(&mut *tx) .await?; @@ -88,11 +89,9 @@ impl SqliteSink { } } -impl Sink> for SqliteStorage +impl Sink> for SqliteStorage where Args: Send + Sync + 'static, - Encode: Codec, - Encode::Error: std::error::Error + Send + Sync + 'static, { type Error = sqlx::Error; @@ -100,12 +99,9 @@ where Poll::Ready(Ok(())) } - fn start_send(self: Pin<&mut Self>, item: SqliteTask) -> Result<(), Self::Error> { + fn start_send(self: Pin<&mut Self>, item: SqliteTask) -> Result<(), Self::Error> { // Add the item to the buffer - self.project() - .sink - .buffer - .push(item.try_map(|s| Encode::encode(&s).map_err(|e| sqlx::Error::Encode(e.into())))?); + self.project().sink.buffer.push(item); Ok(()) }