diff --git a/.changeset/real-rats-drop.md b/.changeset/real-rats-drop.md new file mode 100644 index 0000000000..953794afd4 --- /dev/null +++ b/.changeset/real-rats-drop.md @@ -0,0 +1,5 @@ +--- +"@trigger.dev/sdk": patch +--- + +Add onCancel lifecycle hook diff --git a/apps/webapp/app/components/runs/v3/RunIcon.tsx b/apps/webapp/app/components/runs/v3/RunIcon.tsx index c03b32731d..8a1924b3ef 100644 --- a/apps/webapp/app/components/runs/v3/RunIcon.tsx +++ b/apps/webapp/app/components/runs/v3/RunIcon.tsx @@ -97,6 +97,7 @@ export function RunIcon({ name, className, spanName }: TaskIconProps) { case "task-hook-onResume": case "task-hook-onComplete": case "task-hook-cleanup": + case "task-hook-onCancel": return ; case "task-hook-onFailure": case "task-hook-catchError": diff --git a/apps/webapp/app/v3/services/cancelTaskRun.server.ts b/apps/webapp/app/v3/services/cancelTaskRun.server.ts index 811fd54d64..d664d754e8 100644 --- a/apps/webapp/app/v3/services/cancelTaskRun.server.ts +++ b/apps/webapp/app/v3/services/cancelTaskRun.server.ts @@ -47,29 +47,6 @@ export class CancelTaskRunService extends BaseService { tx: this._prisma, }); - const inProgressEvents = await eventRepository.queryIncompleteEvents( - getTaskEventStoreTableForRun(taskRun), - { - runId: taskRun.friendlyId, - }, - taskRun.createdAt, - taskRun.completedAt ?? undefined - ); - - logger.debug("Cancelling in-progress events", { - inProgressEvents: inProgressEvents.map((event) => event.id), - }); - - await Promise.all( - inProgressEvents.map((event) => { - return eventRepository.cancelEvent( - event, - options?.cancelledAt ?? new Date(), - options?.reason ?? "Run cancelled" - ); - }) - ); - return { id: result.run.id, }; diff --git a/packages/cli-v3/src/entryPoints/dev-run-worker.ts b/packages/cli-v3/src/entryPoints/dev-run-worker.ts index fed66e6dc6..3a5ca0815c 100644 --- a/packages/cli-v3/src/entryPoints/dev-run-worker.ts +++ b/packages/cli-v3/src/entryPoints/dev-run-worker.ts @@ -23,6 +23,7 @@ import { TaskRunExecution, timeout, TriggerConfig, + UsageMeasurement, waitUntil, WorkerManifest, WorkerToExecutorMessageCatalog, @@ -232,7 +233,10 @@ async function bootstrap() { let _execution: TaskRunExecution | undefined; let _isRunning = false; +let _isCancelled = false; let _tracingSDK: TracingSDK | undefined; +let _executionMeasurement: UsageMeasurement | undefined; +const cancelController = new AbortController(); const zodIpc = new ZodIpcConnection({ listenSchema: WorkerToExecutorMessageCatalog, @@ -403,18 +407,17 @@ const zodIpc = new ZodIpcConnection({ getNumberEnvVar("TRIGGER_RUN_METADATA_FLUSH_INTERVAL", 1000) ); - const measurement = usage.start(); + _executionMeasurement = usage.start(); - // This lives outside of the executor because this will eventually be moved to the controller level - const signal = execution.run.maxDuration - ? timeout.abortAfterTimeout(execution.run.maxDuration) - : undefined; + const timeoutController = timeout.abortAfterTimeout(execution.run.maxDuration); + + const signal = AbortSignal.any([cancelController.signal, timeoutController.signal]); const { result } = await executor.execute(execution, metadata, traceContext, signal); - const usageSample = usage.stop(measurement); + if (_isRunning && !_isCancelled) { + const usageSample = usage.stop(_executionMeasurement); - if (_isRunning) { return sender.send("TASK_RUN_COMPLETED", { execution, result: { @@ -458,7 +461,16 @@ const zodIpc = new ZodIpcConnection({ WAIT_COMPLETED_NOTIFICATION: async () => { await managedWorkerRuntime.completeWaitpoints([]); }, - FLUSH: async ({ timeoutInMs }, sender) => { + CANCEL: async ({ timeoutInMs }) => { + _isCancelled = true; + cancelController.abort("run cancelled"); + await callCancelHooks(timeoutInMs); + if (_executionMeasurement) { + usage.stop(_executionMeasurement); + } + await flushAll(timeoutInMs); + }, + FLUSH: async ({ timeoutInMs }) => { await flushAll(timeoutInMs); }, WAITPOINT_CREATED: async ({ wait, waitpoint }) => { @@ -470,6 +482,18 @@ const zodIpc = new ZodIpcConnection({ }, }); +async function callCancelHooks(timeoutInMs: number = 10_000) { + const now = performance.now(); + + try { + await Promise.race([lifecycleHooks.callOnCancelHookListeners(), setTimeout(timeoutInMs)]); + } finally { + const duration = performance.now() - now; + + log(`Called cancel hooks in ${duration}ms`); + } +} + async function flushAll(timeoutInMs: number = 10_000) { const now = performance.now(); diff --git a/packages/cli-v3/src/entryPoints/managed-run-worker.ts b/packages/cli-v3/src/entryPoints/managed-run-worker.ts index 19c8718cce..894ed70aa1 100644 --- a/packages/cli-v3/src/entryPoints/managed-run-worker.ts +++ b/packages/cli-v3/src/entryPoints/managed-run-worker.ts @@ -22,6 +22,7 @@ import { TaskRunExecution, timeout, TriggerConfig, + UsageMeasurement, waitUntil, WorkerManifest, WorkerToExecutorMessageCatalog, @@ -229,7 +230,10 @@ async function bootstrap() { let _execution: TaskRunExecution | undefined; let _isRunning = false; +let _isCancelled = false; let _tracingSDK: TracingSDK | undefined; +let _executionMeasurement: UsageMeasurement | undefined; +const cancelController = new AbortController(); const zodIpc = new ZodIpcConnection({ listenSchema: WorkerToExecutorMessageCatalog, @@ -398,18 +402,17 @@ const zodIpc = new ZodIpcConnection({ getNumberEnvVar("TRIGGER_RUN_METADATA_FLUSH_INTERVAL", 1000) ); - const measurement = usage.start(); + _executionMeasurement = usage.start(); - // This lives outside of the executor because this will eventually be moved to the controller level - const signal = execution.run.maxDuration - ? timeout.abortAfterTimeout(execution.run.maxDuration) - : undefined; + const timeoutController = timeout.abortAfterTimeout(execution.run.maxDuration); + + const signal = AbortSignal.any([cancelController.signal, timeoutController.signal]); const { result } = await executor.execute(execution, metadata, traceContext, signal); - const usageSample = usage.stop(measurement); + if (_isRunning && !_isCancelled) { + const usageSample = usage.stop(_executionMeasurement); - if (_isRunning) { return sender.send("TASK_RUN_COMPLETED", { execution, result: { @@ -454,6 +457,15 @@ const zodIpc = new ZodIpcConnection({ FLUSH: async ({ timeoutInMs }, sender) => { await flushAll(timeoutInMs); }, + CANCEL: async ({ timeoutInMs }, sender) => { + _isCancelled = true; + cancelController.abort("run cancelled"); + await callCancelHooks(timeoutInMs); + if (_executionMeasurement) { + usage.stop(_executionMeasurement); + } + await flushAll(timeoutInMs); + }, WAITPOINT_CREATED: async ({ wait, waitpoint }) => { managedWorkerRuntime.associateWaitWithWaitpoint(wait.id, waitpoint.id); }, @@ -463,6 +475,18 @@ const zodIpc = new ZodIpcConnection({ }, }); +async function callCancelHooks(timeoutInMs: number = 10_000) { + const now = performance.now(); + + try { + await Promise.race([lifecycleHooks.callOnCancelHookListeners(), setTimeout(timeoutInMs)]); + } finally { + const duration = performance.now() - now; + + console.log(`Called cancel hooks in ${duration}ms`); + } +} + async function flushAll(timeoutInMs: number = 10_000) { const now = performance.now(); diff --git a/packages/cli-v3/src/executions/taskRunProcess.ts b/packages/cli-v3/src/executions/taskRunProcess.ts index ae95ecb109..f6e179195b 100644 --- a/packages/cli-v3/src/executions/taskRunProcess.ts +++ b/packages/cli-v3/src/executions/taskRunProcess.ts @@ -109,9 +109,9 @@ export class TaskRunProcess { this._isBeingCancelled = true; try { - await this.#flush(); + await this.#cancel(); } catch (err) { - console.error("Error flushing task run process", { err }); + console.error("Error cancelling task run process", { err }); } await this.kill(); @@ -120,6 +120,10 @@ export class TaskRunProcess { async cleanup(kill = true) { this._isPreparedForNextRun = false; + if (this._isBeingCancelled) { + return; + } + try { await this.#flush(); } catch (err) { @@ -224,10 +228,17 @@ export class TaskRunProcess { await this._ipc?.sendWithAck("FLUSH", { timeoutInMs }, timeoutInMs + 1_000); } + async #cancel(timeoutInMs: number = 30_000) { + logger.debug("sending cancel message to task run process", { pid: this.pid, timeoutInMs }); + + await this._ipc?.sendWithAck("CANCEL", { timeoutInMs }, timeoutInMs + 1_000); + } + async execute( params: TaskRunProcessExecuteParams, isWarmStart?: boolean ): Promise { + this._isBeingCancelled = false; this._isPreparedForNextRun = false; this._isPreparedForNextAttempt = false; diff --git a/packages/core/src/utils.ts b/packages/core/src/utils.ts index 4a214a4536..2930f61415 100644 --- a/packages/core/src/utils.ts +++ b/packages/core/src/utils.ts @@ -16,3 +16,25 @@ export async function tryCatch( return [error as E, null]; } } + +export type Deferred = { + promise: Promise; + resolve: (value: T) => void; + reject: (reason?: any) => void; +}; + +export function promiseWithResolvers(): Deferred { + let resolve!: (value: T) => void; + let reject!: (reason?: any) => void; + + const promise = new Promise((_resolve, _reject) => { + resolve = _resolve; + reject = _reject; + }); + + return { + promise, + resolve, + reject, + }; +} diff --git a/packages/core/src/v3/lifecycle-hooks-api.ts b/packages/core/src/v3/lifecycle-hooks-api.ts index ec9e87c998..3c74f719cb 100644 --- a/packages/core/src/v3/lifecycle-hooks-api.ts +++ b/packages/core/src/v3/lifecycle-hooks-api.ts @@ -32,4 +32,7 @@ export type { AnyOnCleanupHookFunction, TaskCleanupHookParams, TaskWait, + TaskCancelHookParams, + OnCancelHookFunction, + AnyOnCancelHookFunction, } from "./lifecycleHooks/types.js"; diff --git a/packages/core/src/v3/lifecycleHooks/index.ts b/packages/core/src/v3/lifecycleHooks/index.ts index 843ae92ce8..99ed47ae60 100644 --- a/packages/core/src/v3/lifecycleHooks/index.ts +++ b/packages/core/src/v3/lifecycleHooks/index.ts @@ -13,6 +13,7 @@ import { AnyOnStartHookFunction, AnyOnSuccessHookFunction, AnyOnWaitHookFunction, + AnyOnCancelHookFunction, RegisteredHookFunction, RegisterHookFunctionParams, TaskWait, @@ -260,6 +261,33 @@ export class LifecycleHooksAPI { this.#getManager().registerOnResumeHookListener(listener); } + public registerGlobalCancelHook(hook: RegisterHookFunctionParams): void { + this.#getManager().registerGlobalCancelHook(hook); + } + + public registerTaskCancelHook( + taskId: string, + hook: RegisterHookFunctionParams + ): void { + this.#getManager().registerTaskCancelHook(taskId, hook); + } + + public getTaskCancelHook(taskId: string): AnyOnCancelHookFunction | undefined { + return this.#getManager().getTaskCancelHook(taskId); + } + + public getGlobalCancelHooks(): RegisteredHookFunction[] { + return this.#getManager().getGlobalCancelHooks(); + } + + public callOnCancelHookListeners(): Promise { + return this.#getManager().callOnCancelHookListeners(); + } + + public registerOnCancelHookListener(listener: () => Promise): void { + this.#getManager().registerOnCancelHookListener(listener); + } + #getManager(): LifecycleHooksManager { return getGlobal(API_NAME) ?? NOOP_LIFECYCLE_HOOKS_MANAGER; } diff --git a/packages/core/src/v3/lifecycleHooks/manager.ts b/packages/core/src/v3/lifecycleHooks/manager.ts index 29f4968362..f6cceb8d55 100644 --- a/packages/core/src/v3/lifecycleHooks/manager.ts +++ b/packages/core/src/v3/lifecycleHooks/manager.ts @@ -13,6 +13,7 @@ import { AnyOnMiddlewareHookFunction, AnyOnCleanupHookFunction, TaskWait, + AnyOnCancelHookFunction, } from "./types.js"; export class StandardLifecycleHooksManager implements LifecycleHooksManager { @@ -37,9 +38,6 @@ export class StandardLifecycleHooksManager implements LifecycleHooksManager { private taskCompleteHooks: Map> = new Map(); - private globalWaitHooks: Map> = new Map(); - private taskWaitHooks: Map> = new Map(); - private globalResumeHooks: Map> = new Map(); private taskResumeHooks: Map> = new Map(); @@ -59,9 +57,25 @@ export class StandardLifecycleHooksManager implements LifecycleHooksManager { private taskCleanupHooks: Map> = new Map(); + private globalWaitHooks: Map> = new Map(); + private taskWaitHooks: Map> = new Map(); private onWaitHookListeners: ((wait: TaskWait) => Promise)[] = []; + private onResumeHookListeners: ((wait: TaskWait) => Promise)[] = []; + private globalCancelHooks: Map> = + new Map(); + private taskCancelHooks: Map> = new Map(); + private onCancelHookListeners: (() => Promise)[] = []; + + registerOnCancelHookListener(listener: () => Promise): void { + this.onCancelHookListeners.push(listener); + } + + async callOnCancelHookListeners(): Promise { + await Promise.allSettled(this.onCancelHookListeners.map((listener) => listener())); + } + registerOnWaitHookListener(listener: (wait: TaskWait) => Promise): void { this.onWaitHookListeners.push(listener); } @@ -394,9 +408,65 @@ export class StandardLifecycleHooksManager implements LifecycleHooksManager { getGlobalCleanupHooks(): RegisteredHookFunction[] { return Array.from(this.globalCleanupHooks.values()); } + + registerGlobalCancelHook(hook: RegisterHookFunctionParams): void { + const id = generateHookId(hook); + + this.globalCancelHooks.set(id, { + id, + name: hook.id, + fn: hook.fn, + }); + } + + registerTaskCancelHook( + taskId: string, + hook: RegisterHookFunctionParams + ): void { + const id = generateHookId(hook); + + this.taskCancelHooks.set(taskId, { + id, + name: hook.id, + fn: hook.fn, + }); + } + + getGlobalCancelHooks(): RegisteredHookFunction[] { + return Array.from(this.globalCancelHooks.values()); + } + + getTaskCancelHook(taskId: string): AnyOnCancelHookFunction | undefined { + return this.taskCancelHooks.get(taskId)?.fn; + } } export class NoopLifecycleHooksManager implements LifecycleHooksManager { + registerOnCancelHookListener(listener: () => Promise): void { + // Noop + } + + async callOnCancelHookListeners(): Promise { + // Noop + } + + registerGlobalCancelHook(hook: RegisterHookFunctionParams): void {} + + registerTaskCancelHook( + taskId: string, + hook: RegisterHookFunctionParams + ): void { + // Noop + } + + getTaskCancelHook(taskId: string): AnyOnCancelHookFunction | undefined { + return undefined; + } + + getGlobalCancelHooks(): RegisteredHookFunction[] { + return []; + } + registerOnWaitHookListener(listener: (wait: TaskWait) => Promise): void { // Noop } diff --git a/packages/core/src/v3/lifecycleHooks/types.ts b/packages/core/src/v3/lifecycleHooks/types.ts index 5d307c225b..9501216546 100644 --- a/packages/core/src/v3/lifecycleHooks/types.ts +++ b/packages/core/src/v3/lifecycleHooks/types.ts @@ -7,7 +7,7 @@ export type TaskInitHookParams = { ctx: TaskRunContext; payload: TPayload; task: string; - signal?: AbortSignal; + signal: AbortSignal; }; export type OnInitHookFunction = ( @@ -23,7 +23,7 @@ export type TaskStartHookParams< ctx: TaskRunContext; payload: TPayload; task: string; - signal?: AbortSignal; + signal: AbortSignal; init?: TInitOutput; }; @@ -60,7 +60,7 @@ export type TaskWaitHookParams< ctx: TaskRunContext; payload: TPayload; task: string; - signal?: AbortSignal; + signal: AbortSignal; init?: TInitOutput; }; @@ -78,7 +78,7 @@ export type TaskResumeHookParams< wait: TaskWait; payload: TPayload; task: string; - signal?: AbortSignal; + signal: AbortSignal; init?: TInitOutput; }; @@ -96,7 +96,7 @@ export type TaskFailureHookParams< payload: TPayload; task: string; error: unknown; - signal?: AbortSignal; + signal: AbortSignal; init?: TInitOutput; }; @@ -115,7 +115,7 @@ export type TaskSuccessHookParams< payload: TPayload; task: string; output: TOutput; - signal?: AbortSignal; + signal: AbortSignal; init?: TInitOutput; }; @@ -152,7 +152,7 @@ export type TaskCompleteHookParams< payload: TPayload; task: string; result: TaskCompleteResult; - signal?: AbortSignal; + signal: AbortSignal; init?: TInitOutput; }; @@ -188,7 +188,7 @@ export type TaskCatchErrorHookParams< retry?: RetryOptions; retryAt?: Date; retryDelayInMs?: number; - signal?: AbortSignal; + signal: AbortSignal; init?: TInitOutput; }; @@ -203,7 +203,7 @@ export type TaskMiddlewareHookParams = { ctx: TaskRunContext; payload: TPayload; task: string; - signal?: AbortSignal; + signal: AbortSignal; next: () => Promise; }; @@ -220,7 +220,7 @@ export type TaskCleanupHookParams< ctx: TaskRunContext; payload: TPayload; task: string; - signal?: AbortSignal; + signal: AbortSignal; init?: TInitOutput; }; @@ -230,6 +230,29 @@ export type OnCleanupHookFunction; +export type TaskCancelHookParams< + TPayload = unknown, + TRunOutput = any, + TInitOutput extends TaskInitOutput = TaskInitOutput, +> = { + ctx: TaskRunContext; + payload: TPayload; + task: string; + runPromise: Promise; + init?: TInitOutput; + signal: AbortSignal; +}; + +export type OnCancelHookFunction< + TPayload, + TRunOutput = any, + TInitOutput extends TaskInitOutput = TaskInitOutput, +> = ( + params: TaskCancelHookParams +) => undefined | void | Promise; + +export type AnyOnCancelHookFunction = OnCancelHookFunction; + export interface LifecycleHooksManager { registerGlobalInitHook(hook: RegisterHookFunctionParams): void; registerTaskInitHook( @@ -307,4 +330,15 @@ export interface LifecycleHooksManager { callOnResumeHookListeners(wait: TaskWait): Promise; registerOnResumeHookListener(listener: (wait: TaskWait) => Promise): void; + + registerGlobalCancelHook(hook: RegisterHookFunctionParams): void; + registerTaskCancelHook( + taskId: string, + hook: RegisterHookFunctionParams + ): void; + getGlobalCancelHooks(): RegisteredHookFunction[]; + getTaskCancelHook(taskId: string): AnyOnCancelHookFunction | undefined; + + registerOnCancelHookListener(listener: () => Promise): void; + callOnCancelHookListeners(): Promise; } diff --git a/packages/core/src/v3/schemas/messages.ts b/packages/core/src/v3/schemas/messages.ts index edbdfac3de..2bbfed7aa2 100644 --- a/packages/core/src/v3/schemas/messages.ts +++ b/packages/core/src/v3/schemas/messages.ts @@ -243,6 +243,12 @@ export const WorkerToExecutorMessageCatalog = { }), callback: z.void(), }, + CANCEL: { + message: z.object({ + timeoutInMs: z.number(), + }), + callback: z.void(), + }, WAITPOINT_CREATED: { message: z.object({ version: z.literal("v1").default("v1"), diff --git a/packages/core/src/v3/timeout/api.ts b/packages/core/src/v3/timeout/api.ts index d8370118e2..ed6bd506ec 100644 --- a/packages/core/src/v3/timeout/api.ts +++ b/packages/core/src/v3/timeout/api.ts @@ -4,8 +4,8 @@ import { TimeoutManager } from "./types.js"; const API_NAME = "timeout"; class NoopTimeoutManager implements TimeoutManager { - abortAfterTimeout(timeoutInSeconds: number): AbortSignal { - return new AbortController().signal; + abortAfterTimeout(timeoutInSeconds?: number): AbortController { + return new AbortController(); } } @@ -25,11 +25,11 @@ export class TimeoutAPI implements TimeoutManager { } public get signal(): AbortSignal | undefined { - return this.#getManagerManager().signal; + return this.#getManager().signal; } - public abortAfterTimeout(timeoutInSeconds: number): AbortSignal { - return this.#getManagerManager().abortAfterTimeout(timeoutInSeconds); + public abortAfterTimeout(timeoutInSeconds?: number): AbortController { + return this.#getManager().abortAfterTimeout(timeoutInSeconds); } public setGlobalManager(manager: TimeoutManager): boolean { @@ -40,7 +40,7 @@ export class TimeoutAPI implements TimeoutManager { unregisterGlobal(API_NAME); } - #getManagerManager(): TimeoutManager { + #getManager(): TimeoutManager { return getGlobal(API_NAME) ?? NOOP_TIMEOUT_MANAGER; } } diff --git a/packages/core/src/v3/timeout/types.ts b/packages/core/src/v3/timeout/types.ts index 9d2be0ef52..7f263bb4e2 100644 --- a/packages/core/src/v3/timeout/types.ts +++ b/packages/core/src/v3/timeout/types.ts @@ -1,5 +1,5 @@ export interface TimeoutManager { - abortAfterTimeout: (timeoutInSeconds: number) => AbortSignal; + abortAfterTimeout: (timeoutInSeconds?: number) => AbortController; signal?: AbortSignal; } diff --git a/packages/core/src/v3/timeout/usageTimeoutManager.ts b/packages/core/src/v3/timeout/usageTimeoutManager.ts index 030e30602d..b90546832b 100644 --- a/packages/core/src/v3/timeout/usageTimeoutManager.ts +++ b/packages/core/src/v3/timeout/usageTimeoutManager.ts @@ -4,6 +4,7 @@ import { TaskRunExceededMaxDuration, TimeoutManager } from "./types.js"; export class UsageTimeoutManager implements TimeoutManager { private _abortController: AbortController; private _abortSignal: AbortSignal | undefined; + private _intervalId: NodeJS.Timeout | undefined; constructor(private readonly usageManager: UsageManager) { this._abortController = new AbortController(); @@ -13,15 +14,23 @@ export class UsageTimeoutManager implements TimeoutManager { return this._abortSignal; } - abortAfterTimeout(timeoutInSeconds: number): AbortSignal { + abortAfterTimeout(timeoutInSeconds?: number): AbortController { this._abortSignal = this._abortController.signal; + if (!timeoutInSeconds) { + return this._abortController; + } + + if (this._intervalId) { + clearInterval(this._intervalId); + } + // Now we need to start an interval that will measure usage and abort the signal if the usage is too high - const intervalId = setInterval(() => { + this._intervalId = setInterval(() => { const sample = this.usageManager.sample(); if (sample) { if (sample.cpuTime > timeoutInSeconds * 1000) { - clearInterval(intervalId); + clearInterval(this._intervalId); this._abortController.abort( new TaskRunExceededMaxDuration(timeoutInSeconds, sample.cpuTime / 1000) @@ -30,6 +39,6 @@ export class UsageTimeoutManager implements TimeoutManager { } }, 1000); - return this._abortSignal; + return this._abortController; } } diff --git a/packages/core/src/v3/types/tasks.ts b/packages/core/src/v3/types/tasks.ts index c307bc25f2..e64e1684ac 100644 --- a/packages/core/src/v3/types/tasks.ts +++ b/packages/core/src/v3/types/tasks.ts @@ -12,6 +12,7 @@ import { OnStartHookFunction, OnSuccessHookFunction, OnWaitHookFunction, + OnCancelHookFunction, } from "../lifecycleHooks/types.js"; import { RunTags } from "../schemas/api.js"; import { @@ -88,28 +89,36 @@ export type RunFnParams = Prettify<{ ctx: Context; /** If you use the `init` function, this will be whatever you returned. */ init?: TInitOutput; - /** Abort signal that is aborted when a task run exceeds it's maxDuration. Can be used to automatically cancel downstream requests */ - signal?: AbortSignal; + /** Abort signal that is aborted when a task run exceeds it's maxDuration or if the task run is cancelled. Can be used to automatically cancel downstream requests */ + signal: AbortSignal; }>; export type MiddlewareFnParams = Prettify<{ ctx: Context; next: () => Promise; - /** Abort signal that is aborted when a task run exceeds it's maxDuration. Can be used to automatically cancel downstream requests */ - signal?: AbortSignal; + /** Abort signal that is aborted when a task run exceeds it's maxDuration or if the task run is cancelled. Can be used to automatically cancel downstream requests */ + signal: AbortSignal; }>; export type InitFnParams = Prettify<{ ctx: Context; - /** Abort signal that is aborted when a task run exceeds it's maxDuration. Can be used to automatically cancel downstream requests */ - signal?: AbortSignal; + /** Abort signal that is aborted when a task run exceeds it's maxDuration or if the task run is cancelled. Can be used to automatically cancel downstream requests */ + signal: AbortSignal; }>; export type StartFnParams = Prettify<{ ctx: Context; init?: InitOutput; - /** Abort signal that is aborted when a task run exceeds it's maxDuration. Can be used to automatically cancel downstream requests */ - signal?: AbortSignal; + /** Abort signal that is aborted when a task run exceeds it's maxDuration or if the task run is cancelled. Can be used to automatically cancel downstream requests */ + signal: AbortSignal; +}>; + +export type CancelFnParams = Prettify<{ + ctx: Context; + /** Abort signal that is aborted when a task run exceeds it's maxDuration or if the task run is cancelled. Can be used to automatically cancel downstream requests */ + signal: AbortSignal; + runPromise: Promise; + init?: InitOutput; }>; export type Context = TaskRunContext; @@ -296,6 +305,7 @@ type CommonTaskOptions< onResume?: OnResumeHookFunction; onWait?: OnWaitHookFunction; onComplete?: OnCompleteHookFunction; + onCancel?: OnCancelHookFunction; /** * middleware allows you to run code "around" the run function. This can be useful for logging, metrics, or other cross-cutting concerns. diff --git a/packages/core/src/v3/usage-api.ts b/packages/core/src/v3/usage-api.ts index 1bafc04f09..be0f0d1c28 100644 --- a/packages/core/src/v3/usage-api.ts +++ b/packages/core/src/v3/usage-api.ts @@ -3,3 +3,5 @@ import { UsageAPI } from "./usage/api.js"; /** Entrypoint for usage API */ export const usage = UsageAPI.getInstance(); + +export type { UsageMeasurement, UsageSample } from "./usage/types.js"; diff --git a/packages/core/src/v3/usage/devUsageManager.ts b/packages/core/src/v3/usage/devUsageManager.ts index fea5d2fa97..d8baa935f7 100644 --- a/packages/core/src/v3/usage/devUsageManager.ts +++ b/packages/core/src/v3/usage/devUsageManager.ts @@ -74,7 +74,9 @@ export class DevUsageManager implements UsageManager { const sample = measurement.sample(); - this._currentMeasurements.delete(measurement.id); + if (this._currentMeasurements.has(measurement.id)) { + this._currentMeasurements.delete(measurement.id); + } return sample; } diff --git a/packages/core/src/v3/workers/taskExecutor.ts b/packages/core/src/v3/workers/taskExecutor.ts index 8a13ab3bc0..74acc8dc82 100644 --- a/packages/core/src/v3/workers/taskExecutor.ts +++ b/packages/core/src/v3/workers/taskExecutor.ts @@ -1,4 +1,4 @@ -import { SpanKind } from "@opentelemetry/api"; +import { Context, context, SpanKind, trace } from "@opentelemetry/api"; import { VERSION } from "../../version.js"; import { ApiError, RateLimitError } from "../apiClient/errors.js"; import { ConsoleInterceptor } from "../consoleInterceptor.js"; @@ -51,6 +51,7 @@ import { stringifyIO, } from "../utils/ioSerialization.js"; import { calculateNextRetryDelay } from "../utils/retries.js"; +import { promiseWithResolvers } from "../../utils.js"; export type TaskExecutorOptions = { tracingSDK: TracingSDK; @@ -90,7 +91,7 @@ export class TaskExecutor { execution: TaskRunExecution, worker: ServerBackgroundWorker, traceContext: Record, - signal?: AbortSignal, + signal: AbortSignal, isWarmStart?: boolean ): Promise<{ result: TaskRunExecutionResult }> { const ctx = TaskRunContext.parse(execution); @@ -120,6 +121,8 @@ export class TaskExecutor { const result = await this._tracer.startActiveSpan( attemptMessage, async (span) => { + const attemptContext = context.active(); + return await this._consoleInterceptor.intercept(console, async () => { let parsedPayload: any; let initOutput: any; @@ -150,6 +153,26 @@ export class TaskExecutor { await this.#callOnResumeFunctions(wait, parsedPayload, ctx, initOutput, signal); }); + const { + promise: runPromise, + resolve: runResolve, + reject: runReject, + } = promiseWithResolvers(); + + // Make sure the run promise does not cause unhandled promise rejections + runPromise.catch(() => {}); + + lifecycleHooks.registerOnCancelHookListener(async () => { + await this.#callOnCancelFunctions( + runPromise, + parsedPayload, + ctx, + initOutput, + signal, + attemptContext + ); + }); + const executeTask = async (payload: any) => { const [runError, output] = await tryCatch( (async () => { @@ -172,6 +195,8 @@ export class TaskExecutor { ); if (runError) { + runReject(runError); + const [handleErrorError, handleErrorResult] = await tryCatch( this.#handleError(execution, runError, payload, ctx, initOutput, signal) ); @@ -220,6 +245,8 @@ export class TaskExecutor { } satisfies TaskRunExecutionResult; } + runResolve(output); + const [outputError, stringifiedOutput] = await tryCatch(stringifyIO(output)); if (outputError) { @@ -336,7 +363,7 @@ export class TaskExecutor { execution: TaskRunExecution, hooks: RegisteredHookFunction[], executeTask: (payload: unknown) => Promise, - signal?: AbortSignal + signal: AbortSignal ) { let output: any; let executeError: unknown; @@ -384,7 +411,7 @@ export class TaskExecutor { return output; } - async #callRun(payload: unknown, ctx: TaskRunContext, init: unknown, signal?: AbortSignal) { + async #callRun(payload: unknown, ctx: TaskRunContext, init: unknown, signal: AbortSignal) { const runFn = this.task.fns.run; if (!runFn) { @@ -392,30 +419,29 @@ export class TaskExecutor { } // Create a promise that rejects when the signal aborts - const abortPromise = signal - ? new Promise((_, reject) => { - signal.addEventListener("abort", () => { - const maxDuration = ctx.run.maxDuration; - reject( - new InternalError({ - code: TaskRunErrorCodes.MAX_DURATION_EXCEEDED, - message: `Run exceeded maximum compute time (maxDuration) of ${maxDuration} seconds`, - }) - ); - }); - }) - : undefined; + const abortPromise = new Promise((_, reject) => { + signal.addEventListener("abort", () => { + if (typeof signal.reason === "string" && signal.reason.includes("cancel")) { + console.log("abortPromise: cancel"); + return; + } + + const maxDuration = ctx.run.maxDuration; + reject( + new InternalError({ + code: TaskRunErrorCodes.MAX_DURATION_EXCEEDED, + message: `Run exceeded maximum compute time (maxDuration) of ${maxDuration} seconds`, + }) + ); + }); + }); return runTimelineMetrics.measureMetric("trigger.dev/execution", "run", async () => { return await this._tracer.startActiveSpan( "run()", async (span) => { - if (abortPromise) { - // Race between the run function and the abort promise - return await Promise.race([runFn(payload, { ctx, init, signal }), abortPromise]); - } - - return await runFn(payload, { ctx, init, signal }); + // Race between the run function and the abort promise + return await Promise.race([runFn(payload, { ctx, init, signal }), abortPromise]); }, { attributes: { [SemanticInternalAttributes.STYLE_ICON]: "task-fn-run" }, @@ -429,7 +455,7 @@ export class TaskExecutor { payload: unknown, ctx: TaskRunContext, initOutput: TaskInitOutput, - signal?: AbortSignal + signal: AbortSignal ) { const globalWaitHooks = lifecycleHooks.getGlobalWaitHooks(); const taskWaitHook = lifecycleHooks.getTaskWaitHook(this.task.id); @@ -496,12 +522,94 @@ export class TaskExecutor { ); } + async #callOnCancelFunctions( + runPromise: Promise, + payload: unknown, + ctx: TaskRunContext, + initOutput: TaskInitOutput, + signal: AbortSignal, + attemptContext: Context + ) { + const globalCancelHooks = lifecycleHooks.getGlobalCancelHooks(); + const taskCancelHook = lifecycleHooks.getTaskCancelHook(this.task.id); + + if (globalCancelHooks.length === 0 && !taskCancelHook) { + return; + } + + const result = await runTimelineMetrics.measureMetric( + "trigger.dev/execution", + "onCancel", + async () => { + for (const hook of globalCancelHooks) { + const [hookError] = await tryCatch( + this._tracer.startActiveSpan( + "onCancel()", + async (span) => { + await hook.fn({ + payload, + ctx, + signal, + task: this.task.id, + init: initOutput, + runPromise, + }); + }, + { + attributes: { + [SemanticInternalAttributes.STYLE_ICON]: "task-hook-onCancel", + [SemanticInternalAttributes.COLLAPSED]: true, + ...this.#lifecycleHookAccessoryAttributes(hook.name), + }, + }, + attemptContext + ) + ); + + if (hookError) { + throw hookError; + } + } + + if (taskCancelHook) { + const [hookError] = await tryCatch( + this._tracer.startActiveSpan( + "onCancel()", + async (span) => { + await taskCancelHook({ + payload, + ctx, + signal, + task: this.task.id, + init: initOutput, + runPromise, + }); + }, + { + attributes: { + [SemanticInternalAttributes.STYLE_ICON]: "task-hook-onCancel", + [SemanticInternalAttributes.COLLAPSED]: true, + ...this.#lifecycleHookAccessoryAttributes("task"), + }, + }, + attemptContext + ) + ); + + if (hookError) { + throw hookError; + } + } + } + ); + } + async #callOnResumeFunctions( wait: TaskWait, payload: unknown, ctx: TaskRunContext, initOutput: TaskInitOutput, - signal?: AbortSignal + signal: AbortSignal ) { const globalResumeHooks = lifecycleHooks.getGlobalResumeHooks(); const taskResumeHook = lifecycleHooks.getTaskResumeHook(this.task.id); @@ -568,7 +676,7 @@ export class TaskExecutor { ); } - async #callInitFunctions(payload: unknown, ctx: TaskRunContext, signal?: AbortSignal) { + async #callInitFunctions(payload: unknown, ctx: TaskRunContext, signal: AbortSignal) { const globalInitHooks = lifecycleHooks.getGlobalInitHooks(); const taskInitHook = lifecycleHooks.getTaskInitHook(this.task.id); @@ -671,7 +779,7 @@ export class TaskExecutor { output: any, ctx: TaskRunContext, initOutput: any, - signal?: AbortSignal + signal: AbortSignal ) { const globalSuccessHooks = lifecycleHooks.getGlobalSuccessHooks(); const taskSuccessHook = lifecycleHooks.getTaskSuccessHook(this.task.id); @@ -746,7 +854,7 @@ export class TaskExecutor { error: unknown, ctx: TaskRunContext, initOutput: any, - signal?: AbortSignal + signal: AbortSignal ) { const globalFailureHooks = lifecycleHooks.getGlobalFailureHooks(); const taskFailureHook = lifecycleHooks.getTaskFailureHook(this.task.id); @@ -832,7 +940,7 @@ export class TaskExecutor { payload: unknown, ctx: TaskRunContext, initOutput: any, - signal?: AbortSignal + signal: AbortSignal ) { const globalStartHooks = lifecycleHooks.getGlobalStartHooks(); const taskStartHook = lifecycleHooks.getTaskStartHook(this.task.id); @@ -898,7 +1006,7 @@ export class TaskExecutor { payload: unknown, ctx: TaskRunContext, initOutput: any, - signal?: AbortSignal + signal: AbortSignal ) { await this.#callCleanupFunctions(payload, ctx, initOutput, signal); await this.#blockForWaitUntil(); @@ -908,7 +1016,7 @@ export class TaskExecutor { payload: unknown, ctx: TaskRunContext, initOutput: any, - signal?: AbortSignal + signal: AbortSignal ) { const globalCleanupHooks = lifecycleHooks.getGlobalCleanupHooks(); const taskCleanupHook = lifecycleHooks.getTaskCleanupHook(this.task.id); @@ -1001,7 +1109,7 @@ export class TaskExecutor { payload: any, ctx: TaskRunContext, init: TaskInitOutput, - signal?: AbortSignal + signal: AbortSignal ): Promise< | { status: "retry"; retry: TaskRunExecutionRetry; error?: unknown } | { status: "skipped"; error?: unknown } @@ -1191,7 +1299,7 @@ export class TaskExecutor { result: TaskCompleteResult, ctx: TaskRunContext, initOutput: any, - signal?: AbortSignal + signal: AbortSignal ) { const globalCompleteHooks = lifecycleHooks.getGlobalCompleteHooks(); const taskCompleteHook = lifecycleHooks.getTaskCompleteHook(this.task.id); diff --git a/packages/core/test/taskExecutor.test.ts b/packages/core/test/taskExecutor.test.ts index c8e8e10bfe..531a872e4f 100644 --- a/packages/core/test/taskExecutor.test.ts +++ b/packages/core/test/taskExecutor.test.ts @@ -1905,5 +1905,7 @@ function executeTask( engine: "V2", }; - return executor.execute(execution, worker, {}, signal); + const $signal = signal ? signal : new AbortController().signal; + + return executor.execute(execution, worker, {}, $signal); } diff --git a/packages/trigger-sdk/src/v3/hooks.ts b/packages/trigger-sdk/src/v3/hooks.ts index d864f4ec8e..b4e9cd0988 100644 --- a/packages/trigger-sdk/src/v3/hooks.ts +++ b/packages/trigger-sdk/src/v3/hooks.ts @@ -11,6 +11,7 @@ import { type AnyOnResumeHookFunction, type AnyOnCatchErrorHookFunction, type AnyOnMiddlewareHookFunction, + type AnyOnCancelHookFunction, } from "@trigger.dev/core/v3"; export type { @@ -25,6 +26,7 @@ export type { AnyOnResumeHookFunction, AnyOnCatchErrorHookFunction, AnyOnMiddlewareHookFunction, + AnyOnCancelHookFunction, }; export function onStart(name: string, fn: AnyOnStartHookFunction): void; @@ -131,3 +133,15 @@ export function middleware( fn: typeof fnOrName === "function" ? fnOrName : fn!, }); } + +export function onCancel(name: string, fn: AnyOnCancelHookFunction): void; +export function onCancel(fn: AnyOnCancelHookFunction): void; +export function onCancel( + fnOrName: string | AnyOnCancelHookFunction, + fn?: AnyOnCancelHookFunction +): void { + lifecycleHooks.registerGlobalCancelHook({ + id: typeof fnOrName === "string" ? fnOrName : fnOrName.name ? fnOrName.name : undefined, + fn: typeof fnOrName === "function" ? fnOrName : fn!, + }); +} diff --git a/packages/trigger-sdk/src/v3/shared.ts b/packages/trigger-sdk/src/v3/shared.ts index a5b3fca392..801a138b5a 100644 --- a/packages/trigger-sdk/src/v3/shared.ts +++ b/packages/trigger-sdk/src/v3/shared.ts @@ -42,6 +42,7 @@ import type { AnyOnStartHookFunction, AnyOnSuccessHookFunction, AnyOnWaitHookFunction, + AnyOnCancelHookFunction, AnyRunHandle, AnyRunTypes, AnyTask, @@ -1637,4 +1638,10 @@ function registerTaskLifecycleHooks< fn: params.cleanup as AnyOnCleanupHookFunction, }); } + + if (params.onCancel) { + lifecycleHooks.registerTaskCancelHook(taskId, { + fn: params.onCancel as AnyOnCancelHookFunction, + }); + } } diff --git a/packages/trigger-sdk/src/v3/tasks.ts b/packages/trigger-sdk/src/v3/tasks.ts index a6089d090e..71cb6acd5b 100644 --- a/packages/trigger-sdk/src/v3/tasks.ts +++ b/packages/trigger-sdk/src/v3/tasks.ts @@ -8,6 +8,7 @@ import { onHandleError, onCatchError, middleware, + onCancel, } from "./hooks.js"; import { batchTrigger, @@ -95,6 +96,7 @@ export const tasks = { onComplete, onWait, onResume, + onCancel, /** @deprecated Use catchError instead */ handleError: onHandleError, catchError: onCatchError, diff --git a/references/d3-chat/src/trigger/chat.ts b/references/d3-chat/src/trigger/chat.ts index 184613b36f..cd7a6de71f 100644 --- a/references/d3-chat/src/trigger/chat.ts +++ b/references/d3-chat/src/trigger/chat.ts @@ -1,7 +1,7 @@ import { anthropic } from "@ai-sdk/anthropic"; import { openai } from "@ai-sdk/openai"; import { ai } from "@trigger.dev/sdk/ai"; -import { logger, metadata, schemaTask, wait } from "@trigger.dev/sdk/v3"; +import { logger, metadata, schemaTask, tasks, wait } from "@trigger.dev/sdk/v3"; import { sql } from "@vercel/postgres"; import { streamText, TextStreamPart, tool } from "ai"; import { nanoid } from "nanoid"; @@ -110,7 +110,7 @@ export const todoChat = schemaTask({ ), userId: z.string(), }), - run: async ({ input, userId }) => { + run: async ({ input, userId }, { signal }) => { metadata.set("user_id", userId); const system = ` @@ -157,6 +157,8 @@ export const todoChat = schemaTask({ const prompt = input; + const chunks: TextStreamPart[] = []; + const result = streamText({ model: getModel(), system, @@ -174,6 +176,10 @@ export const todoChat = schemaTask({ experimental_telemetry: { isEnabled: true, }, + abortSignal: signal, + onChunk: ({ chunk }) => { + chunks.push(chunk); + }, }); const stream = await metadata.stream("fullStream", result.fullStream); @@ -213,3 +219,49 @@ function getModel() { return anthropic("claude-3-5-sonnet-latest"); } } + +export const interruptibleChat = schemaTask({ + id: "interruptible-chat", + description: "Chat with the AI", + schema: z.object({ + prompt: z.string().describe("The prompt to chat with the AI"), + }), + run: async ({ prompt }, { signal }) => { + const chunks: TextStreamPart<{}>[] = []; + + // 👇 This is a global onCancel hook, but it's inside of the run function + tasks.onCancel(async () => { + // We have access to the chunks here + logger.info("interruptible-chat: task cancelled with chunks", { chunks }); + }); + + try { + const result = streamText({ + model: getModel(), + prompt, + experimental_telemetry: { + isEnabled: true, + }, + tools: {}, + abortSignal: signal, + onChunk: ({ chunk }) => { + chunks.push(chunk); + }, + }); + + const textParts = []; + + for await (const part of result.textStream) { + textParts.push(part); + } + + return textParts.join(""); + } catch (error) { + if (error instanceof Error && error.name === "AbortError") { + // streamText will throw an AbortError if the signal is aborted, so we can handle it here + } else { + throw error; + } + } + }, +}); diff --git a/references/hello-world/src/trigger/example.ts b/references/hello-world/src/trigger/example.ts index 8764638ebd..1a475d11de 100644 --- a/references/hello-world/src/trigger/example.ts +++ b/references/hello-world/src/trigger/example.ts @@ -1,4 +1,4 @@ -import { batch, logger, task, timeout, wait } from "@trigger.dev/sdk"; +import { batch, logger, task, tasks, timeout, wait } from "@trigger.dev/sdk"; import { setTimeout } from "timers/promises"; import { ResourceMonitor } from "../resourceMonitor.js"; @@ -207,6 +207,54 @@ export const hooksTask = task({ cleanup: async ({ ctx, payload }) => { logger.info("Hello, world from the cleanup hook", { payload }); }, + onCancel: async ({ payload }) => { + logger.info("Hello, world from the onCancel hook", { payload }); + }, +}); + +export const cancelExampleTask = task({ + id: "cancel-example", + // Signal will be aborted when the task is cancelled 👇 + run: async (payload: { timeoutInSeconds: number }, { signal }) => { + logger.info("Hello, world from the cancel task", { + timeoutInSeconds: payload.timeoutInSeconds, + }); + + // This is a global hook that will be called if the task is cancelled + tasks.onCancel(async () => { + logger.info("global task onCancel hook but inside of the run function baby!"); + }); + + await logger.trace("timeout", async (span) => { + try { + // We pass the signal to setTimeout to abort the timeout if the task is cancelled + await setTimeout(payload.timeoutInSeconds * 1000, undefined, { signal }); + } catch (error) { + // If the timeout is aborted, this error will be thrown, we can handle it here + logger.error("Timeout error", { error }); + } + }); + + logger.info("Hello, world from the cancel task after the timeout", { + timeoutInSeconds: payload.timeoutInSeconds, + }); + + return { + message: "Hello, world!", + }; + }, + onCancel: async ({ payload, runPromise }) => { + logger.info("Hello, world from the onCancel hook", { payload }); + // You can await the runPromise to get the output of the task + const output = await runPromise; + + logger.info("Hello, world from the onCancel hook after the run", { payload, output }); + + // You can do work inside the onCancel hook, up to 30 seconds + await setTimeout(10_000); + + logger.info("Hello, world from the onCancel hook after the timeout", { payload }); + }, }); export const resourceMonitorTest = task({ diff --git a/references/hello-world/src/trigger/init.ts b/references/hello-world/src/trigger/init.ts index e57bc89838..8512d395ca 100644 --- a/references/hello-world/src/trigger/init.ts +++ b/references/hello-world/src/trigger/init.ts @@ -6,6 +6,10 @@ tasks.middleware("db", ({ ctx, payload, next }) => { return next(); }); +tasks.onCancel(async ({ ctx, payload }) => { + logger.info("Hello, world from the global cancel", { ctx, payload }); +}); + // tasks.onSuccess(({ ctx, payload, output }) => { // logger.info("Hello, world from the success", { ctx, payload }); // });