From f46a697b1373362d836bfd866656ab3dd24c6168 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Sun, 14 Jan 2024 22:15:09 +0000 Subject: [PATCH 1/5] Generate TypeScript definitions from source --- index.d.ts | 315 --------------------------- index.js | 71 +++++- index.test.ts | 3 +- integration/typescript/types.test.ts | 84 +++++++ jsconfig.json | 4 +- lib/collections.js | 10 +- lib/deployments.js | 9 +- lib/hardware.js | 3 +- lib/identifier.js | 12 +- lib/models.js | 29 ++- lib/predictions.js | 37 +++- lib/trainings.js | 16 +- lib/types.js | 71 ++++++ package.json | 7 +- tsconfig.json | 7 +- 15 files changed, 305 insertions(+), 373 deletions(-) delete mode 100644 index.d.ts create mode 100644 integration/typescript/types.test.ts create mode 100644 lib/types.js diff --git a/index.d.ts b/index.d.ts deleted file mode 100644 index d1602ad..0000000 --- a/index.d.ts +++ /dev/null @@ -1,315 +0,0 @@ -declare module "replicate" { - type Status = "starting" | "processing" | "succeeded" | "failed" | "canceled"; - type Visibility = "public" | "private"; - type WebhookEventType = "start" | "output" | "logs" | "completed"; - - export interface ApiError extends Error { - request: Request; - response: Response; - } - - export interface Account { - type: "user" | "organization"; - username: string; - name: string; - github_url?: string; - } - - export interface Collection { - name: string; - slug: string; - description: string; - models?: Model[]; - } - - export interface Deployment { - owner: string; - name: string; - current_release: { - number: number; - model: string; - version: string; - created_at: string; - created_by: Account; - configuration: { - hardware: string; - scaling: { - min_instances: number; - max_instances: number; - }; - }; - }; - } - - export interface Hardware { - sku: string; - name: string; - } - - export interface Model { - url: string; - owner: string; - name: string; - description?: string; - visibility: "public" | "private"; - github_url?: string; - paper_url?: string; - license_url?: string; - run_count: number; - cover_image_url?: string; - default_example?: Prediction; - latest_version?: ModelVersion; - } - - export interface ModelVersion { - id: string; - created_at: string; - cog_version: string; - openapi_schema: object; - } - - export interface Prediction { - id: string; - status: Status; - model: string; - version: string; - input: object; - output?: any; - source: "api" | "web"; - error?: any; - logs?: string; - metrics?: { - predict_time?: number; - }; - webhook?: string; - webhook_events_filter?: WebhookEventType[]; - created_at: string; - started_at?: string; - completed_at?: string; - urls: { - get: string; - cancel: string; - stream?: string; - }; - } - - export type Training = Prediction; - - export interface Page { - previous?: string; - next?: string; - results: T[]; - } - - export interface ServerSentEvent { - event: string; - data: string; - id?: string; - retry?: number; - } - - export interface WebhookSecret { - key: string; - } - - export default class Replicate { - constructor(options?: { - auth?: string; - userAgent?: string; - baseUrl?: string; - fetch?: ( - input: Request | string, - init?: RequestInit - ) => Promise; - }); - - auth: string; - userAgent?: string; - baseUrl?: string; - fetch: (input: Request | string, init?: RequestInit) => Promise; - - run( - identifier: `${string}/${string}` | `${string}/${string}:${string}`, - options: { - input: object; - wait?: { interval?: number }; - webhook?: string; - webhook_events_filter?: WebhookEventType[]; - signal?: AbortSignal; - }, - progress?: (prediction: Prediction) => void - ): Promise; - - stream( - identifier: `${string}/${string}` | `${string}/${string}:${string}`, - options: { - input: object; - webhook?: string; - webhook_events_filter?: WebhookEventType[]; - signal?: AbortSignal; - } - ): AsyncGenerator; - - request( - route: string | URL, - options: { - method?: string; - headers?: object | Headers; - params?: object; - data?: object; - } - ): Promise; - - paginate(endpoint: () => Promise>): AsyncGenerator<[T]>; - - wait( - prediction: Prediction, - options?: { - interval?: number; - }, - stop?: (prediction: Prediction) => Promise - ): Promise; - - accounts: { - current(): Promise; - }; - - collections: { - list(): Promise>; - get(collection_slug: string): Promise; - }; - - deployments: { - predictions: { - create( - deployment_owner: string, - deployment_name: string, - options: { - input: object; - stream?: boolean; - webhook?: string; - webhook_events_filter?: WebhookEventType[]; - } - ): Promise; - }; - get( - deployment_owner: string, - deployment_name: string - ): Promise; - create(deployment_config: { - name: string; - model: string; - version: string; - hardware: string; - min_instances: number; - max_instances: number; - }): Promise; - update( - deployment_owner: string, - deployment_name: string, - deployment_config: { - version?: string; - hardware?: string; - min_instances?: number; - max_instances?: number; - } & ( - | { version: string } - | { hardware: string } - | { min_instances: number } - | { max_instances: number } - ) - ): Promise; - list(): Promise>; - }; - - hardware: { - list(): Promise; - }; - - models: { - get(model_owner: string, model_name: string): Promise; - list(): Promise>; - create( - model_owner: string, - model_name: string, - options: { - visibility: Visibility; - hardware: string; - description?: string; - github_url?: string; - paper_url?: string; - license_url?: string; - cover_image_url?: string; - } - ): Promise; - versions: { - list(model_owner: string, model_name: string): Promise; - get( - model_owner: string, - model_name: string, - version_id: string - ): Promise; - }; - }; - - predictions: { - create( - options: { - model?: string; - version?: string; - input: object; - stream?: boolean; - webhook?: string; - webhook_events_filter?: WebhookEventType[]; - } & ({ version: string } | { model: string }) - ): Promise; - get(prediction_id: string): Promise; - cancel(prediction_id: string): Promise; - list(): Promise>; - }; - - trainings: { - create( - model_owner: string, - model_name: string, - version_id: string, - options: { - destination: `${string}/${string}`; - input: object; - webhook?: string; - webhook_events_filter?: WebhookEventType[]; - } - ): Promise; - get(training_id: string): Promise; - cancel(training_id: string): Promise; - list(): Promise>; - }; - - webhooks: { - default: { - secret: { - get(): Promise; - }; - }; - }; - } - - export function validateWebhook( - requestData: - | Request - | { - id?: string; - timestamp?: string; - body: string; - secret?: string; - signature?: string; - }, - secret: string - ): Promise; - - export function parseProgressFromLogs(logs: Prediction | string): { - percentage: number; - current: number; - total: number; - } | null; -} diff --git a/index.js b/index.js index 042af91..d6311bf 100644 --- a/index.js +++ b/index.js @@ -41,30 +41,54 @@ class Replicate { /** * Create a new Replicate API client instance. * - * @param {object} options - Configuration options for the client - * @param {string} options.auth - API access token. Defaults to the `REPLICATE_API_TOKEN` environment variable. - * @param {string} options.userAgent - Identifier of your app + * @example + * // Create a new Replicate API client instance + * const Replicate = require("replicate"); + * const replicate = new Replicate({ + * // get your token from https://replicate.com/account + * auth: process.env.REPLICATE_API_TOKEN, + * userAgent: "my-app/1.2.3" + * }); + * + * // Run a model and await the result: + * const model = 'owner/model:version-id' + * const input = {text: 'Hello, world!'} + * const output = await replicate.run(model, { input }); + * + * @param {Object} [options] - Configuration options for the client + * @param {string} [options.auth] - API access token. Defaults to the `REPLICATE_API_TOKEN` environment variable. + * @param {string} [options.userAgent] - Identifier of your app * @param {string} [options.baseUrl] - Defaults to https://api.replicate.com/v1 * @param {Function} [options.fetch] - Fetch function to use. Defaults to `globalThis.fetch` */ constructor(options = {}) { + /** @type {string} */ this.auth = options.auth || (typeof process !== "undefined" ? process.env.REPLICATE_API_TOKEN : null); + + /** @type {string} */ this.userAgent = options.userAgent || `replicate-javascript/${packageJSON.version}`; + + /** @type {string} */ this.baseUrl = options.baseUrl || "https://api.replicate.com/v1"; + + /** @type {fetch} */ this.fetch = options.fetch || globalThis.fetch; + /** @type {accounts} */ this.accounts = { current: accounts.current.bind(this), }; + /** @type {collections} */ this.collections = { list: collections.list.bind(this), get: collections.get.bind(this), }; + /** @type {deployments} */ this.deployments = { get: deployments.get.bind(this), create: deployments.create.bind(this), @@ -75,10 +99,12 @@ class Replicate { }, }; + /** @type {hardware} */ this.hardware = { list: hardware.list.bind(this), }; + /** @type {models} */ this.models = { get: models.get.bind(this), list: models.list.bind(this), @@ -89,6 +115,7 @@ class Replicate { }, }; + /** @type {predictions} */ this.predictions = { create: predictions.create.bind(this), get: predictions.get.bind(this), @@ -96,6 +123,7 @@ class Replicate { list: predictions.list.bind(this), }; + /** @type {trainings} */ this.trainings = { create: trainings.create.bind(this), get: trainings.get.bind(this), @@ -115,18 +143,18 @@ class Replicate { /** * Run a model and wait for its output. * - * @param {string} ref - Required. The model version identifier in the format "owner/name" or "owner/name:version" + * @param {`${string}/${string}` | `${string}/${string}:${string}`} ref - Required. The model version identifier in the format "owner/name" or "owner/name:version" * @param {object} options * @param {object} options.input - Required. An object with the model inputs * @param {object} [options.wait] - Options for waiting for the prediction to finish * @param {number} [options.wait.interval] - Polling interval in milliseconds. Defaults to 500 * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output - * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) + * @param {WebhookEventType[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) * @param {AbortSignal} [options.signal] - AbortSignal to cancel the prediction * @param {Function} [progress] - Callback function that receives the prediction object as it's updated. The function is called when the prediction is created, each time its updated while polling for completion, and when it's completed. * @throws {Error} If the reference is invalid * @throws {Error} If the prediction failed - * @returns {Promise} - Resolves with the output of running the model + * @returns {Promise} - Resolves with the output of running the model */ async run(ref, options, progress) { const { wait, ...data } = options; @@ -262,7 +290,7 @@ class Replicate { /** * Stream a model and wait for its output. * - * @param {string} identifier - Required. The model version identifier in the format "{owner}/{name}:{version}" + * @param {string} ref - Required. The model version identifier in the format "{owner}/{name}:{version}" * @param {object} options * @param {object} options.input - Required. An object with the model inputs * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output @@ -315,8 +343,10 @@ class Replicate { * for await (const page of replicate.paginate(replicate.predictions.list) { * console.log(page); * } - * @param {Function} endpoint - Function that returns a promise for the next page of results - * @yields {object[]} Each page of results + * @template T + * @param {() => Promise>} endpoint - Function that returns a promise for the next page of results + * @yields {T[]} Each page of results + * @returns {AsyncGenerator} */ async *paginate(endpoint) { const response = await endpoint(); @@ -342,7 +372,7 @@ class Replicate { * @param {Function} [stop] - Async callback function that is called after each polling attempt. Receives the prediction object as an argument. Return false to cancel polling. * @throws {Error} If the prediction doesn't complete within the maximum number of attempts * @throws {Error} If the prediction failed - * @returns {Promise} Resolves with the completed prediction object + * @returns {Promise} Resolves with the completed prediction object */ async wait(prediction, options, stop) { const { id } = prediction; @@ -391,3 +421,24 @@ class Replicate { module.exports = Replicate; module.exports.validateWebhook = validateWebhook; module.exports.parseProgressFromLogs = parseProgressFromLogs; + +// - Type Definitions + +/** + * @typedef {import("./lib/error")} ApiError + * @typedef {import("./lib/types").Collection} Collection + * @typedef {import("./lib/types").ModelVersion} ModelVersion + * @typedef {import("./lib/types").Hardware} Hardware + * @typedef {import("./lib/types").Model} Model + * @typedef {import("./lib/types").Prediction} Prediction + * @typedef {import("./lib/types").Training} Training + * @typedef {import("./lib/types").ServerSentEvent} ServerSentEvent + * @typedef {import("./lib/types").Status} Status + * @typedef {import("./lib/types").Visibility} Visibility + * @typedef {import("./lib/types").WebhookEventType} WebhookEventType + */ + +/** + * @template T + * @typedef {import("./lib/types").Page} Page + */ diff --git a/index.test.ts b/index.test.ts index 55adee5..ad083cf 100644 --- a/index.test.ts +++ b/index.test.ts @@ -5,9 +5,8 @@ import Replicate, { Prediction, validateWebhook, parseProgressFromLogs, -} from "replicate"; +} from "./"; import nock from "nock"; -import { Readable } from "node:stream"; import { createReadableStream } from "./lib/stream"; let client: Replicate; diff --git a/integration/typescript/types.test.ts b/integration/typescript/types.test.ts new file mode 100644 index 0000000..d58484b --- /dev/null +++ b/integration/typescript/types.test.ts @@ -0,0 +1,84 @@ +import { ApiError, Collection, Hardware, Model, ModelVersion, Page, Prediction, Status, Training, Visibility, WebhookEventType } from "replicate"; + +export type Equals = + (() => T extends X ? 1 : 2) extends + (() => T extends Y ? 1 : 2) ? true : false; + + +type AssertFalse = A + +// @ts-expect-error +export type TestAssertion = AssertFalse> + +export type TestApiError = AssertFalse> +export type TestCollection = AssertFalse> +export type TestHardware = AssertFalse> +export type TestModel = AssertFalse> +export type TestModelVersion = AssertFalse> +export type TestPage = AssertFalse, any>> +export type TestPrediction = AssertFalse> +export type TestStatus = AssertFalse> +export type TestTraining = AssertFalse> +export type TestVisibility = AssertFalse> +export type TestWebhookEventType = AssertFalse> + + +// NOTE: We export the constants to avoid unused varaible issues. + +export const collection: Collection = { name: "", slug: "", description: "", models: [] }; +export const status: Status = "starting"; +export const visibility: Visibility = "public"; +export const webhookType: WebhookEventType = "start"; +export const err: ApiError = Object.assign(new Error(), {request: new Request("file://"), response: new Response()}); +export const hardware: Hardware = { sku: "", name: "" }; +export const model: Model = { + url: "", + owner: "", + name: "", + description: "", + visibility: "public", + github_url: "", + paper_url: "", + license_url: "", + run_count: 10, + cover_image_url: "", + default_example: undefined, + latest_version: undefined, +}; +export const version: ModelVersion = { + id: "", + created_at: "", + cog_version: "", + openapi_schema: "", +}; +export const prediction: Prediction = { + id: "", + status: "starting", + model: "", + version: "", + input: {}, + output: {}, + source: "api", + error: undefined, + logs: "", + metrics: { + predict_time: 100, + }, + webhook: "", + webhook_events_filter: [], + created_at: "", + started_at: "", + completed_at: "", + urls: { + get: "", + cancel: "", + stream: "", + }, +}; +export const training: Training = prediction; + +export const page: Page = { + previous: "", + next: "", + results: [version], +}; diff --git a/jsconfig.json b/jsconfig.json index b83b3f3..3d6fa2f 100644 --- a/jsconfig.json +++ b/jsconfig.json @@ -6,9 +6,11 @@ "target": "ES2020", "resolveJsonModule": true, "strictNullChecks": true, - "strictFunctionTypes": true + "strictFunctionTypes": true, + "types": [], }, "exclude": [ + "dist", "node_modules", "**/node_modules/*" ] diff --git a/lib/collections.js b/lib/collections.js index 9332aaa..4175934 100644 --- a/lib/collections.js +++ b/lib/collections.js @@ -1,8 +1,14 @@ +/** @typedef {import("./types").Collection} Collection */ +/** + * @template T + * @typedef {import("./types").Page} Page + */ + /** * Fetch a model collection * * @param {string} collection_slug - Required. The slug of the collection. See http://replicate.com/collections - * @returns {Promise} - Resolves with the collection data + * @returns {Promise} - Resolves with the collection data */ async function getCollection(collection_slug) { const response = await this.request(`/collections/${collection_slug}`, { @@ -15,7 +21,7 @@ async function getCollection(collection_slug) { /** * Fetch a list of model collections * - * @returns {Promise} - Resolves with the collections data + * @returns {Promise>} - Resolves with the collections data */ async function listCollections() { const response = await this.request("/collections", { diff --git a/lib/deployments.js b/lib/deployments.js index 4f6f3c6..9c2be3e 100644 --- a/lib/deployments.js +++ b/lib/deployments.js @@ -1,3 +1,5 @@ +/** @typedef {import("./types").Prediction} Prediction */ + const { transformFileInputs } = require("./util"); /** @@ -6,18 +8,17 @@ const { transformFileInputs } = require("./util"); * @param {string} deployment_owner - Required. The username of the user or organization who owns the deployment * @param {string} deployment_name - Required. The name of the deployment * @param {object} options - * @param {object} options.input - Required. An object with the model inputs + * @param {unknown} options.input - Required. An object with the model inputs * @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output - * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) - * @returns {Promise} Resolves with the created prediction data + * @param {WebhookEventType[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) + * @returns {Promise} Resolves with the created prediction data */ async function createPrediction(deployment_owner, deployment_name, options) { const { stream, input, ...data } = options; if (data.webhook) { try { - // eslint-disable-next-line no-new new URL(data.webhook); } catch (err) { throw new Error("Invalid webhook URL"); diff --git a/lib/hardware.js b/lib/hardware.js index d717548..755bd91 100644 --- a/lib/hardware.js +++ b/lib/hardware.js @@ -1,7 +1,8 @@ +/** @typedef {import("./types").Hardware} Hardware */ /** * List hardware * - * @returns {Promise} Resolves with the array of hardware + * @returns {Promise} Resolves with the array of hardware */ async function listHardware() { const response = await this.request("/hardware", { diff --git a/lib/identifier.js b/lib/identifier.js index 86e23ee..f9e9786 100644 --- a/lib/identifier.js +++ b/lib/identifier.js @@ -2,10 +2,10 @@ * A reference to a model version in the format `owner/name` or `owner/name:version`. */ class ModelVersionIdentifier { - /* - * @param {string} Required. The model owner. - * @param {string} Required. The model name. - * @param {string} The model version. + /** + * @param {string} owner Required. The model owner. + * @param {string} name Required. The model name. + * @param {string | null=} version The model version. */ constructor(owner, name, version = null) { this.owner = owner; @@ -13,10 +13,10 @@ class ModelVersionIdentifier { this.version = version; } - /* + /** * Parse a reference to a model version * - * @param {string} + * @param {string} ref * @returns {ModelVersionIdentifier} * @throws {Error} If the reference is invalid. */ diff --git a/lib/models.js b/lib/models.js index c6a02fc..e7cbcd8 100644 --- a/lib/models.js +++ b/lib/models.js @@ -1,9 +1,18 @@ +/** @typedef {import("./types").Model} Model */ +/** @typedef {import("./types").ModelVersion} ModelVersion */ +/** @typedef {import("./types").Prediction} Prediction */ +/** @typedef {import("./types").Visibility} Visibility */ +/** + * @template T + * @typedef {import("./types").Page} Page + */ + /** * Get information about a model * * @param {string} model_owner - Required. The name of the user or organization that owns the model * @param {string} model_name - Required. The name of the model - * @returns {Promise} Resolves with the model data + * @returns {Promise} Resolves with the model data */ async function getModel(model_owner, model_name) { const response = await this.request(`/models/${model_owner}/${model_name}`, { @@ -18,7 +27,7 @@ async function getModel(model_owner, model_name) { * * @param {string} model_owner - Required. The name of the user or organization that owns the model * @param {string} model_name - Required. The name of the model - * @returns {Promise} Resolves with the list of model versions + * @returns {Promise>} Resolves with the list of model versions */ async function listModelVersions(model_owner, model_name) { const response = await this.request( @@ -37,7 +46,7 @@ async function listModelVersions(model_owner, model_name) { * @param {string} model_owner - Required. The name of the user or organization that owns the model * @param {string} model_name - Required. The name of the model * @param {string} version_id - Required. The model version - * @returns {Promise} Resolves with the model version data + * @returns {Promise} Resolves with the model version data */ async function getModelVersion(model_owner, model_name, version_id) { const response = await this.request( @@ -53,7 +62,7 @@ async function getModelVersion(model_owner, model_name, version_id) { /** * List all public models * - * @returns {Promise} Resolves with the model version data + * @returns {Promise>} Resolves with the model version data */ async function listModels() { const response = await this.request("/models", { @@ -69,14 +78,14 @@ async function listModels() { * @param {string} model_owner - Required. The name of the user or organization that will own the model. This must be the same as the user or organization that is making the API request. In other words, the API token used in the request must belong to this user or organization. * @param {string} model_name - Required. The name of the model. This must be unique among all models owned by the user or organization. * @param {object} options - * @param {("public"|"private")} options.visibility - Required. Whether the model should be public or private. A public model can be viewed and run by anyone, whereas a private model can be viewed and run only by the user or organization members that own the model. + * @param {Visibility} options.visibility - Required. Whether the model should be public or private. A public model can be viewed and run by anyone, whereas a private model can be viewed and run only by the user or organization members that own the model. * @param {string} options.hardware - Required. The SKU for the hardware used to run the model. Possible values can be found by calling `Replicate.hardware.list()`. * @param {string} options.description - A description of the model. - * @param {string} options.github_url - A URL for the model's source code on GitHub. - * @param {string} options.paper_url - A URL for the model's paper. - * @param {string} options.license_url - A URL for the model's license. - * @param {string} options.cover_image_url - A URL for the model's cover image. This should be an image file. - * @returns {Promise} Resolves with the model version data + * @param {string=} options.github_url - A URL for the model's source code on GitHub. + * @param {string=} options.paper_url - A URL for the model's paper. + * @param {string=} options.license_url - A URL for the model's license. + * @param {string=} options.cover_image_url - A URL for the model's cover image. This should be an image file. + * @returns {Promise} Resolves with the model version data */ async function createModel(model_owner, model_name, options) { const data = { owner: model_owner, name: model_name, ...options }; diff --git a/lib/predictions.js b/lib/predictions.js index 5b0370e..d354e77 100644 --- a/lib/predictions.js +++ b/lib/predictions.js @@ -1,16 +1,29 @@ +/** + * @template T + * @typedef {import("./types").Page} Page + * @typedef {import("./types").Prediction} Prediction + * @typedef {Object} BasePredictionOptions + * @property {unknown} input - Required. An object with the model inputs + * @property {string} [webhook] - An HTTPS URL for receiving a webhook when the prediction has new output + * @property {string[]} [webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) + * @property {boolean} [stream] - Whether to stream the prediction output. Defaults to false + * + * @typedef {Object} ModelPredictionOptions + * @property {string} model The model name (for official models) + * @property {never=} version + * + * @typedef {Object} VersionPredictionOptions + * @property {string} version The model version + * @property {never=} model + */ + const { transformFileInputs } = require("./util"); /** * Create a new prediction * - * @param {object} options - * @param {string} options.model - The model. - * @param {string} options.version - The model version. - * @param {object} options.input - Required. An object with the model inputs - * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output - * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) - * @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false - * @returns {Promise} Resolves with the created prediction + * @param {BasePredictionOptions & (ModelPredictionOptions | VersionPredictionOptions)} options + * @returns {Promise} Resolves with the created prediction */ async function createPrediction(options) { const { model, version, stream, input, ...data } = options; @@ -54,8 +67,8 @@ async function createPrediction(options) { /** * Fetch a prediction by ID * - * @param {number} prediction_id - Required. The prediction ID - * @returns {Promise} Resolves with the prediction data + * @param {string} prediction_id - Required. The prediction ID + * @returns {Promise} Resolves with the prediction data */ async function getPrediction(prediction_id) { const response = await this.request(`/predictions/${prediction_id}`, { @@ -69,7 +82,7 @@ async function getPrediction(prediction_id) { * Cancel a prediction by ID * * @param {string} prediction_id - Required. The training ID - * @returns {Promise} Resolves with the data for the training + * @returns {Promise} Resolves with the data for the training */ async function cancelPrediction(prediction_id) { const response = await this.request(`/predictions/${prediction_id}/cancel`, { @@ -82,7 +95,7 @@ async function cancelPrediction(prediction_id) { /** * List all predictions * - * @returns {Promise} - Resolves with a page of predictions + * @returns {Promise>} - Resolves with a page of predictions */ async function listPredictions() { const response = await this.request("/predictions", { diff --git a/lib/trainings.js b/lib/trainings.js index 6b13dca..e469b96 100644 --- a/lib/trainings.js +++ b/lib/trainings.js @@ -1,3 +1,9 @@ +/** + * @template T + * @typedef {import("./types").Page} Page + */ +/** @typedef {import("./types").Training} Training */ + /** * Create a new training * @@ -6,10 +12,10 @@ * @param {string} version_id - Required. The version ID * @param {object} options * @param {string} options.destination - Required. The destination for the trained version in the form "{username}/{model_name}" - * @param {object} options.input - Required. An object with the model inputs + * @param {unknown} options.input - Required. An object with the model inputs * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the training updates * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) - * @returns {Promise} Resolves with the data for the created training + * @returns {Promise} Resolves with the data for the created training */ async function createTraining(model_owner, model_name, version_id, options) { const { ...data } = options; @@ -38,7 +44,7 @@ async function createTraining(model_owner, model_name, version_id, options) { * Fetch a training by ID * * @param {string} training_id - Required. The training ID - * @returns {Promise} Resolves with the data for the training + * @returns {Promise} Resolves with the data for the training */ async function getTraining(training_id) { const response = await this.request(`/trainings/${training_id}`, { @@ -52,7 +58,7 @@ async function getTraining(training_id) { * Cancel a training by ID * * @param {string} training_id - Required. The training ID - * @returns {Promise} Resolves with the data for the training + * @returns {Promise} Resolves with the data for the training */ async function cancelTraining(training_id) { const response = await this.request(`/trainings/${training_id}/cancel`, { @@ -65,7 +71,7 @@ async function cancelTraining(training_id) { /** * List all trainings * - * @returns {Promise} - Resolves with a page of trainings + * @returns {Promise>} - Resolves with a page of trainings */ async function listTrainings() { const response = await this.request("/trainings", { diff --git a/lib/types.js b/lib/types.js new file mode 100644 index 0000000..fd05845 --- /dev/null +++ b/lib/types.js @@ -0,0 +1,71 @@ +/** + * @typedef {"starting" | "processing" | "succeeded" | "failed" | "canceled"} Status + * @typedef {"public" | "private"} Visibility + * @typedef {"start" | "output" | "logs" | "completed"} WebhookEventType + * + * @typedef {Object} Collection + * @property {string} name + * @property {string} slug + * @property {string} description + * @property {Model[]=} models + * + * @typedef {Object} Hardware + * @property {string} sku + * @property {string} name + * + * @typedef {Object} Model + * @property {string} url + * @property {string} owner + * @property {string} name + * @property {string=} description + * @property {Visibility} visibility + * @property {string=} github_url + * @property {string=} paper_url + * @property {string=} license_url + * @property {number} run_count + * @property {string=} cover_image_url + * @property {Prediction=} default_example + * @property {ModelVersion=} latest_version + * + * @typedef {Object} ModelVersion + * @property {string} id + * @property {string} created_at + * @property {string} cog_version + * @property {string} openapi_schema + * + * @typedef {Object} Prediction + * @property {string} id + * @property {Status} status + * @property {string=} model + * @property {string} version + * @property {object} input + * @property {unknown=} output + * @property {"api" | "web"} source + * @property {unknown=} error + * @property {string=} logs + * @property {{predict_time?: number}=} metrics + * @property {string=} webhook + * @property {WebhookEventType[]=} webhook_events_filter + * @property {string} created_at + * @property {string=} started_at + * @property {string=} completed_at + * @property {{get: string; cancel: string; stream?: string}} urls + * + * @typedef {Prediction} Training + * + * @typedef {Object} ServerSentEvent + * @property {string} event + * @property {string} data + * @property {string=} id + * @property {number=} retry + */ + +/** + * @template T + * @typedef {Object} Page + * @property {string=} previous + * @property {string=} next + * @property {T[]} results + */ + +module.exports = {}; diff --git a/package.json b/package.json index 26f546e..06650ef 100644 --- a/package.json +++ b/package.json @@ -8,7 +8,7 @@ "license": "Apache-2.0", "main": "index.js", "type": "commonjs", - "types": "index.d.ts", + "types": "dist/types/index.d.ts", "files": [ "CONTRIBUTING.md", "LICENSE", @@ -26,12 +26,15 @@ "yarn": ">=1.7.0" }, "scripts": { + "build-types": "tsc --target ES2022 --declaration --emitDeclarationOnly --allowJs --types node --outDir ./dist/types index.js", "check": "tsc", "format": "biome format . --write", "lint-biome": "biome lint .", "lint-publint": "publint", "lint": "npm run lint-biome && npm run lint-publint", - "test": "jest" + "test": "jest", + "test-integration": "npm run build-types; for x in commonjs esm typescript; do npm --prefix integration/$x install --omit=dev && npm --prefix integration/$x test; done;", + "test-all": "npm run check; npm run test; npm run test-integration" }, "optionalDependencies": { "readable-stream": ">=4.0.0" diff --git a/tsconfig.json b/tsconfig.json index d77efdc..7c43e22 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -1,9 +1,10 @@ { "compilerOptions": { + "allowJs": true, "esModuleInterop": true, "noEmit": true, - "strict": true, - "allowJs": true + "strict": true }, - "exclude": ["**/node_modules", "integration"] + "types": ["node"], + "exclude": ["dist", "integration", "**/node_modules"] } From 331a654b47ccfcdd09b3930d121232ee01e17ef4 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Tue, 23 Jan 2024 12:27:12 +0000 Subject: [PATCH 2/5] Tidy up the scripts We now have basic and `:integration` flavors of the common commands `lint` and `test` as well as an `:all` flavor that will run everything. We also now ensure that the types are built before running integration tests as well as part of the packaging workflow via `prepack`. --- package.json | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/package.json b/package.json index 06650ef..73629af 100644 --- a/package.json +++ b/package.json @@ -13,6 +13,7 @@ "CONTRIBUTING.md", "LICENSE", "README.md", + "dist/**/*", "index.d.ts", "index.js", "lib/**/*.js", @@ -26,15 +27,17 @@ "yarn": ">=1.7.0" }, "scripts": { - "build-types": "tsc --target ES2022 --declaration --emitDeclarationOnly --allowJs --types node --outDir ./dist/types index.js", + "build": "npm run build:types", + "build:types": "tsc --target ES2022 --declaration --emitDeclarationOnly --allowJs --types node --outDir ./dist/types index.js", "check": "tsc", "format": "biome format . --write", - "lint-biome": "biome lint .", - "lint-publint": "publint", - "lint": "npm run lint-biome && npm run lint-publint", + "lint": "biome lint .", + "lint:integration": "npm run build; publint", + "lint:all": "npm run tsc; npm run lint; npm run lint:integration", + "prepack": "npm run build", "test": "jest", - "test-integration": "npm run build-types; for x in commonjs esm typescript; do npm --prefix integration/$x install --omit=dev && npm --prefix integration/$x test; done;", - "test-all": "npm run check; npm run test; npm run test-integration" + "test:integration": "npm run build; for x in commonjs esm typescript; do npm --prefix integration/$x install --omit=dev && npm --prefix integration/$x test; done;", + "test:all": "npm run check; npm run test; npm run test:integration" }, "optionalDependencies": { "readable-stream": ">=4.0.0" From 8cca8557d616e8e85b18bc61a7005bf17b5a95c4 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Fri, 26 Jan 2024 14:16:48 +0000 Subject: [PATCH 3/5] Run `npm install` before `npm pack` --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 74995fc..ed18be5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -45,6 +45,7 @@ jobs: - name: Build tarball id: pack run: | + npm clean-install echo "tarball-name=$(npm --loglevel error pack)" >> $GITHUB_OUTPUT - uses: actions/upload-artifact@v3 with: From fb22f14727211ab5e38b68d3914e2c1a74662212 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Fri, 15 Mar 2024 17:50:24 +0000 Subject: [PATCH 4/5] Add missing TypeScript definitions --- index.js | 3 ++- lib/accounts.js | 4 +++- lib/deployments.js | 4 +++- lib/predictions.js | 3 +++ lib/stream.js | 2 +- lib/types.js | 24 ++++++++++++++++++++++-- lib/util.js | 31 ++++++++----------------------- lib/webhooks.js | 2 +- package.json | 6 +----- 9 files changed, 44 insertions(+), 35 deletions(-) diff --git a/index.js b/index.js index d6311bf..aac0bcb 100644 --- a/index.js +++ b/index.js @@ -151,7 +151,7 @@ class Replicate { * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output * @param {WebhookEventType[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) * @param {AbortSignal} [options.signal] - AbortSignal to cancel the prediction - * @param {Function} [progress] - Callback function that receives the prediction object as it's updated. The function is called when the prediction is created, each time its updated while polling for completion, and when it's completed. + * @param {(p: Prediction) => void} [progress] - Callback function that receives the prediction object as it's updated. The function is called when the prediction is created, each time its updated while polling for completion, and when it's completed. * @throws {Error} If the reference is invalid * @throws {Error} If the prediction failed * @returns {Promise} - Resolves with the output of running the model @@ -426,6 +426,7 @@ module.exports.parseProgressFromLogs = parseProgressFromLogs; /** * @typedef {import("./lib/error")} ApiError + * @typedef {import("./lib/types").Account} Account * @typedef {import("./lib/types").Collection} Collection * @typedef {import("./lib/types").ModelVersion} ModelVersion * @typedef {import("./lib/types").Hardware} Hardware diff --git a/lib/accounts.js b/lib/accounts.js index b3bbd9f..fda3a40 100644 --- a/lib/accounts.js +++ b/lib/accounts.js @@ -1,7 +1,9 @@ +/** @typedef {import("./types").Account} Account */ + /** * Get the current account * - * @returns {Promise} Resolves with the current account + * @returns {Promise} Resolves with the current account */ async function getCurrentAccount() { const response = await this.request("/account", { diff --git a/lib/deployments.js b/lib/deployments.js index 9c2be3e..dd0e78b 100644 --- a/lib/deployments.js +++ b/lib/deployments.js @@ -1,4 +1,6 @@ +/** @typedef {import("./types").Deployment} Deployment */ /** @typedef {import("./types").Prediction} Prediction */ +/** @typedef {import("./types").WebhookEventType} WebhookEventType */ const { transformFileInputs } = require("./util"); @@ -45,7 +47,7 @@ async function createPrediction(deployment_owner, deployment_name, options) { * * @param {string} deployment_owner - Required. The username of the user or organization who owns the deployment * @param {string} deployment_name - Required. The name of the deployment - * @returns {Promise} Resolves with the deployment data + * @returns {Promise} Resolves with the deployment data */ async function getDeployment(deployment_owner, deployment_name) { const response = await this.request( diff --git a/lib/predictions.js b/lib/predictions.js index d354e77..a3ab440 100644 --- a/lib/predictions.js +++ b/lib/predictions.js @@ -1,6 +1,9 @@ /** * @template T * @typedef {import("./types").Page} Page + */ + +/** * @typedef {import("./types").Prediction} Prediction * @typedef {Object} BasePredictionOptions * @property {unknown} input - Required. An object with the model inputs diff --git a/lib/stream.js b/lib/stream.js index 2e0bbde..82fcda4 100644 --- a/lib/stream.js +++ b/lib/stream.js @@ -46,7 +46,7 @@ class ServerSentEvent { * * @param {object} config * @param {string} config.url The URL to connect to. - * @param {typeof fetch} [config.fetch] The URL to connect to. + * @param {(url: URL | RequestInfo, init?: RequestInit | undefined) => Promise} [config.fetch] The URL to connect to. * @param {object} [config.options] The EventSource options. * @returns {ReadableStream & AsyncIterable} */ diff --git a/lib/types.js b/lib/types.js index fd05845..9e93e3f 100644 --- a/lib/types.js +++ b/lib/types.js @@ -1,7 +1,13 @@ -/** +/** * @typedef {"starting" | "processing" | "succeeded" | "failed" | "canceled"} Status * @typedef {"public" | "private"} Visibility - * @typedef {"start" | "output" | "logs" | "completed"} WebhookEventType + * @typedef {"start" | "output" | "logs" | "completed"} WebhookEventType + * + * @typedef {Object} Account + * @property {"user" | "organization"} type + * @property {string} username + * @property {string} name + * @property {string=} github_url * * @typedef {Object} Collection * @property {string} name @@ -9,6 +15,20 @@ * @property {string} description * @property {Model[]=} models * + * @typedef {Object} Deployment + * @property {string} owner + * @property {string} name + * @property {object} current_release + * @property {number} current_release.number + * @property {string} current_release.model + * @property {string} current_release.version + * @property {string} current_release.created_at + * @property {Account} current_release.created_by + * @property {object} current_release.configuration + * @property {string} current_release.configuration.hardware + * @property {number} current_release.configuration.min_instances + * @property {number} current_release.configuration.max_instances + * * @typedef {Object} Hardware * @property {string} sku * @property {string} name diff --git a/lib/util.js b/lib/util.js index 22a14c8..4f0044f 100644 --- a/lib/util.js +++ b/lib/util.js @@ -1,31 +1,16 @@ const ApiError = require("./error"); -/** - * @see {@link validateWebhook} - * @overload - * @param {object} requestData - The request data - * @param {string} requestData.id - The webhook ID header from the incoming request. - * @param {string} requestData.timestamp - The webhook timestamp header from the incoming request. - * @param {string} requestData.body - The raw body of the incoming webhook request. - * @param {string} requestData.secret - The webhook secret, obtained from `replicate.webhooks.defaul.secret` method. - * @param {string} requestData.signature - The webhook signature header from the incoming request, comprising one or more space-delimited signatures. - */ - -/** - * @see {@link validateWebhook} - * @overload - * @param {object} requestData - The request object - * @param {object} requestData.headers - The request headers - * @param {string} requestData.headers["webhook-id"] - The webhook ID header from the incoming request - * @param {string} requestData.headers["webhook-timestamp"] - The webhook timestamp header from the incoming request - * @param {string} requestData.headers["webhook-signature"] - The webhook signature header from the incoming request, comprising one or more space-delimited signatures - * @param {string} requestData.body - The raw body of the incoming webhook request - * @param {string} secret - The webhook secret, obtained from `replicate.webhooks.defaul.secret` method - */ - /** * Validate a webhook signature * + * @typedef {Object} WebhookPayload + * @property {string} id - The webhook ID header from the incoming request. + * @property {string} timestamp - The webhook timestamp header from the incoming request. + * @property {string} body - The raw body of the incoming webhook request. + * @property {string} signature - The webhook signature header from the incoming request, comprising one or more space-delimited signatures. + * + * @param {Request | WebhookPayload} requestData + * @param {string} secret - The webhook secret, obtained from `replicate.webhooks.defaul.secret` method. * @returns {Promise} - True if the signature is valid * @throws {Error} - If the request is missing required headers, body, or secret */ diff --git a/lib/webhooks.js b/lib/webhooks.js index f1324ec..e484df8 100644 --- a/lib/webhooks.js +++ b/lib/webhooks.js @@ -1,7 +1,7 @@ /** * Get the default webhook signing secret * - * @returns {Promise} Resolves with the signing secret for the default webhook + * @returns {Promise<{key: string}>} Resolves with the signing secret for the default webhook */ async function getDefaultWebhookSecret() { const response = await this.request("/webhooks/default/secret", { diff --git a/package.json b/package.json index 73629af..6008eaa 100644 --- a/package.json +++ b/package.json @@ -27,7 +27,7 @@ "yarn": ">=1.7.0" }, "scripts": { - "build": "npm run build:types", + "build": "npm run build:types && tsc --noEmit dist/types/**/*.d.ts", "build:types": "tsc --target ES2022 --declaration --emitDeclarationOnly --allowJs --types node --outDir ./dist/types index.js", "check": "tsc", "format": "biome format . --write", @@ -39,14 +39,10 @@ "test:integration": "npm run build; for x in commonjs esm typescript; do npm --prefix integration/$x install --omit=dev && npm --prefix integration/$x test; done;", "test:all": "npm run check; npm run test; npm run test:integration" }, - "optionalDependencies": { - "readable-stream": ">=4.0.0" - }, "devDependencies": { "@biomejs/biome": "^1.4.1", "@types/jest": "^29.5.3", "@typescript-eslint/eslint-plugin": "^5.56.0", - "cross-fetch": "^3.1.5", "jest": "^29.6.2", "nock": "^14.0.0-beta.4", "publint": "^0.2.7", From 6b6eb15e92c04a6dc208ac3b4a5653205ac0c92e Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Mon, 18 Mar 2024 22:28:24 -0700 Subject: [PATCH 5/5] Add missing webhooks type annotation --- index.js | 2 + index.test.ts | 28 ++--- integration/typescript/types.test.ts | 175 +++++++++++++++++---------- lib/collections.js | 2 +- lib/deployments.js | 24 ++-- lib/models.js | 2 +- lib/trainings.js | 2 +- 7 files changed, 141 insertions(+), 94 deletions(-) diff --git a/index.js b/index.js index aac0bcb..81c9234 100644 --- a/index.js +++ b/index.js @@ -131,6 +131,7 @@ class Replicate { list: trainings.list.bind(this), }; + /** @type {webhooks} */ this.webhooks = { default: { secret: { @@ -428,6 +429,7 @@ module.exports.parseProgressFromLogs = parseProgressFromLogs; * @typedef {import("./lib/error")} ApiError * @typedef {import("./lib/types").Account} Account * @typedef {import("./lib/types").Collection} Collection + * @typedef {import("./lib/types").Deployment} Deployment * @typedef {import("./lib/types").ModelVersion} ModelVersion * @typedef {import("./lib/types").Hardware} Hardware * @typedef {import("./lib/types").Model} Model diff --git a/index.test.ts b/index.test.ts index ad083cf..0e48e05 100644 --- a/index.test.ts +++ b/index.test.ts @@ -790,10 +790,8 @@ describe("Replicate client", () => { }, configuration: { hardware: "gpu-t4", - scaling: { - min_instances: 1, - max_instances: 5, - }, + min_instances: 1, + max_instances: 5, }, }, }); @@ -831,10 +829,8 @@ describe("Replicate client", () => { }, configuration: { hardware: "gpu-t4", - scaling: { - min_instances: 1, - max_instances: 5, - }, + min_instances: 1, + max_instances: 5, }, }, }); @@ -877,10 +873,8 @@ describe("Replicate client", () => { }, configuration: { hardware: "gpu-a40-large", - scaling: { - min_instances: 3, - max_instances: 10, - }, + min_instances: 3, + max_instances: 10, }, }, }); @@ -904,12 +898,8 @@ describe("Replicate client", () => { expect(deployment.current_release.configuration.hardware).toBe( "gpu-a40-large" ); - expect( - deployment.current_release.configuration.scaling?.min_instances - ).toBe(3); - expect( - deployment.current_release.configuration.scaling?.max_instances - ).toBe(10); + expect(deployment.current_release.configuration.min_instances).toBe(3); + expect(deployment.current_release.configuration.max_instances).toBe(10); }); // Add more tests for error handling, edge cases, etc. }); @@ -934,7 +924,7 @@ describe("Replicate client", () => { }); const deployments = await client.deployments.list(); - expect(deployments.results.length).toBe(1) + expect(deployments.results.length).toBe(1); }); // Add more tests for pagination, error handling, edge cases, etc. }); diff --git a/integration/typescript/types.test.ts b/integration/typescript/types.test.ts index d58484b..efbfa0b 100644 --- a/integration/typescript/types.test.ts +++ b/integration/typescript/types.test.ts @@ -1,84 +1,135 @@ -import { ApiError, Collection, Hardware, Model, ModelVersion, Page, Prediction, Status, Training, Visibility, WebhookEventType } from "replicate"; +import { + Account, + ApiError, + Collection, + Deployment, + Hardware, + Model, + ModelVersion, + Page, + Prediction, + Status, + Training, + Visibility, + WebhookEventType, +} from "replicate"; -export type Equals = - (() => T extends X ? 1 : 2) extends - (() => T extends Y ? 1 : 2) ? true : false; +export type Equals = (() => T extends X ? 1 : 2) extends < + T, +>() => T extends Y ? 1 : 2 + ? true + : false; - -type AssertFalse = A +type AssertFalse = A; // @ts-expect-error -export type TestAssertion = AssertFalse> - -export type TestApiError = AssertFalse> -export type TestCollection = AssertFalse> -export type TestHardware = AssertFalse> -export type TestModel = AssertFalse> -export type TestModelVersion = AssertFalse> -export type TestPage = AssertFalse, any>> -export type TestPrediction = AssertFalse> -export type TestStatus = AssertFalse> -export type TestTraining = AssertFalse> -export type TestVisibility = AssertFalse> -export type TestWebhookEventType = AssertFalse> +export type TestAssertion = AssertFalse>; +export type TestAccount = AssertFalse>; +export type TestApiError = AssertFalse>; +export type TestCollection = AssertFalse>; +export type TestDeployment = AssertFalse>; +export type TestHardware = AssertFalse>; +export type TestModel = AssertFalse>; +export type TestModelVersion = AssertFalse>; +export type TestPage = AssertFalse, any>>; +export type TestPrediction = AssertFalse>; +export type TestStatus = AssertFalse>; +export type TestTraining = AssertFalse>; +export type TestVisibility = AssertFalse>; +export type TestWebhookEventType = AssertFalse>; // NOTE: We export the constants to avoid unused varaible issues. -export const collection: Collection = { name: "", slug: "", description: "", models: [] }; +export const account: Account = { + type: "user", + name: "", + username: "", + github_url: "", +}; +export const collection: Collection = { + name: "", + slug: "", + description: "", + models: [], +}; +export const deployment: Deployment = { + owner: "", + name: "", + current_release: { + number: 1, + model: "", + version: "", + created_at: "", + created_by: { + type: "user", + username: "", + name: "", + github_url: "", + }, + configuration: { + hardware: "gpu-a100", + min_instances: 0, + max_instances: 5, + }, + }, +}; export const status: Status = "starting"; export const visibility: Visibility = "public"; export const webhookType: WebhookEventType = "start"; -export const err: ApiError = Object.assign(new Error(), {request: new Request("file://"), response: new Response()}); +export const err: ApiError = Object.assign(new Error(), { + request: new Request("file://"), + response: new Response(), +}); export const hardware: Hardware = { sku: "", name: "" }; export const model: Model = { - url: "", - owner: "", - name: "", - description: "", - visibility: "public", - github_url: "", - paper_url: "", - license_url: "", - run_count: 10, - cover_image_url: "", - default_example: undefined, - latest_version: undefined, + url: "", + owner: "", + name: "", + description: "", + visibility: "public", + github_url: "", + paper_url: "", + license_url: "", + run_count: 10, + cover_image_url: "", + default_example: undefined, + latest_version: undefined, }; export const version: ModelVersion = { - id: "", - created_at: "", - cog_version: "", - openapi_schema: "", + id: "", + created_at: "", + cog_version: "", + openapi_schema: "", }; export const prediction: Prediction = { - id: "", - status: "starting", - model: "", - version: "", - input: {}, - output: {}, - source: "api", - error: undefined, - logs: "", - metrics: { - predict_time: 100, - }, - webhook: "", - webhook_events_filter: [], - created_at: "", - started_at: "", - completed_at: "", - urls: { - get: "", - cancel: "", - stream: "", - }, + id: "", + status: "starting", + model: "", + version: "", + input: {}, + output: {}, + source: "api", + error: undefined, + logs: "", + metrics: { + predict_time: 100, + }, + webhook: "", + webhook_events_filter: [], + created_at: "", + started_at: "", + completed_at: "", + urls: { + get: "", + cancel: "", + stream: "", + }, }; export const training: Training = prediction; export const page: Page = { - previous: "", - next: "", - results: [version], + previous: "", + next: "", + results: [version], }; diff --git a/lib/collections.js b/lib/collections.js index 4175934..e1b6cda 100644 --- a/lib/collections.js +++ b/lib/collections.js @@ -1,5 +1,5 @@ /** @typedef {import("./types").Collection} Collection */ -/** +/** * @template T * @typedef {import("./types").Page} Page */ diff --git a/lib/deployments.js b/lib/deployments.js index dd0e78b..dced368 100644 --- a/lib/deployments.js +++ b/lib/deployments.js @@ -1,3 +1,7 @@ +/** + * @template T + * @typedef {import("./types").Page} Page + */ /** @typedef {import("./types").Deployment} Deployment */ /** @typedef {import("./types").Prediction} Prediction */ /** @typedef {import("./types").WebhookEventType} WebhookEventType */ @@ -74,12 +78,12 @@ async function getDeployment(deployment_owner, deployment_name) { * Create a deployment * * @param {DeploymentCreateRequest} config - Required. The deployment config. - * @returns {Promise} Resolves with the deployment data + * @returns {Promise} Resolves with the deployment data */ -async function createDeployment(deployment_config) { +async function createDeployment(config) { const response = await this.request("/deployments", { method: "POST", - data: deployment_config, + data: config, }); return response.json(); @@ -87,10 +91,10 @@ async function createDeployment(deployment_config) { /** * @typedef {Object} DeploymentUpdateRequest - Request body for `deployments.update` - * @property {string} version - the 64-character string ID of the model version that you want to deploy - * @property {string} hardware - the SKU for the hardware used to run the model, via `replicate.hardware.list()` - * @property {number} min_instances - the minimum number of instances for scaling - * @property {number} max_instances - the maximum number of instances for scaling + * @property {string=} version - the 64-character string ID of the model version that you want to deploy + * @property {string=} hardware - the SKU for the hardware used to run the model, via `replicate.hardware.list()` + * @property {number=} min_instances - the minimum number of instances for scaling + * @property {number=} max_instances - the maximum number of instances for scaling */ /** @@ -98,8 +102,8 @@ async function createDeployment(deployment_config) { * * @param {string} deployment_owner - Required. The username of the user or organization who owns the deployment * @param {string} deployment_name - Required. The name of the deployment - * @param {DeploymentUpdateRequest} deployment_config - Required. The deployment changes. - * @returns {Promise} Resolves with the deployment data + * @param {DeploymentUpdateRequest | {version: string} | {hardware: string} | {min_instances: number} | {max_instance: number}} deployment_config - Required. The deployment changes. + * @returns {Promise} Resolves with the deployment data */ async function updateDeployment( deployment_owner, @@ -120,7 +124,7 @@ async function updateDeployment( /** * List all deployments * - * @returns {Promise} - Resolves with a page of deployments + * @returns {Promise>} - Resolves with a page of deployments */ async function listDeployments() { const response = await this.request("/deployments", { diff --git a/lib/models.js b/lib/models.js index e7cbcd8..9282e2d 100644 --- a/lib/models.js +++ b/lib/models.js @@ -2,7 +2,7 @@ /** @typedef {import("./types").ModelVersion} ModelVersion */ /** @typedef {import("./types").Prediction} Prediction */ /** @typedef {import("./types").Visibility} Visibility */ -/** +/** * @template T * @typedef {import("./types").Page} Page */ diff --git a/lib/trainings.js b/lib/trainings.js index e469b96..b0de908 100644 --- a/lib/trainings.js +++ b/lib/trainings.js @@ -1,4 +1,4 @@ -/** +/** * @template T * @typedef {import("./types").Page} Page */