Skip to content

Commit

Permalink
Bug fixes for the wait option in replicate.run (#315)
Browse files Browse the repository at this point in the history
There were a couple of small bugs in the current implementation:

1. We would pass non-boolean, non-integer values through to `predictions.create` when it was an object with an interval, resulting in the blocking mode being used accidentally.
2. We would pass the boolean/integer values through to `wait` which would create a runtime error when the `wait` function expects an object.
3. We continued to poll for the prediction despite the blocking response returning the output data.

This PR addresses these three issues by checking if the run should be blocking and passing the correct arguments in the correct places. We also assume that if the returned prediction is not in `starting` state then it is completed. This isn't ideal but works for the moment.

Lastly, in  the case where the blocking request times out the client will fall back to polling at the default interval.
  • Loading branch information
aron authored Oct 1, 2024
1 parent be0f323 commit c1a12b0
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 23 deletions.
10 changes: 5 additions & 5 deletions index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ declare module "replicate" {
model: string;
version: string;
input: object;
output?: any;
output?: any; // TODO: this should be `unknown`
source: "api" | "web";
error?: any;
error?: unknown;
logs?: string;
metrics?: {
predict_time?: number;
Expand Down Expand Up @@ -156,7 +156,7 @@ declare module "replicate" {
identifier: `${string}/${string}` | `${string}/${string}:${string}`,
options: {
input: object;
wait?: boolean | number | { mode?: "poll"; interval?: number };
wait?: boolean | number | { interval?: number };
webhook?: string;
webhook_events_filter?: WebhookEventType[];
signal?: AbortSignal;
Expand Down Expand Up @@ -215,7 +215,7 @@ declare module "replicate" {
stream?: boolean;
webhook?: string;
webhook_events_filter?: WebhookEventType[];
block?: boolean;
wait?: boolean | number | { mode?: "poll"; interval?: number };
}
): Promise<Prediction>;
};
Expand Down Expand Up @@ -304,7 +304,7 @@ declare module "replicate" {
stream?: boolean;
webhook?: string;
webhook_events_filter?: WebhookEventType[];
block?: boolean;
wait?: boolean | number | { mode?: "poll"; interval?: number };
} & ({ version: string } | { model: string })
): Promise<Prediction>;
get(prediction_id: string): Promise<Prediction>;
Expand Down
40 changes: 22 additions & 18 deletions index.js
Original file line number Diff line number Diff line change
Expand Up @@ -147,19 +147,20 @@ class Replicate {
const { wait, signal, ...data } = options;

const identifier = ModelVersionIdentifier.parse(ref);
const isBlocking = typeof wait === "boolean" || typeof wait === "number";

let prediction;
if (identifier.version) {
prediction = await this.predictions.create({
...data,
version: identifier.version,
wait: wait,
wait: isBlocking ? wait : false,
});
} else if (identifier.owner && identifier.name) {
prediction = await this.predictions.create({
...data,
model: `${identifier.owner}/${identifier.name}`,
wait: wait,
wait: isBlocking ? wait : false,
});
} else {
throw new Error("Invalid model version identifier");
Expand All @@ -170,23 +171,26 @@ class Replicate {
progress(prediction);
}

prediction = await this.wait(
prediction,
wait || {},
async (updatedPrediction) => {
// Call progress callback with the updated prediction object
if (progress) {
progress(updatedPrediction);
const isDone = isBlocking && prediction.status !== "starting";
if (!isDone) {
prediction = await this.wait(
prediction,
isBlocking ? {} : wait,
async (updatedPrediction) => {
// Call progress callback with the updated prediction object
if (progress) {
progress(updatedPrediction);
}

// We handle the cancel later in the function.
if (signal && signal.aborted) {
return true; // stop polling
}

return false; // continue polling
}

// We handle the cancel later in the function.
if (signal && signal.aborted) {
return true; // stop polling
}

return false; // continue polling
}
);
);
}

if (signal && signal.aborted) {
prediction = await this.predictions.cancel(prediction.id);
Expand Down

0 comments on commit c1a12b0

Please sign in to comment.