Skip to content

Commit 6c7e05b

Browse files
committed
Add support for AbortSignal to all API methods
1 parent 0509238 commit 6c7e05b

11 files changed

+213
-64
lines changed

README.md

+8-2
Original file line numberDiff line numberDiff line change
@@ -1208,13 +1208,19 @@ const response = await replicate.request(route, parameters);
12081208

12091209
| name | type | description |
12101210
| -------------------- | ------ | ------------------------------------------------------------ |
1211-
| `options.route` | string | Required. REST API endpoint path. |
1212-
| `options.parameters` | object | URL, query, and request body parameters for the given route. |
1211+
| `options.route` | `string` | Required. REST API endpoint path. |
1212+
| `options.params` | `object` | URL query parameters for the given route. |
1213+
| `options.method` | `string` | HTTP method for the given route. |
1214+
| `options.headers` | `object` | Additional HTTP headers for the given route. |
1215+
| `options.data` | `object | FormData` | Request body. |
1216+
| `options.signal` | `AbortSignal` | Optional `AbortSignal`. |
12131217

12141218
The `replicate.request()` method is used by the other methods
12151219
to interact with the Replicate API.
12161220
You can call this method directly to make other requests to the API.
12171221

1222+
The method accepts an `AbortSignal` which can be used to cancel the request in flight.
1223+
12181224
### `FileOutput`
12191225

12201226
`FileOutput` is a `ReadableStream` instance that represents a model file output. It can be used to stream file data to disk or as a `Response` body to an HTTP request.

index.d.ts

+76-32
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,15 @@ declare module "replicate" {
201201
): Promise<Prediction>;
202202

203203
accounts: {
204-
current(): Promise<Account>;
204+
current(options?: { signal?: AbortSignal }): Promise<Account>;
205205
};
206206

207207
collections: {
208-
list(): Promise<Page<Collection>>;
209-
get(collection_slug: string): Promise<Collection>;
208+
list(options?: { signal?: AbortSignal }): Promise<Page<Collection>>;
209+
get(
210+
collection_slug: string,
211+
options?: { signal?: AbortSignal }
212+
): Promise<Collection>;
210213
};
211214

212215
deployments: {
@@ -221,21 +224,26 @@ declare module "replicate" {
221224
webhook?: string;
222225
webhook_events_filter?: WebhookEventType[];
223226
wait?: number | boolean;
227+
signal?: AbortSignal;
224228
}
225229
): Promise<Prediction>;
226230
};
227231
get(
228232
deployment_owner: string,
229-
deployment_name: string
233+
deployment_name: string,
234+
options?: { signal?: AbortSignal }
235+
): Promise<Deployment>;
236+
create(
237+
deployment_config: {
238+
name: string;
239+
model: string;
240+
version: string;
241+
hardware: string;
242+
min_instances: number;
243+
max_instances: number;
244+
},
245+
options?: { signal?: AbortSignal }
230246
): Promise<Deployment>;
231-
create(deployment_config: {
232-
name: string;
233-
model: string;
234-
version: string;
235-
hardware: string;
236-
min_instances: number;
237-
max_instances: number;
238-
}): Promise<Deployment>;
239247
update(
240248
deployment_owner: string,
241249
deployment_name: string,
@@ -249,32 +257,45 @@ declare module "replicate" {
249257
| { hardware: string }
250258
| { min_instances: number }
251259
| { max_instances: number }
252-
)
260+
),
261+
options?: { signal?: AbortSignal }
253262
): Promise<Deployment>;
254263
delete(
255264
deployment_owner: string,
256-
deployment_name: string
265+
deployment_name: string,
266+
options?: { signal?: AbortSignal }
257267
): Promise<boolean>;
258-
list(): Promise<Page<Deployment>>;
268+
list(options?: { signal?: AbortSignal }): Promise<Page<Deployment>>;
259269
};
260270

261271
files: {
262272
create(
263273
file: Blob | Buffer,
264-
metadata?: Record<string, unknown>
274+
metadata?: Record<string, unknown>,
275+
options?: { signal?: AbortSignal }
265276
): Promise<FileObject>;
266-
list(): Promise<Page<FileObject>>;
267-
get(file_id: string): Promise<FileObject>;
268-
delete(file_id: string): Promise<boolean>;
277+
list(options?: { signal?: AbortSignal }): Promise<Page<FileObject>>;
278+
get(
279+
file_id: string,
280+
options?: { signal?: AbortSignal }
281+
): Promise<FileObject>;
282+
delete(
283+
file_id: string,
284+
options?: { signal?: AbortSignal }
285+
): Promise<boolean>;
269286
};
270287

271288
hardware: {
272-
list(): Promise<Hardware[]>;
289+
list(options?: { signal?: AbortSignal }): Promise<Hardware[]>;
273290
};
274291

