From c1a12b0f60122442c4115c23cc964fa96f2042fc Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Tue, 1 Oct 2024 23:44:15 +0100 Subject: [PATCH] Bug fixes for the `wait` option in `replicate.run` (#315) There were a couple of small bugs in the current implementation: 1. We would pass non-boolean, non-integer values through to `predictions.create` when it was an object with an interval, resulting in the blocking mode being used accidentally. 2. We would pass the boolean/integer values through to `wait` which would create a runtime error when the `wait` function expects an object. 3. We continued to poll for the prediction despite the blocking response returning the output data. This PR addresses these three issues by checking if the run should be blocking and passing the correct arguments in the correct places. We also assume that if the returned prediction is not in `starting` state then it is completed. This isn't ideal but works for the moment. Lastly, in the case where the blocking request times out the client will fall back to polling at the default interval. --- index.d.ts | 10 +++++----- index.js | 40 ++++++++++++++++++++++------------------ 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/index.d.ts b/index.d.ts index eabcc9b..45a2430 100644 --- a/index.d.ts +++ b/index.d.ts @@ -93,9 +93,9 @@ declare module "replicate" { model: string; version: string; input: object; - output?: any; + output?: any; // TODO: this should be `unknown` source: "api" | "web"; - error?: any; + error?: unknown; logs?: string; metrics?: { predict_time?: number; @@ -156,7 +156,7 @@ declare module "replicate" { identifier: `${string}/${string}` | `${string}/${string}:${string}`, options: { input: object; - wait?: boolean | number | { mode?: "poll"; interval?: number }; + wait?: boolean | number | { interval?: number }; webhook?: string; webhook_events_filter?: WebhookEventType[]; signal?: AbortSignal; @@ -215,7 +215,7 @@ declare module "replicate" { stream?: boolean; webhook?: string; webhook_events_filter?: WebhookEventType[]; - block?: boolean; + wait?: boolean | number | { mode?: "poll"; interval?: number }; } ): Promise; }; @@ -304,7 +304,7 @@ declare module "replicate" { stream?: boolean; webhook?: string; webhook_events_filter?: WebhookEventType[]; - block?: boolean; + wait?: boolean | number | { mode?: "poll"; interval?: number }; } & ({ version: string } | { model: string }) ): Promise; get(prediction_id: string): Promise; diff --git a/index.js b/index.js index 712bc59..ac4b815 100644 --- a/index.js +++ b/index.js @@ -147,19 +147,20 @@ class Replicate { const { wait, signal, ...data } = options; const identifier = ModelVersionIdentifier.parse(ref); + const isBlocking = typeof wait === "boolean" || typeof wait === "number"; let prediction; if (identifier.version) { prediction = await this.predictions.create({ ...data, version: identifier.version, - wait: wait, + wait: isBlocking ? wait : false, }); } else if (identifier.owner && identifier.name) { prediction = await this.predictions.create({ ...data, model: `${identifier.owner}/${identifier.name}`, - wait: wait, + wait: isBlocking ? wait : false, }); } else { throw new Error("Invalid model version identifier"); @@ -170,23 +171,26 @@ class Replicate { progress(prediction); } - prediction = await this.wait( - prediction, - wait || {}, - async (updatedPrediction) => { - // Call progress callback with the updated prediction object - if (progress) { - progress(updatedPrediction); + const isDone = isBlocking && prediction.status !== "starting"; + if (!isDone) { + prediction = await this.wait( + prediction, + isBlocking ? {} : wait, + async (updatedPrediction) => { + // Call progress callback with the updated prediction object + if (progress) { + progress(updatedPrediction); + } + + // We handle the cancel later in the function. + if (signal && signal.aborted) { + return true; // stop polling + } + + return false; // continue polling } - - // We handle the cancel later in the function. - if (signal && signal.aborted) { - return true; // stop polling - } - - return false; // continue polling - } - ); + ); + } if (signal && signal.aborted) { prediction = await this.predictions.cancel(prediction.id);