diff --git a/migrations/20250212120000_4.sql b/migrations/20250212120000_4.sql new file mode 100644 index 0000000..d2a052b --- /dev/null +++ b/migrations/20250212120000_4.sql @@ -0,0 +1,11 @@ +-- Force anything running this migration to use the right search path. +set local search_path to underway; + +alter table underway.task +add column if not exists lease_token uuid; + +alter table underway.task +add column if not exists lease_expires_at timestamp with time zone; + +alter table underway.task_attempt +add column if not exists lease_token uuid; diff --git a/src/job.rs b/src/job.rs index 3287b58..57b9896 100644 --- a/src/job.rs +++ b/src/job.rs @@ -652,9 +652,7 @@ use uuid::Uuid; use crate::{ queue::{Error as QueueError, InProgressTask, Queue}, scheduler::{Error as SchedulerError, Scheduler, ZonedSchedule}, - task::{ - Error as TaskError, Result as TaskResult, RetryPolicy, State as TaskState, Task, TaskId, - }, + task::{Error as TaskError, Result as TaskResult, RetryPolicy, State as TaskState, Task}, worker::{Error as WorkerError, Worker}, }; @@ -917,7 +915,9 @@ impl EnqueuedJob { retry_policy as "retry_policy: RetryPolicy", timeout, heartbeat, - concurrency_key + concurrency_key, + 0::int as "attempt_number!", + lease_token as "lease_token?" from underway.task where input->>'job_id' = $1 and state = $2 diff --git a/src/queue.rs b/src/queue.rs index ff58e45..67a9ff4 100644 --- a/src/queue.rs +++ b/src/queue.rs @@ -806,7 +806,8 @@ impl Queue { // "in-progress". let mut tx = self.pool.begin().await?; - let in_progress_task = sqlx::query_as!( + let lease_token = Uuid::new_v4(); + let mut in_progress_task = sqlx::query_as!( InProgressTask, r#" with available_task as ( @@ -820,7 +821,7 @@ impl Queue { or ( state = $3 -- Has heartbeat stalled? - and last_heartbeat_at < now() - heartbeat + and (lease_expires_at < now() or lease_expires_at is null) -- Are there remaining retries? and (retry_policy).max_attempts > ( select count(*) @@ -841,7 +842,9 @@ impl Queue { update underway.task t set state = $3, last_attempt_at = now(), - last_heartbeat_at = now() + last_heartbeat_at = now(), + lease_token = $4, + lease_expires_at = now() + heartbeat from available_task where t.task_queue_name = $1 and t.id = available_task.id @@ -852,16 +855,19 @@ impl Queue { t.timeout, t.heartbeat, t.retry_policy as "retry_policy: RetryPolicy", - t.concurrency_key + t.concurrency_key, + 0::int as "attempt_number!", + t.lease_token as "lease_token?" "#, self.name, TaskState::Pending as TaskState, TaskState::InProgress as TaskState, + lease_token, ) .fetch_optional(&mut *tx) .await?; - if let Some(in_progress_task) = &in_progress_task { + if let Some(in_progress_task) = &mut in_progress_task { let task_id = in_progress_task.id; tracing::Span::current().record("task.id", task_id.as_hyphenated().to_string()); @@ -886,7 +892,7 @@ impl Queue { .await?; // Insert a new task attempt row - sqlx::query!( + let attempt_number = sqlx::query_scalar!( r#" with next_attempt as ( select coalesce(max(attempt_number) + 1, 1) as attempt_number @@ -898,21 +904,27 @@ impl Queue { task_id, task_queue_name, state, - attempt_number + attempt_number, + lease_token ) values ( $1, $2, $3, - (select attempt_number from next_attempt) + (select attempt_number from next_attempt), + $4 ) + returning attempt_number "#, task_id as TaskId, self.name, - TaskState::InProgress as TaskState + TaskState::InProgress as TaskState, + in_progress_task.lease_token ) - .execute(&mut *tx) + .fetch_one(&mut *tx) .await?; + + in_progress_task.attempt_number = attempt_number; } tx.commit().await?; @@ -1188,10 +1200,21 @@ pub struct InProgressTask { pub(crate) heartbeat: sqlx::postgres::types::PgInterval, pub(crate) retry_policy: RetryPolicy, pub(crate) concurrency_key: Option, + pub(crate) attempt_number: i32, + pub(crate) lease_token: Option, } impl InProgressTask { pub(crate) async fn mark_succeeded(&self, conn: &mut PgConnection) -> Result { + let Some(lease_token) = self.lease_token else { + tracing::warn!( + task.id = %self.id.as_hyphenated(), + attempt_number = self.attempt_number, + "Skipping success update without lease token" + ); + return Ok(()); + }; + // Update the task attempt row. let result = sqlx::query!( r#" @@ -1201,24 +1224,27 @@ impl InProgressTask { completed_at = now() where task_id = $1 and task_queue_name = $2 - and attempt_number = ( - select attempt_number - from underway.task_attempt - where task_id = $1 - and task_queue_name = $2 - order by attempt_number desc - limit 1 - ) + and attempt_number = $4 + and state = $5 + and lease_token = $6 "#, self.id as TaskId, self.queue_name, - TaskState::Succeeded as TaskState + TaskState::Succeeded as TaskState, + self.attempt_number, + TaskState::InProgress as TaskState, + lease_token ) .execute(&mut *conn) .await?; if result.rows_affected() == 0 { - return Err(Error::TaskNotFound(self.id)); + tracing::warn!( + task.id = %self.id.as_hyphenated(), + attempt_number = self.attempt_number, + "Skipping success update for non-current task attempt" + ); + return Ok(()); } // Update the task row. @@ -1227,23 +1253,38 @@ impl InProgressTask { update underway.task set state = $2, updated_at = now(), - completed_at = now() + completed_at = now(), + lease_token = null, + lease_expires_at = null where id = $1 + and task_queue_name = $3 + and state = $4 + and lease_token = $5 "#, self.id as TaskId, - TaskState::Succeeded as TaskState + TaskState::Succeeded as TaskState, + self.queue_name, + TaskState::InProgress as TaskState, + lease_token ) .execute(&mut *conn) .await?; if result.rows_affected() == 0 { - return Err(Error::TaskNotFound(self.id)); + tracing::warn!( + task.id = %self.id.as_hyphenated(), + attempt_number = self.attempt_number, + "Skipping success update for non-current task attempt" + ); + return Ok(()); } Ok(()) } pub(crate) async fn mark_cancelled(&self, conn: &mut PgConnection) -> Result { + let lease_token = self.lease_token; + // Update task attempt row if one exists. // // N.B.: A task may be cancelled before an attempt row has been created. @@ -1255,20 +1296,19 @@ impl InProgressTask { completed_at = now() where task_id = $1 and task_queue_name = $2 - and attempt_number = ( - select attempt_number - from underway.task_attempt - where task_id = $1 - and task_queue_name = $2 - and state < $4 - order by attempt_number desc - limit 1 + and attempt_number = $4 + and state < $5 + and ( + (lease_token = $6) + or (lease_token is null and $6 is null) ) "#, self.id as TaskId, self.queue_name, TaskState::Cancelled as TaskState, - TaskState::Succeeded as TaskState + self.attempt_number, + TaskState::Succeeded as TaskState, + lease_token ) .execute(&mut *conn) .await?; @@ -1279,18 +1319,33 @@ impl InProgressTask { update underway.task set state = $2, updated_at = now(), - completed_at = now() - where id = $1 and state < $3 + completed_at = now(), + lease_token = null, + lease_expires_at = null + where id = $1 + and task_queue_name = $3 + and state < $4 + and ( + (lease_token = $5) + or (lease_token is null and $5 is null) + ) "#, self.id as TaskId, TaskState::Cancelled as TaskState, - TaskState::Succeeded as TaskState + self.queue_name, + TaskState::Succeeded as TaskState, + lease_token ) .execute(&mut *conn) .await?; if result.rows_affected() == 0 { - return Err(Error::TaskNotFound(self.id)); + tracing::warn!( + task.id = %self.id.as_hyphenated(), + attempt_number = self.attempt_number, + "Skipping cancel update for non-current task attempt" + ); + return Ok(false); } Ok(result.rows_affected() > 0) @@ -1302,6 +1357,15 @@ impl InProgressTask { err )] pub(crate) async fn retry_after(&self, conn: &mut PgConnection, delay: Span) -> Result { + let Some(lease_token) = self.lease_token else { + tracing::warn!( + task.id = %self.id.as_hyphenated(), + attempt_number = self.attempt_number, + "Skipping retry update without lease token" + ); + return Ok(()); + }; + let result = sqlx::query!( r#" update underway.task_attempt @@ -1310,43 +1374,58 @@ impl InProgressTask { completed_at = now() where task_id = $1 and task_queue_name = $2 - and attempt_number = ( - select attempt_number - from underway.task_attempt - where task_id = $1 - and task_queue_name = $2 - order by attempt_number desc - limit 1 - ) + and attempt_number = $4 + and lease_token = $5 "#, self.id as TaskId, self.queue_name, - TaskState::Failed as TaskState + TaskState::Failed as TaskState, + self.attempt_number, + lease_token ) .execute(&mut *conn) .await?; if result.rows_affected() == 0 { - return Err(Error::TaskNotFound(self.id)); + tracing::warn!( + task.id = %self.id.as_hyphenated(), + attempt_number = self.attempt_number, + "Skipping retry update for non-current task attempt" + ); + return Ok(()); } + let delay_duration = StdDuration::try_from(delay)?; + let delay_interval = sqlx::postgres::types::PgInterval::try_from(delay_duration) + .map_err(sqlx::Error::Decode)?; let result = sqlx::query!( r#" update underway.task set state = $3, delay = $2, - updated_at = now() + updated_at = now(), + lease_token = null, + lease_expires_at = null where id = $1 + and task_queue_name = $4 + and lease_token = $5 "#, self.id as TaskId, - StdDuration::try_from(delay)? as _, - TaskState::Pending as TaskState + delay_interval, + TaskState::Pending as TaskState, + self.queue_name, + lease_token ) .execute(&mut *conn) .await?; if result.rows_affected() == 0 { - return Err(Error::TaskNotFound(self.id)); + tracing::warn!( + task.id = %self.id.as_hyphenated(), + attempt_number = self.attempt_number, + "Skipping retry update for non-current task attempt" + ); + return Ok(()); } Ok(()) @@ -1357,6 +1436,15 @@ impl InProgressTask { conn: &mut PgConnection, error: &TaskError, ) -> Result { + let Some(lease_token) = self.lease_token else { + tracing::warn!( + task.id = %self.id.as_hyphenated(), + attempt_number = self.attempt_number, + "Skipping failure update without lease token" + ); + return Ok(()); + }; + let result = sqlx::query!( r#" update underway.task_attempt @@ -1365,25 +1453,28 @@ impl InProgressTask { error_message = $4 where task_id = $1 and task_queue_name = $2 - and attempt_number = ( - select attempt_number - from underway.task_attempt - where task_id = $1 - and task_queue_name = $2 - order by attempt_number desc - limit 1 - ) + and attempt_number = $5 + and state = $6 + and lease_token = $7 "#, self.id as TaskId, self.queue_name, TaskState::Failed as TaskState, error.to_string(), + self.attempt_number, + TaskState::InProgress as TaskState, + lease_token ) .execute(&mut *conn) .await?; if result.rows_affected() == 0 { - return Err(Error::TaskNotFound(self.id)); + tracing::warn!( + task.id = %self.id.as_hyphenated(), + attempt_number = self.attempt_number, + "Skipping failure update for non-current task attempt" + ); + return Ok(()); } let result = sqlx::query!( @@ -1392,21 +1483,39 @@ impl InProgressTask { set updated_at = now() where id = $1 and task_queue_name = $2 + and state = $3 + and lease_token = $4 "#, self.id as TaskId, self.queue_name, + TaskState::InProgress as TaskState, + lease_token ) .execute(&mut *conn) .await?; if result.rows_affected() == 0 { - return Err(Error::TaskNotFound(self.id)); + tracing::warn!( + task.id = %self.id.as_hyphenated(), + attempt_number = self.attempt_number, + "Skipping failure update for non-current task attempt" + ); + return Ok(()); } Ok(()) } pub(crate) async fn mark_failed(&self, conn: &mut PgConnection) -> Result { + let Some(lease_token) = self.lease_token else { + tracing::warn!( + task.id = %self.id.as_hyphenated(), + attempt_number = self.attempt_number, + "Skipping failed update without lease token" + ); + return Ok(()); + }; + // Update the task attempt row. let result = sqlx::query!( r#" @@ -1416,24 +1525,25 @@ impl InProgressTask { completed_at = now() where task_id = $1 and task_queue_name = $2 - and attempt_number = ( - select attempt_number - from underway.task_attempt - where task_id = $1 - and task_queue_name = $2 - order by attempt_number desc - limit 1 - ) + and attempt_number = $4 + and lease_token = $5 "#, self.id as TaskId, self.queue_name, - TaskState::Failed as TaskState + TaskState::Failed as TaskState, + self.attempt_number, + lease_token ) .execute(&mut *conn) .await?; if result.rows_affected() == 0 { - return Err(Error::TaskNotFound(self.id)); + tracing::warn!( + task.id = %self.id.as_hyphenated(), + attempt_number = self.attempt_number, + "Skipping failed update for non-current task attempt" + ); + return Ok(()); } // Update the task row. @@ -1442,17 +1552,30 @@ impl InProgressTask { update underway.task set state = $2, updated_at = now(), - completed_at = now() + completed_at = now(), + lease_token = null, + lease_expires_at = null where id = $1 + and task_queue_name = $3 + and state = $4 + and lease_token = $5 "#, self.id as TaskId, - TaskState::Failed as TaskState + TaskState::Failed as TaskState, + self.queue_name, + TaskState::InProgress as TaskState, + lease_token ) .execute(&mut *conn) .await?; if result.rows_affected() == 0 { - return Err(Error::TaskNotFound(self.id)); + tracing::warn!( + task.id = %self.id.as_hyphenated(), + attempt_number = self.attempt_number, + "Skipping failed update for non-current task attempt" + ); + return Ok(()); } Ok(()) @@ -1482,20 +1605,44 @@ impl InProgressTask { where E: PgExecutor<'a>, { - sqlx::query!( + let Some(lease_token) = self.lease_token else { + tracing::warn!( + task.id = %self.id.as_hyphenated(), + attempt_number = self.attempt_number, + "Skipping heartbeat without lease token" + ); + return Ok(()); + }; + + let result = sqlx::query!( r#" update underway.task set updated_at = now(), - last_heartbeat_at = now() + last_heartbeat_at = now(), + lease_expires_at = now() + heartbeat where id = $1 and task_queue_name = $2 + and state = $3 + and ( + lease_token = $4 + ) "#, self.id as TaskId, - self.queue_name + self.queue_name, + TaskState::InProgress as TaskState, + lease_token ) .execute(executor) .await?; + if result.rows_affected() == 0 { + tracing::warn!( + task.id = %self.id.as_hyphenated(), + attempt_number = self.attempt_number, + "Skipping heartbeat for non-current task attempt" + ); + } + Ok(()) } @@ -3309,7 +3456,7 @@ mod tests { sqlx::query!( r#" update underway.task - set last_heartbeat_at = now() - interval '30 seconds' + set lease_expires_at = now() - interval '30 seconds' where id = $1 "#, task_id as TaskId @@ -3339,4 +3486,68 @@ mod tests { Ok(()) } + + #[sqlx::test] + async fn stale_attempt_cannot_finalize_task(pool: PgPool) -> sqlx::Result<(), Error> { + let queue = Queue::builder() + .name("stale_attempt_cannot_finalize_task") + .pool(pool.clone()) + .build() + .await?; + + let task_id = queue.enqueue(&pool, &TestTask, &json!("{}")).await?; + + let mut conn = pool.acquire().await?; + let in_progress_task = queue.dequeue().await?.expect("A task should be dequeued"); + + sqlx::query!( + r#" + update underway.task + set last_heartbeat_at = now() - interval '30 seconds' + where id = $1 + "#, + task_id as TaskId + ) + .execute(&pool) + .await?; + + let _new_in_progress_task = queue + .dequeue() + .await? + .expect("A stale task should be dequeued"); + + // Attempt to finalize the stale attempt; this should have no effect. + in_progress_task.mark_succeeded(&mut conn).await?; + + let task_row = sqlx::query!( + r#" + select state as "state: TaskState" + from underway.task + where id = $1 + "#, + task_id as TaskId + ) + .fetch_one(&pool) + .await?; + + assert_eq!(task_row.state, TaskState::InProgress); + + let attempt_rows = sqlx::query!( + r#" + select attempt_number, state as "state: TaskState" + from underway.task_attempt + where task_id = $1 + order by attempt_number + "#, + task_id as TaskId + ) + .fetch_all(&pool) + .await?; + + assert_eq!(attempt_rows.len(), 2); + assert_eq!(attempt_rows[0].state, TaskState::Failed); + assert_eq!(attempt_rows[1].state, TaskState::InProgress); + + Ok(()) + } }