diff --git a/README.md b/README.md index 67c8cab..cf12865 100644 --- a/README.md +++ b/README.md @@ -1213,15 +1213,21 @@ Low-level method used by the Replicate client to interact with API endpoints. const response = await replicate.request(route, parameters); ``` -| name | type | description | -| -------------------- | ------ | ------------------------------------------------------------ | -| `options.route` | string | Required. REST API endpoint path. | -| `options.parameters` | object | URL, query, and request body parameters for the given route. | +| name | type | description | +| -------------------- | ------------------- | ----------- | +| `options.route` | `string` | Required. REST API endpoint path. +| `options.params` | `object` | URL query parameters for the given route. | +| `options.method` | `string` | HTTP method for the given route. | +| `options.headers` | `object` | Additional HTTP headers for the given route. | +| `options.data` | `object \| FormData` | Request body. | +| `options.signal` | `AbortSignal` | Optional `AbortSignal`. | The `replicate.request()` method is used by the other methods to interact with the Replicate API. You can call this method directly to make other requests to the API. +The method accepts an `AbortSignal` which can be used to cancel the request in flight. + ### `FileOutput` `FileOutput` is a `ReadableStream` instance that represents a model file output. It can be used to stream file data to disk or as a `Response` body to an HTTP request. diff --git a/index.d.ts b/index.d.ts index 5b35bf2..709f466 100644 --- a/index.d.ts +++ b/index.d.ts @@ -183,10 +183,14 @@ declare module "replicate" { headers?: object | Headers; params?: object; data?: object; + signal?: AbortSignal; } ): Promise; - paginate(endpoint: () => Promise>): AsyncGenerator<[T]>; + paginate( + endpoint: () => Promise>, + options?: { signal?: AbortSignal } + ): AsyncGenerator; wait( prediction: Prediction, @@ -197,12 +201,15 @@ declare module "replicate" { ): Promise; accounts: { - current(): Promise; + current(options?: { signal?: AbortSignal }): Promise; }; collections: { - list(): Promise>; - get(collection_slug: string): Promise; + list(options?: { signal?: AbortSignal }): Promise>; + get( + collection_slug: string, + options?: { signal?: AbortSignal } + ): Promise; }; deployments: { @@ -217,21 +224,26 @@ declare module "replicate" { webhook?: string; webhook_events_filter?: WebhookEventType[]; wait?: number | boolean; + signal?: AbortSignal; } ): Promise; }; get( deployment_owner: string, - deployment_name: string + deployment_name: string, + options?: { signal?: AbortSignal } + ): Promise; + create( + deployment_config: { + name: string; + model: string; + version: string; + hardware: string; + min_instances: number; + max_instances: number; + }, + options?: { signal?: AbortSignal } ): Promise; - create(deployment_config: { - name: string; - model: string; - version: string; - hardware: string; - min_instances: number; - max_instances: number; - }): Promise; update( deployment_owner: string, deployment_name: string, @@ -245,32 +257,45 @@ declare module "replicate" { | { hardware: string } | { min_instances: number } | { max_instances: number } - ) + ), + options?: { signal?: AbortSignal } ): Promise; delete( deployment_owner: string, - deployment_name: string + deployment_name: string, + options?: { signal?: AbortSignal } ): Promise; - list(): Promise>; + list(options?: { signal?: AbortSignal }): Promise>; }; files: { create( file: Blob | Buffer, - metadata?: Record + metadata?: Record, + options?: { signal?: AbortSignal } ): Promise; - list(): Promise>; - get(file_id: string): Promise; - delete(file_id: string): Promise; + list(options?: { signal?: AbortSignal }): Promise>; + get( + file_id: string, + options?: { signal?: AbortSignal } + ): Promise; + delete( + file_id: string, + options?: { signal?: AbortSignal } + ): Promise; }; hardware: { - list(): Promise; + list(options?: { signal?: AbortSignal }): Promise; }; models: { - get(model_owner: string, model_name: string): Promise; - list(): Promise>; + get( + model_owner: string, + model_name: string, + options?: { signal?: AbortSignal } + ): Promise; + list(options?: { signal?: AbortSignal }): Promise>; create( model_owner: string, model_name: string, @@ -282,17 +307,26 @@ declare module "replicate" { paper_url?: string; license_url?: string; cover_image_url?: string; + signal?: AbortSignal; } ): Promise; versions: { - list(model_owner: string, model_name: string): Promise; + list( + model_owner: string, + model_name: string, + options?: { signal?: AbortSignal } + ): Promise; get( model_owner: string, model_name: string, - version_id: string + version_id: string, + options?: { signal?: AbortSignal } ): Promise; }; - search(query: string): Promise>; + search( + query: string, + options?: { signal?: AbortSignal } + ): Promise>; }; predictions: { @@ -306,11 +340,18 @@ declare module "replicate" { webhook?: string; webhook_events_filter?: WebhookEventType[]; wait?: boolean | number; + signal?: AbortSignal; } & ({ version: string } | { model: string }) ): Promise; - get(prediction_id: string): Promise; - cancel(prediction_id: string): Promise; - list(): Promise>; + get( + prediction_id: string, + options?: { signal?: AbortSignal } + ): Promise; + cancel( + prediction_id: string, + options?: { signal?: AbortSignal } + ): Promise; + list(options?: { signal?: AbortSignal }): Promise>; }; trainings: { @@ -323,17 +364,24 @@ declare module "replicate" { input: object; webhook?: string; webhook_events_filter?: WebhookEventType[]; + signal?: AbortSignal; } ): Promise; - get(training_id: string): Promise; - cancel(training_id: string): Promise; - list(): Promise>; + get( + training_id: string, + options?: { signal?: AbortSignal } + ): Promise; + cancel( + training_id: string, + options?: { signal?: AbortSignal } + ): Promise; + list(options?: { signal?: AbortSignal }): Promise>; }; webhooks: { default: { secret: { - get(): Promise; + get(options?: { signal?: AbortSignal }): Promise; }; }; }; diff --git a/index.js b/index.js index a5755d9..b1248e7 100644 --- a/index.js +++ b/index.js @@ -225,6 +225,7 @@ class Replicate { * @param {object} [options.params] - Query parameters * @param {object|Headers} [options.headers] - HTTP headers * @param {object} [options.data] - Body parameters + * @param {AbortSignal} [options.signal] - AbortSignal to cancel the request * @returns {Promise} - Resolves with the response object * @throws {ApiError} If the request failed */ @@ -241,7 +242,7 @@ class Replicate { ); } - const { method = "GET", params = {}, data } = options; + const { method = "GET", params = {}, data, signal } = options; for (const [key, value] of Object.entries(params)) { url.searchParams.append(key, value); @@ -273,6 +274,7 @@ class Replicate { method, headers, body, + signal, }; const shouldRetry = @@ -354,15 +356,20 @@ class Replicate { * console.log(page); * } * @param {Function} endpoint - Function that returns a promise for the next page of results + * @param {object} [options] + * @param {AbortSignal} [options.signal] - AbortSignal to cancel the request. * @yields {object[]} Each page of results */ - async *paginate(endpoint) { + async *paginate(endpoint, options = {}) { const response = await endpoint(); yield response.results; - if (response.next) { + if (response.next && !(options.signal && options.signal.aborted)) { const nextPage = () => - this.request(response.next, { method: "GET" }).then((r) => r.json()); - yield* this.paginate(nextPage); + this.request(response.next, { + method: "GET", + signal: options.signal, + }).then((r) => r.json()); + yield* this.paginate(nextPage, options); } } diff --git a/index.test.ts b/index.test.ts index f5bd609..4905908 100644 --- a/index.test.ts +++ b/index.test.ts @@ -99,6 +99,90 @@ describe("Replicate client", () => { }); }); + describe("paginate", () => { + test("pages through results", async () => { + nock(BASE_URL) + .get("/collections") + .reply(200, { + results: [ + { + name: "Super resolution", + slug: "super-resolution", + description: + "Upscaling models that create high-quality images from low-quality images.", + }, + ], + next: `${BASE_URL}/collections?page=2`, + previous: null, + }); + nock(BASE_URL) + .get("/collections?page=2") + .reply(200, { + results: [ + { + name: "Image classification", + slug: "image-classification", + description: "Models that classify images.", + }, + ], + next: null, + previous: null, + }); + + const iterator = client.paginate(client.collections.list); + + const firstPage = (await iterator.next()).value; + expect(firstPage.length).toBe(1); + + const secondPage = (await iterator.next()).value; + expect(secondPage.length).toBe(1); + }); + + test("accepts an abort signal", async () => { + nock(BASE_URL) + .get("/collections") + .reply(200, { + results: [ + { + name: "Super resolution", + slug: "super-resolution", + description: + "Upscaling models that create high-quality images from low-quality images.", + }, + ], + next: `${BASE_URL}/collections?page=2`, + previous: null, + }); + nock(BASE_URL) + .get("/collections?page=2") + .reply(200, { + results: [ + { + name: "Image classification", + slug: "image-classification", + description: "Models that classify images.", + }, + ], + next: null, + previous: null, + }); + + const controller = new AbortController(); + const iterator = client.paginate(client.collections.list, { + signal: controller.signal, + }); + + const firstIteration = await iterator.next(); + expect(firstIteration.value.length).toBe(1); + + controller.abort(); + + const secondIteration = await iterator.next(); + expect(secondIteration.value).toBeUndefined(); + expect(secondIteration.done).toBe(true); + }); + }); + describe("account.get", () => { test("Calls the correct API route", async () => { nock(BASE_URL).get("/account").reply(200, { diff --git a/integration/next/pages/index.js b/integration/next/pages/index.js index fc9581a..0912438 100644 --- a/integration/next/pages/index.js +++ b/integration/next/pages/index.js @@ -1,5 +1,5 @@ export default () => ( -
-

Welcome to Next.js

-
-) +
+

Welcome to Next.js

+
+); diff --git a/lib/accounts.js b/lib/accounts.js index b3bbd9f..72a94af 100644 --- a/lib/accounts.js +++ b/lib/accounts.js @@ -1,11 +1,14 @@ /** * Get the current account * + * @param {object} [options] + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} Resolves with the current account */ -async function getCurrentAccount() { +async function getCurrentAccount({ signal } = {}) { const response = await this.request("/account", { method: "GET", + signal, }); return response.json(); diff --git a/lib/collections.js b/lib/collections.js index 9332aaa..7b8e8f1 100644 --- a/lib/collections.js +++ b/lib/collections.js @@ -2,11 +2,14 @@ * Fetch a model collection * * @param {string} collection_slug - Required. The slug of the collection. See http://replicate.com/collections + * @param {object} [options] + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} - Resolves with the collection data */ -async function getCollection(collection_slug) { +async function getCollection(collection_slug, { signal } = {}) { const response = await this.request(`/collections/${collection_slug}`, { method: "GET", + signal, }); return response.json(); @@ -15,11 +18,14 @@ async function getCollection(collection_slug) { /** * Fetch a list of model collections * + * @param {object} [options] + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} - Resolves with the collections data */ -async function listCollections() { +async function listCollections({ signal } = {}) { const response = await this.request("/collections", { method: "GET", + signal, }); return response.json(); diff --git a/lib/deployments.js b/lib/deployments.js index 716c8e1..d45c4f3 100644 --- a/lib/deployments.js +++ b/lib/deployments.js @@ -10,10 +10,11 @@ const { transformFileInputs } = require("./util"); * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) * @param {boolean|integer} [options.wait] - Whether to wait until the prediction is completed before returning. If an integer is provided, it will wait for that many seconds. Defaults to false + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} Resolves with the created prediction data */ async function createPrediction(deployment_owner, deployment_name, options) { - const { input, wait, ...data } = options; + const { input, wait, signal, ...data } = options; if (data.webhook) { try { @@ -47,6 +48,7 @@ async function createPrediction(deployment_owner, deployment_name, options) { this.fileEncodingStrategy ), }, + signal, } ); @@ -58,13 +60,20 @@ async function createPrediction(deployment_owner, deployment_name, options) { * * @param {string} deployment_owner - Required. The username of the user or organization who owns the deployment * @param {string} deployment_name - Required. The name of the deployment + * @param {object] [options] + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} Resolves with the deployment data */ -async function getDeployment(deployment_owner, deployment_name) { +async function getDeployment( + deployment_owner, + deployment_name, + { signal } = {} +) { const response = await this.request( `/deployments/${deployment_owner}/${deployment_name}`, { method: "GET", + signal, } ); @@ -84,13 +93,16 @@ async function getDeployment(deployment_owner, deployment_name) { /** * Create a deployment * - * @param {DeploymentCreateRequest} config - Required. The deployment config. + * @param {DeploymentCreateRequest} deployment_config - Required. The deployment config. + * @param {object} [options] + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} Resolves with the deployment data */ -async function createDeployment(deployment_config) { +async function createDeployment(deployment_config, { signal } = {}) { const response = await this.request("/deployments", { method: "POST", data: deployment_config, + signal, }); return response.json(); @@ -110,18 +122,22 @@ async function createDeployment(deployment_config) { * @param {string} deployment_owner - Required. The username of the user or organization who owns the deployment * @param {string} deployment_name - Required. The name of the deployment * @param {DeploymentUpdateRequest} deployment_config - Required. The deployment changes. + * @param {object} [options] + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} Resolves with the deployment data */ async function updateDeployment( deployment_owner, deployment_name, - deployment_config + deployment_config, + { signal } = {} ) { const response = await this.request( `/deployments/${deployment_owner}/${deployment_name}`, { method: "PATCH", data: deployment_config, + signal, } ); @@ -133,13 +149,20 @@ async function updateDeployment( * * @param {string} deployment_owner - Required. The username of the user or organization who owns the deployment * @param {string} deployment_name - Required. The name of the deployment + * @param {object} [options] + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} Resolves with true if the deployment was deleted */ -async function deleteDeployment(deployment_owner, deployment_name) { +async function deleteDeployment( + deployment_owner, + deployment_name, + { signal } = {} +) { const response = await this.request( `/deployments/${deployment_owner}/${deployment_name}`, { method: "DELETE", + signal, } ); @@ -149,11 +172,14 @@ async function deleteDeployment(deployment_owner, deployment_name) { /** * List all deployments * + * @param {object} [options] + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} - Resolves with a page of deployments */ -async function listDeployments() { +async function listDeployments({ signal } = {}) { const response = await this.request("/deployments", { method: "GET", + signal, }); return response.json(); diff --git a/lib/files.js b/lib/files.js index de49c58..c810139 100644 --- a/lib/files.js +++ b/lib/files.js @@ -3,9 +3,11 @@ * * @param {object} file - Required. The file object. * @param {object} metadata - Optional. User-provided metadata associated with the file. + * @param {object} [options] + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} - Resolves with the file data */ -async function createFile(file, metadata = {}) { +async function createFile(file, metadata = {}, { signal } = {}) { const form = new FormData(); let filename; @@ -36,6 +38,7 @@ async function createFile(file, metadata = {}) { headers: { "Content-Type": "multipart/form-data", }, + signal, }); return response.json(); @@ -44,11 +47,14 @@ async function createFile(file, metadata = {}) { /** * List all files * + * @param {object} [options] + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} - Resolves with the files data */ -async function listFiles() { +async function listFiles({ signal } = {}) { const response = await this.request("/files", { method: "GET", + signal, }); return response.json(); @@ -58,11 +64,14 @@ async function listFiles() { * Get a file * * @param {string} file_id - Required. The ID of the file. + * @param {object} [options] + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} - Resolves with the file data */ -async function getFile(file_id) { +async function getFile(file_id, { signal } = {}) { const response = await this.request(`/files/${file_id}`, { method: "GET", + signal, }); return response.json(); @@ -72,11 +81,14 @@ async function getFile(file_id) { * Delete a file * * @param {string} file_id - Required. The ID of the file. + * @param {object} [options] + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} - Resolves with true if the file was deleted */ -async function deleteFile(file_id) { +async function deleteFile(file_id, { signal } = {}) { const response = await this.request(`/files/${file_id}`, { method: "DELETE", + signal, }); return response.status === 204; diff --git a/lib/hardware.js b/lib/hardware.js index d717548..e981b1f 100644 --- a/lib/hardware.js +++ b/lib/hardware.js @@ -1,11 +1,14 @@ /** * List hardware * + * @param {object} [options] + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} Resolves with the array of hardware */ -async function listHardware() { +async function listHardware({ signal } = {}) { const response = await this.request("/hardware", { method: "GET", + signal, }); return response.json(); diff --git a/lib/models.js b/lib/models.js index 272d9ed..4a3fcdd 100644 --- a/lib/models.js +++ b/lib/models.js @@ -3,11 +3,14 @@ * * @param {string} model_owner - Required. The name of the user or organization that owns the model * @param {string} model_name - Required. The name of the model + * @param {object} [options] + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} Resolves with the model data */ -async function getModel(model_owner, model_name) { +async function getModel(model_owner, model_name, { signal } = {}) { const response = await this.request(`/models/${model_owner}/${model_name}`, { method: "GET", + signal, }); return response.json(); @@ -18,13 +21,16 @@ async function getModel(model_owner, model_name) { * * @param {string} model_owner - Required. The name of the user or organization that owns the model * @param {string} model_name - Required. The name of the model + * @param {object} [options] + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} Resolves with the list of model versions */ -async function listModelVersions(model_owner, model_name) { +async function listModelVersions(model_owner, model_name, { signal } = {}) { const response = await this.request( `/models/${model_owner}/${model_name}/versions`, { method: "GET", + signal, } ); @@ -37,13 +43,21 @@ async function listModelVersions(model_owner, model_name) { * @param {string} model_owner - Required. The name of the user or organization that owns the model * @param {string} model_name - Required. The name of the model * @param {string} version_id - Required. The model version + * @param {object} [options] + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} Resolves with the model version data */ -async function getModelVersion(model_owner, model_name, version_id) { +async function getModelVersion( + model_owner, + model_name, + version_id, + { signal } = {} +) { const response = await this.request( `/models/${model_owner}/${model_name}/versions/${version_id}`, { method: "GET", + signal, } ); @@ -53,11 +67,14 @@ async function getModelVersion(model_owner, model_name, version_id) { /** * List all public models * + * @param {object} [options] + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} Resolves with the model version data */ -async function listModels() { +async function listModels({ signal } = {}) { const response = await this.request("/models", { method: "GET", + signal, }); return response.json(); @@ -76,14 +93,17 @@ async function listModels() { * @param {string} options.paper_url - A URL for the model's paper. * @param {string} options.license_url - A URL for the model's license. * @param {string} options.cover_image_url - A URL for the model's cover image. This should be an image file. + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} Resolves with the model version data */ async function createModel(model_owner, model_name, options) { - const data = { owner: model_owner, name: model_name, ...options }; + const { signal, ...rest } = options; + const data = { owner: model_owner, name: model_name, ...rest }; const response = await this.request("/models", { method: "POST", data, + signal, }); return response.json(); @@ -93,15 +113,18 @@ async function createModel(model_owner, model_name, options) { * Search for public models * * @param {string} query - The search query + * @param {object} [options] + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} Resolves with a page of models matching the search query */ -async function search(query) { +async function search(query, { signal } = {}) { const response = await this.request("/models", { method: "QUERY", headers: { "Content-Type": "text/plain", }, data: query, + signal, }); return response.json(); diff --git a/lib/predictions.js b/lib/predictions.js index f8e1c5a..708d04b 100644 --- a/lib/predictions.js +++ b/lib/predictions.js @@ -10,10 +10,11 @@ const { transformFileInputs } = require("./util"); * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) * @param {boolean|integer} [options.wait] - Whether to wait until the prediction is completed before returning. If an integer is provided, it will wait for that many seconds. Defaults to false + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} Resolves with the created prediction */ async function createPrediction(options) { - const { model, version, input, wait, ...data } = options; + const { model, version, input, wait, signal, ...data } = options; if (data.webhook) { try { @@ -48,6 +49,7 @@ async function createPrediction(options) { ), version, }, + signal, }); } else if (model) { response = await this.request(`/models/${model}/predictions`, { @@ -61,6 +63,7 @@ async function createPrediction(options) { this.fileEncodingStrategy ), }, + signal, }); } else { throw new Error("Either model or version must be specified"); @@ -73,11 +76,14 @@ async function createPrediction(options) { * Fetch a prediction by ID * * @param {number} prediction_id - Required. The prediction ID + * @param {object} [options] + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} Resolves with the prediction data */ -async function getPrediction(prediction_id) { +async function getPrediction(prediction_id, { signal } = {}) { const response = await this.request(`/predictions/${prediction_id}`, { method: "GET", + signal, }); return response.json(); @@ -87,11 +93,14 @@ async function getPrediction(prediction_id) { * Cancel a prediction by ID * * @param {string} prediction_id - Required. The training ID + * @param {object} [options] + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} Resolves with the data for the training */ -async function cancelPrediction(prediction_id) { +async function cancelPrediction(prediction_id, { signal } = {}) { const response = await this.request(`/predictions/${prediction_id}/cancel`, { method: "POST", + signal, }); return response.json(); @@ -100,11 +109,14 @@ async function cancelPrediction(prediction_id) { /** * List all predictions * + * @param {object} [options] + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} - Resolves with a page of predictions */ -async function listPredictions() { +async function listPredictions({ signal } = {}) { const response = await this.request("/predictions", { method: "GET", + signal, }); return response.json(); diff --git a/lib/trainings.js b/lib/trainings.js index 6b13dca..49640b9 100644 --- a/lib/trainings.js +++ b/lib/trainings.js @@ -9,10 +9,11 @@ * @param {object} options.input - Required. An object with the model inputs * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the training updates * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} Resolves with the data for the created training */ async function createTraining(model_owner, model_name, version_id, options) { - const { ...data } = options; + const { signal, ...data } = options; if (data.webhook) { try { @@ -28,6 +29,7 @@ async function createTraining(model_owner, model_name, version_id, options) { { method: "POST", data, + signal, } ); @@ -38,11 +40,14 @@ async function createTraining(model_owner, model_name, version_id, options) { * Fetch a training by ID * * @param {string} training_id - Required. The training ID + * @param {object} [options] + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} Resolves with the data for the training */ -async function getTraining(training_id) { +async function getTraining(training_id, { signal } = {}) { const response = await this.request(`/trainings/${training_id}`, { method: "GET", + signal, }); return response.json(); @@ -52,11 +57,14 @@ async function getTraining(training_id) { * Cancel a training by ID * * @param {string} training_id - Required. The training ID + * @param {object} [options] + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} Resolves with the data for the training */ -async function cancelTraining(training_id) { +async function cancelTraining(training_id, { signal } = {}) { const response = await this.request(`/trainings/${training_id}/cancel`, { method: "POST", + signal, }); return response.json(); @@ -65,11 +73,14 @@ async function cancelTraining(training_id) { /** * List all trainings * + * @param {object} [options] + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} - Resolves with a page of trainings */ -async function listTrainings() { +async function listTrainings({ signal } = {}) { const response = await this.request("/trainings", { method: "GET", + signal, }); return response.json(); diff --git a/lib/util.js b/lib/util.js index 2fa4919..0966577 100644 --- a/lib/util.js +++ b/lib/util.js @@ -247,7 +247,7 @@ async function withAutomaticRetries(request, options = {}) { * @param {Replicate} client - The client used to upload the file * @param {object} inputs - The inputs to transform * @param {"default" | "upload" | "data-uri"} strategy - Whether to upload files to Replicate, encode as dataURIs or try both. - * @returns {object} - The transformed inputs + * @returns {Promise} - The transformed inputs * @throws {ApiError} If the request to upload the file fails */ async function transformFileInputs(client, inputs, strategy) { @@ -280,7 +280,7 @@ async function transformFileInputs(client, inputs, strategy) { * * @param {Replicate} client - The client used to upload the file * @param {object} inputs - The inputs to transform - * @returns {object} - The transformed inputs + * @returns {Promise} - The transformed inputs * @throws {ApiError} If the request to upload the file fails */ async function transformFileInputsToReplicateFileURLs(client, inputs) { @@ -301,8 +301,8 @@ const MAX_DATA_URI_SIZE = 10_000_000; * base64-encoded data URI. * * @param {object} inputs - The inputs to transform - * @returns {object} - The transformed inputs - * @throws {Error} If the size of inputs exceeds a given threshould set by MAX_DATA_URI_SIZE + * @returns {Promise} - The transformed inputs + * @throws {Error} If the size of inputs exceeds a given threshold set by MAX_DATA_URI_SIZE */ async function transformFileInputsToBase64EncodedDataURIs(inputs) { let totalBytes = 0; @@ -311,10 +311,10 @@ async function transformFileInputsToBase64EncodedDataURIs(inputs) { let mime; if (value instanceof Blob) { - // Currently we use a NodeJS only API for base64 encoding, as + // Currently, we use a NodeJS only API for base64 encoding, as // we move to support the browser we could support either using // btoa (which does string encoding), the FileReader API or - // a JavaScript implenentation like base64-js. + // a JavaScript implementation like base64-js. // See: https://developer.mozilla.org/en-US/docs/Glossary/Base64 // See: https://github.com/beatgammit/base64-js buffer = await value.arrayBuffer(); diff --git a/lib/webhooks.js b/lib/webhooks.js index f1324ec..8da6fdf 100644 --- a/lib/webhooks.js +++ b/lib/webhooks.js @@ -1,11 +1,14 @@ /** * Get the default webhook signing secret * + * @param {object} [options] + * @param {AbortSignal} [options.signal] - An optional AbortSignal * @returns {Promise} Resolves with the signing secret for the default webhook */ -async function getDefaultWebhookSecret() { +async function getDefaultWebhookSecret({ signal } = {}) { const response = await this.request("/webhooks/default/secret", { method: "GET", + signal, }); return response.json();