275292
models: {
276-
get(model_owner: string, model_name: string): Promise<Model>;
277-
list(): Promise<Page<Model>>;
293+
get(
294+
model_owner: string,
295+
model_name: string,
296+
options?: { signal?: AbortSignal }
297+
): Promise<Model>;
298+
list(options?: { signal?: AbortSignal }): Promise<Page<Model>>;
278299
create(
279300
model_owner: string,
280301
model_name: string,
@@ -286,17 +307,26 @@ declare module "replicate" {
286307
paper_url?: string;
287308
license_url?: string;
288309
cover_image_url?: string;
310+
signal?: AbortSignal;
289311
}
290312
): Promise<Model>;
291313
versions: {
292-
list(model_owner: string, model_name: string): Promise<ModelVersion[]>;
314+
list(
315+
model_owner: string,
316+
model_name: string,
317+
options?: { signal?: AbortSignal }
318+
): Promise<ModelVersion[]>;
293319
get(
294320
model_owner: string,
295321
model_name: string,
296-
version_id: string
322+
version_id: string,
323+
options?: { signal?: AbortSignal }
297324
): Promise<ModelVersion>;
298325
};
299-
search(query: string): Promise<Page<Model>>;
326+
search(
327+
query: string,
328+
options?: { signal?: AbortSignal }
329+
): Promise<Page<Model>>;
300330
};
301331

302332
predictions: {
@@ -310,11 +340,18 @@ declare module "replicate" {
310340
webhook?: string;
311341
webhook_events_filter?: WebhookEventType[];
312342
wait?: boolean | number;
343+
signal?: AbortSignal;
313344
} & ({ version: string } | { model: string })
314345
): Promise<Prediction>;
315-
get(prediction_id: string): Promise<Prediction>;
316-
cancel(prediction_id: string): Promise<Prediction>;
317-
list(): Promise<Page<Prediction>>;
346+
get(
347+
prediction_id: string,
348+
options?: { signal?: AbortSignal }
349+
): Promise<Prediction>;
350+
cancel(
351+
prediction_id: string,
352+
options?: { signal?: AbortSignal }
353+
): Promise<Prediction>;
354+
list(options?: { signal?: AbortSignal }): Promise<Page<Prediction>>;
318355
};
319356

320357
trainings: {
@@ -327,17 +364,24 @@ declare module "replicate" {
327364
input: object;
328365
webhook?: string;
329366
webhook_events_filter?: WebhookEventType[];
367+
signal?: AbortSignal;
330368
}
331369
): Promise<Training>;
332-
get(training_id: string): Promise<Training>;
333-
cancel(training_id: string): Promise<Training>;
334-
list(): Promise<Page<Training>>;
370+
get(
371+
training_id: string,
372+
options?: { signal?: AbortSignal }
373+
): Promise<Training>;
374+
cancel(
375+
training_id: string,
376+
options?: { signal?: AbortSignal }
377+
): Promise<Training>;
378+
list(options?: { signal?: AbortSignal }): Promise<Page<Training>>;
335379
};
336380

337381
webhooks: {
338382
default: {
339383
secret: {
340-
get(): Promise<WebhookSecret>;
384+
get(options?: { signal?: AbortSignal }): Promise<WebhookSecret>;
341385
};
342386
};
343387
};

