Skip to content

Add support for AbortSignal to all API methods #339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1213,15 +1213,21 @@ Low-level method used by the Replicate client to interact with API endpoints.
const response = await replicate.request(route, parameters);
```

| name | type | description |
| -------------------- | ------ | ------------------------------------------------------------ |
| `options.route` | string | Required. REST API endpoint path. |
| `options.parameters` | object | URL, query, and request body parameters for the given route. |
| name | type | description |
| -------------------- | ------------------- | ----------- |
| `options.route` | `string` | Required. REST API endpoint path.
| `options.params` | `object` | URL query parameters for the given route. |
| `options.method` | `string` | HTTP method for the given route. |
| `options.headers` | `object` | Additional HTTP headers for the given route. |
| `options.data` | `object \| FormData` | Request body. |
| `options.signal` | `AbortSignal` | Optional `AbortSignal`. |

The `replicate.request()` method is used by the other methods
to interact with the Replicate API.
You can call this method directly to make other requests to the API.

The method accepts an `AbortSignal` which can be used to cancel the request in flight.

### `FileOutput`

`FileOutput` is a `ReadableStream` instance that represents a model file output. It can be used to stream file data to disk or as a `Response` body to an HTTP request.
Expand Down
114 changes: 81 additions & 33 deletions index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,14 @@ declare module "replicate" {
headers?: object | Headers;
params?: object;
data?: object;
signal?: AbortSignal;
}
): Promise<Response>;

paginate<T>(endpoint: () => Promise<Page<T>>): AsyncGenerator<[T]>;
paginate<T>(
endpoint: () => Promise<Page<T>>,
options?: { signal?: AbortSignal }
): AsyncGenerator<T[]>;

wait(
prediction: Prediction,
Expand All @@ -197,12 +201,15 @@ declare module "replicate" {
): Promise<Prediction>;

accounts: {
current(): Promise<Account>;
current(options?: { signal?: AbortSignal }): Promise<Account>;
};

collections: {
list(): Promise<Page<Collection>>;
get(collection_slug: string): Promise<Collection>;
list(options?: { signal?: AbortSignal }): Promise<Page<Collection>>;
get(
collection_slug: string,
options?: { signal?: AbortSignal }
): Promise<Collection>;
};

deployments: {
Expand All @@ -217,21 +224,26 @@ declare module "replicate" {
webhook?: string;
webhook_events_filter?: WebhookEventType[];
wait?: number | boolean;
signal?: AbortSignal;
}
): Promise<Prediction>;
};
get(
deployment_owner: string,
deployment_name: string
deployment_name: string,
options?: { signal?: AbortSignal }
): Promise<Deployment>;
create(
deployment_config: {
name: string;
model: string;
version: string;
hardware: string;
min_instances: number;
max_instances: number;
},
options?: { signal?: AbortSignal }
): Promise<Deployment>;
create(deployment_config: {
name: string;
model: string;
version: string;
hardware: string;
min_instances: number;
max_instances: number;
}): Promise<Deployment>;
update(
deployment_owner: string,
deployment_name: string,
Expand All @@ -245,32 +257,45 @@ declare module "replicate" {
| { hardware: string }
| { min_instances: number }
| { max_instances: number }
)
),
options?: { signal?: AbortSignal }
): Promise<Deployment>;
delete(
deployment_owner: string,
deployment_name: string
deployment_name: string,
options?: { signal?: AbortSignal }
): Promise<boolean>;
list(): Promise<Page<Deployment>>;
list(options?: { signal?: AbortSignal }): Promise<Page<Deployment>>;
};

files: {
create(
file: Blob | Buffer,
metadata?: Record<string, unknown>
metadata?: Record<string, unknown>,
options?: { signal?: AbortSignal }
): Promise<FileObject>;
list(): Promise<Page<FileObject>>;
get(file_id: string): Promise<FileObject>;
delete(file_id: string): Promise<boolean>;
list(options?: { signal?: AbortSignal }): Promise<Page<FileObject>>;
get(
file_id: string,
options?: { signal?: AbortSignal }
): Promise<FileObject>;
delete(
file_id: string,
options?: { signal?: AbortSignal }
): Promise<boolean>;
};

hardware: {
list(): Promise<Hardware[]>;
list(options?: { signal?: AbortSignal }): Promise<Hardware[]>;
};

models: {
get(model_owner: string, model_name: string): Promise<Model>;
list(): Promise<Page<Model>>;
get(
model_owner: string,
model_name: string,
options?: { signal?: AbortSignal }
): Promise<Model>;
list(options?: { signal?: AbortSignal }): Promise<Page<Model>>;
create(
model_owner: string,
model_name: string,
Expand All @@ -282,17 +307,26 @@ declare module "replicate" {
paper_url?: string;
license_url?: string;
cover_image_url?: string;
signal?: AbortSignal;
}
): Promise<Model>;
versions: {
list(model_owner: string, model_name: string): Promise<ModelVersion[]>;
list(
model_owner: string,
model_name: string,
options?: { signal?: AbortSignal }
): Promise<ModelVersion[]>;
get(
model_owner: string,
model_name: string,
version_id: string
version_id: string,
options?: { signal?: AbortSignal }
): Promise<ModelVersion>;
};
search(query: string): Promise<Page<Model>>;
search(
query: string,
options?: { signal?: AbortSignal }
): Promise<Page<Model>>;
};

predictions: {
Expand All @@ -306,11 +340,18 @@ declare module "replicate" {
webhook?: string;
webhook_events_filter?: WebhookEventType[];
wait?: boolean | number;
signal?: AbortSignal;
} & ({ version: string } | { model: string })
): Promise<Prediction>;
get(prediction_id: string): Promise<Prediction>;
cancel(prediction_id: string): Promise<Prediction>;
list(): Promise<Page<Prediction>>;
get(
prediction_id: string,
options?: { signal?: AbortSignal }
): Promise<Prediction>;
cancel(
prediction_id: string,
options?: { signal?: AbortSignal }
): Promise<Prediction>;
list(options?: { signal?: AbortSignal }): Promise<Page<Prediction>>;
};

trainings: {
Expand All @@ -323,17 +364,24 @@ declare module "replicate" {
input: object;
webhook?: string;
webhook_events_filter?: WebhookEventType[];
signal?: AbortSignal;
}
): Promise<Training>;
get(training_id: string): Promise<Training>;
cancel(training_id: string): Promise<Training>;
list(): Promise<Page<Training>>;
get(
training_id: string,
options?: { signal?: AbortSignal }
): Promise<Training>;
cancel(
training_id: string,
options?: { signal?: AbortSignal }
): Promise<Training>;
list(options?: { signal?: AbortSignal }): Promise<Page<Training>>;
};

webhooks: {
default: {
secret: {
get(): Promise<WebhookSecret>;
get(options?: { signal?: AbortSignal }): Promise<WebhookSecret>;
};
};
};
Expand Down
17 changes: 12 additions & 5 deletions index.js
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ class Replicate {
* @param {object} [options.params] - Query parameters
* @param {object|Headers} [options.headers] - HTTP headers
* @param {object} [options.data] - Body parameters
* @param {AbortSignal} [options.signal] - AbortSignal to cancel the request
* @returns {Promise<Response>} - Resolves with the response object
* @throws {ApiError} If the request failed
*/
Expand All @@ -241,7 +242,7 @@ class Replicate {
);
}

const { method = "GET", params = {}, data } = options;
const { method = "GET", params = {}, data, signal } = options;

for (const [key, value] of Object.entries(params)) {
url.searchParams.append(key, value);
Expand Down Expand Up @@ -273,6 +274,7 @@ class Replicate {
method,
headers,
body,
signal,
};

const shouldRetry =
Expand Down Expand Up @@ -354,15 +356,20 @@ class Replicate {
* console.log(page);
* }
* @param {Function} endpoint - Function that returns a promise for the next page of results
* @param {object} [options]
* @param {AbortSignal} [options.signal] - AbortSignal to cancel the request.
* @yields {object[]} Each page of results
*/
async *paginate(endpoint) {
async *paginate(endpoint, options = {}) {
const response = await endpoint();
yield response.results;
if (response.next) {
if (response.next && !(options.signal && options.signal.aborted)) {
const nextPage = () =>
this.request(response.next, { method: "GET" }).then((r) => r.json());
yield* this.paginate(nextPage);
this.request(response.next, {
method: "GET",
signal: options.signal,
}).then((r) => r.json());
yield* this.paginate(nextPage, options);
}
}

Expand Down
84 changes: 84 additions & 0 deletions index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,90 @@ describe("Replicate client", () => {
});
});

describe("paginate", () => {
test("pages through results", async () => {
nock(BASE_URL)
.get("/collections")
.reply(200, {
results: [
{
name: "Super resolution",
slug: "super-resolution",
description:
"Upscaling models that create high-quality images from low-quality images.",
},
],
next: `${BASE_URL}/collections?page=2`,
previous: null,
});
nock(BASE_URL)
.get("/collections?page=2")
.reply(200, {
results: [
{
name: "Image classification",
slug: "image-classification",
description: "Models that classify images.",
},
],
next: null,
previous: null,
});

const iterator = client.paginate(client.collections.list);

const firstPage = (await iterator.next()).value;
expect(firstPage.length).toBe(1);

const secondPage = (await iterator.next()).value;
expect(secondPage.length).toBe(1);
});

test("accepts an abort signal", async () => {
nock(BASE_URL)
.get("/collections")
.reply(200, {
results: [
{
name: "Super resolution",
slug: "super-resolution",
description:
"Upscaling models that create high-quality images from low-quality images.",
},
],
next: `${BASE_URL}/collections?page=2`,
previous: null,
});
nock(BASE_URL)
.get("/collections?page=2")
.reply(200, {
results: [
{
name: "Image classification",
slug: "image-classification",
description: "Models that classify images.",
},
],
next: null,
previous: null,
});

const controller = new AbortController();
const iterator = client.paginate(client.collections.list, {
signal: controller.signal,
});

const firstIteration = await iterator.next();
expect(firstIteration.value.length).toBe(1);

controller.abort();

const secondIteration = await iterator.next();
expect(secondIteration.value).toBeUndefined();
expect(secondIteration.done).toBe(true);
});
});

describe("account.get", () => {
test("Calls the correct API route", async () => {
nock(BASE_URL).get("/account").reply(200, {
Expand Down
8 changes: 4 additions & 4 deletions integration/next/pages/index.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
export default () => (
<main>
<h1>Welcome to Next.js</h1>
</main>
)
<main>
<h1>Welcome to Next.js</h1>
</main>
);
Loading