lib/accounts.js

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
/**
22
* Get the current account
33
*
4+
* @param {object} [options]
5+
* @param {AbortSignal} [options.signal] - An optional AbortSignal
46
* @returns {Promise<object>} Resolves with the current account
57
*/
6-
async function getCurrentAccount() {
8+
async function getCurrentAccount({ signal } = {}) {
79
const response = await this.request("/account", {
810
method: "GET",
11+
signal,
912
});
1013

1114
return response.json();

lib/collections.js

+8-2
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
* Fetch a model collection
33
*
44
* @param {string} collection_slug - Required. The slug of the collection. See http://replicate.com/collections
5+
* @param {object} [options]
6+
* @param {AbortSignal} [options.signal] - An optional AbortSignal
57
* @returns {Promise<object>} - Resolves with the collection data
68
*/
7-
async function getCollection(collection_slug) {
9+
async function getCollection(collection_slug, { signal } = {}) {
810
const response = await this.request(`/collections/${collection_slug}`, {
911
method: "GET",
12+
signal,
1013
});
1114

1215
return response.json();
@@ -15,11 +18,14 @@ async function getCollection(collection_slug) {
1518
/**
1619
* Fetch a list of model collections
1720
*
21+
* @param {object} [options]
22+
* @param {AbortSignal} [options.signal] - An optional AbortSignal
1823
* @returns {Promise<object>} - Resolves with the collections data
1924
*/
20-
async function listCollections() {
25+
async function listCollections({ signal } = {}) {
2126
const response = await this.request("/collections", {
2227
method: "GET",
28+
signal,
2329
});
2430

2531
return response.json();

lib/deployments.js

+33-7
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@ const { transformFileInputs } = require("./util");
1010
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
1111
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
1212
* @param {boolean|integer} [options.wait] - Whether to wait until the prediction is completed before returning. If an integer is provided, it will wait for that many seconds. Defaults to false
13+
* @param {AbortSignal} [options.signal] - An optional AbortSignal
1314
* @returns {Promise<object>} Resolves with the created prediction data
1415
*/
1516
async function createPrediction(deployment_owner, deployment_name, options) {
16-
const { input, wait, ...data } = options;
17+
const { input, wait, signal, ...data } = options;
1718

1819
if (data.webhook) {
1920
try {
@@ -47,6 +48,7 @@ async function createPrediction(deployment_owner, deployment_name, options) {
4748
this.fileEncodingStrategy
4849
),
4950
},
51+
signal,
5052
}
5153
);
5254

@@ -58,13 +60,20 @@ async function createPrediction(deployment_owner, deployment_name, options) {
5860
*
5961
* @param {string} deployment_owner - Required. The username of the user or organization who owns the deployment
6062
* @param {string} deployment_name - Required. The name of the deployment
63+
* @param {object] [options]
64+
* @param {AbortSignal} [options.signal] - An optional AbortSignal
6165
* @returns {Promise<object>} Resolves with the deployment data
6266
*/
63-
async function getDeployment(deployment_owner, deployment_name) {
67+
async function getDeployment(
68+
deployment_owner,
69+
deployment_name,
70+
{ signal } = {}
71+
) {
6472
const response = await this.request(
6573
`/deployments/${deployment_owner}/${deployment_name}`,
6674
{
6775
method: "GET",
76+
signal,
6877
}
6978
);
7079

@@ -84,13 +93,16 @@ async function getDeployment(deployment_owner, deployment_name) {
8493
/**
8594
* Create a deployment
8695
*
87-
* @param {DeploymentCreateRequest} config - Required. The deployment config.
96+
* @param {DeploymentCreateRequest} deployment_config - Required. The deployment config.
97+
* @param {object} [options]
98+
* @param {AbortSignal} [options.signal] - An optional AbortSignal
8899
* @returns {Promise<object>} Resolves with the deployment data
89100
*/
90-
async function createDeployment(deployment_config) {
101+
async function createDeployment(deployment_config, { signal } = {}) {
91102
const response = await this.request("/deployments", {
92103
method: "POST",
93104
data: deployment_config,
105+
signal,
94106
});
95107

96108
return response.json();
@@ -110,18 +122,22 @@ async function createDeployment(deployment_config) {
110122
* @param {string} deployment_owner - Required. The username of the user or organization who owns the deployment
111123
* @param {string} deployment_name - Required. The name of the deployment
112124
* @param {DeploymentUpdateRequest} deployment_config - Required. The deployment changes.
125+
* @param {object} [options]
126+
* @param {AbortSignal} [options.signal] - An optional AbortSignal
113127
* @returns {Promise<object>} Resolves with the deployment data
114128
*/
115129
async function updateDeployment(
116130
deployment_owner,
117131
deployment_name,
118-
deployment_config
132+
deployment_config,
133+
{ signal } = {}
119134
) {
120135
const response = await this.request(
121136
`/deployments/${deployment_owner}/${deployment_name}`,
122137
{
123138
method: "PATCH",
124139
data: deployment_config,
140+
signal,
125141
}
126142
);
127143

@@ -133,13 +149,20 @@ async function updateDeployment(
133149
*
134150
* @param {string} deployment_owner - Required. The username of the user or organization who owns the deployment
135151
* @param {string} deployment_name - Required. The name of the deployment
152+
* @param {object} [options]
153+
* @param {AbortSignal} [options.signal] - An optional AbortSignal
136154
* @returns {Promise<boolean>} Resolves with true if the deployment was deleted
137155
*/
138-
async function deleteDeployment(deployment_owner, deployment_name) {
156+
async function deleteDeployment(
157+
deployment_owner,
158+
deployment_name,
159+
{ signal } = {}
160+
) {
139161
const response = await this.request(
140162
`/deployments/${deployment_owner}/${deployment_name}`,
141163
{
142164
method: "DELETE",
165+
signal,
143166
}
144167
);
145168

@@ -149,11 +172,14 @@ async function deleteDeployment(deployment_owner, deployment_name) {
149172
/**
150173
* List all deployments
151174
*
175+
* @param {object} [options]
176+
* @param {AbortSignal} [options.signal] - An optional AbortSignal
152177
* @returns {Promise<object>} - Resolves with a page of deployments
153178
*/
154-
async function listDeployments() {
179+
async function listDeployments({ signal } = {}) {
155180
const response = await this.request("/deployments", {
156181
method: "GET",
182+
signal,
157183
});
158184

159185
return response.json();

0 commit comments

Comments
 (0)