Skip to content

Commit 4994995

Browse files
committedJan 14, 2024
Generate TypeScript definitions from source
1 parent 09a117a commit 4994995

14 files changed

+205
-335
lines changed
 

Diff for: ‎index.d.ts

-296
This file was deleted.

Diff for: ‎index.js

+29-5
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
1-
const Replicate = require("./lib/replicate");
1+
const ReplicateClass = require("./lib/replicate");
22
const ApiError = require("./lib/error");
3+
require("./lib/types");
34

45
/**
56
* Placeholder class used to warn of deprecated constructor.
67
* @deprecated use exported Replicate class instead
78
*/
8-
class DeprecatedReplicate extends Replicate {
9+
class DeprecatedReplicate extends ReplicateClass {
910
/** @deprecated Use `import { Replicate } from "replicate";` instead */
1011
// biome-ignore lint/complexity/noUselessConstructor: exists for the tsdoc comment
1112
constructor(...args) {
1213
super(...args);
1314
}
1415
}
1516

16-
const named = { ApiError, Replicate };
17-
const singleton = new Replicate();
17+
const named = { ApiError, Replicate: ReplicateClass };
18+
const singleton = new ReplicateClass();
1819

1920
/**
2021
* Default instance of the Replicate class that gets the access token
@@ -48,7 +49,7 @@ const singleton = new Replicate();
4849
* const client = new Replicate({...});
4950
* ```
5051
*
51-
* @type { Replicate & typeof DeprecatedReplicate & {ApiError: ApiError, Replicate: Replicate} }
52+
* @type { Replicate & typeof DeprecatedReplicate & {Replicate: typeof ReplicateClass} }
5253
*/
5354
const replicate = new Proxy(DeprecatedReplicate, {
5455
get(target, prop, receiver) {
@@ -70,3 +71,26 @@ const replicate = new Proxy(DeprecatedReplicate, {
7071
});
7172

7273
module.exports = replicate;
74+
75+
// - Type Definitions
76+
77+
/**
78+
* @typedef {import("./lib/replicate")} Replicate
79+
* @typedef {import("./lib/error")} ApiError
80+
* @typedef {typeof import("./lib/types").Collection} Collection
81+
* @typedef {typeof import("./lib/types").ModelVersion} ModelVersion
82+
* @typedef {typeof import("./lib/types").Hardware} Hardware
83+
* @typedef {typeof import("./lib/types").Model} Model
84+
* @typedef {typeof import("./lib/types").Prediction} Prediction
85+
* @typedef {typeof import("./lib/types").Training} Training
86+
* @typedef {typeof import("./lib/types").ServerSentEvent} ServerSentEvent
87+
* @typedef {typeof import("./lib/types").Status} Status
88+
* @typedef {typeof import("./lib/types").Visibility} Visibility
89+
* @typedef {typeof import("./lib/types").WebhookEventType} WebhookEventType
90+
*/
91+
92+
/**
93+
* @template T
94+
* @typedef {typeof import("./lib/types").Page} Page
95+
*/
96+

Diff for: ‎index.test.ts

+7-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { expect, jest, test } from "@jest/globals";
2-
import replicate, { ApiError, Model, Prediction, Replicate } from "replicate";
2+
import replicate, { ApiError, Model, Prediction, Replicate } from "./";
33
import nock from "nock";
44
import fetch from "cross-fetch";
55
import assert from "node:assert";
@@ -216,6 +216,7 @@ function testInstance(createClient: (opts?: object) => Replicate) {
216216
});
217217
const client = createClient();
218218
const prediction = await client.predictions.create({
219+
model: "foo",
219220
version:
220221
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
221222
input: {
@@ -237,6 +238,7 @@ function testInstance(createClient: (opts?: object) => Replicate) {
237238

238239
const client = createClient();
239240
await client.predictions.create({
241+
model: "foo",
240242
version:
241243
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
242244
input: {
@@ -250,6 +252,7 @@ function testInstance(createClient: (opts?: object) => Replicate) {
250252
await expect(async () => {
251253
const client = createClient();
252254
await client.predictions.create({
255+
model: "foo",
253256
version:
254257
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
255258
input: {
@@ -275,6 +278,7 @@ function testInstance(createClient: (opts?: object) => Replicate) {
275278

276279
const client = createClient();
277280
await client.predictions.create({
281+
model: "foo",
278282
version:
279283
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
280284
input: {
@@ -303,6 +307,7 @@ function testInstance(createClient: (opts?: object) => Replicate) {
303307
});
304308
const client = createClient();
305309
const prediction = await client.predictions.create({
310+
model: "foo",
306311
version:
307312
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
308313
input: {
@@ -324,6 +329,7 @@ function testInstance(createClient: (opts?: object) => Replicate) {
324329
const client = createClient();
325330
await expect(
326331
client.predictions.create({
332+
model: "foo",
327333
version:
328334
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
329335
input: {
@@ -885,7 +891,6 @@ function testInstance(createClient: (opts?: object) => Replicate) {
885891
});
886892

887893
test("Calls the correct API routes for a model", async () => {
888-
const firstPollingRequest = true;
889894

890895
nock(BASE_URL)
891896
.post("/models/replicate/hello-world/predictions")
@@ -968,12 +973,10 @@ function testInstance(createClient: (opts?: object) => Replicate) {
968973
const client = createClient();
969974
const options = { input: { text: "Hello, world!" } };
970975

971-
// @ts-expect-error
972976
await expect(client.run("owner:abc123", options)).rejects.toThrow();
973977

974978
await expect(client.run("/model:abc123", options)).rejects.toThrow();
975979

976-
// @ts-expect-error
977980
await expect(client.run(":abc123", options)).rejects.toThrow();
978981
});
979982

Diff for: ‎integration/typescript/types.test.ts

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import { ApiError, Hardware, Model, ModelVersion, Page, Prediction, Status, Training, Visibility, WebhookEventType } from "replicate";
2+
3+
// NOTE: We export the constants to avoid unused varaible issues.
4+
5+
export const collection: Collection = { name: "", slug: "", description: "", models: [] };
6+
export const status: Status = "starting";
7+
export const visibility: Visibility = "public";
8+
export const webhookType: WebhookEventType = "start";
9+
export const err: ApiError = Object.assign(new Error(), {request: new Request("file://"), response: new Response()});
10+
export const hardware: Hardware = { sku: "", name: "" };
11+
export const model: Model = {
12+
url: "",
13+
owner: "",
14+
name: "",
15+
description: "",
16+
visibility: "public",
17+
github_url: "",
18+
paper_url: "",
19+
license_url: "",
20+
run_count: 10,
21+
cover_image_url: "",
22+
default_example: undefined,
23+
latest_version: undefined,
24+
};
25+
export const version: ModelVersion = {
26+
id: "",
27+
created_at: "",
28+
cog_version: "",
29+
openapi_schema: "",
30+
};
31+
export const prediction: Prediction = {
32+
id: "",
33+
status: "starting",
34+
model: "",
35+
version: "",
36+
input: {},
37+
output: {},
38+
source: "api",
39+
error: undefined,
40+
logs: "",
41+
metrics: {
42+
predict_time: 100,
43+
},
44+
webhook: "",
45+
webhook_events_filter: [],
46+
created_at: "",
47+
started_at: "",
48+
completed_at: "",
49+
urls: {
50+
get: "",
51+
cancel: "",
52+
stream: "",
53+
},
54+
};
55+
export const training: Training = prediction;
56+
57+
export const page: Page<ModelVersion> = {
58+
previous: "",
59+
next: "",
60+
results: [version],
61+
};

Diff for: ‎lib/collections.js

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
* Fetch a model collection
33
*
44
* @param {string} collection_slug - Required. The slug of the collection. See http://replicate.com/collections
5-
* @returns {Promise<object>} - Resolves with the collection data
5+
* @returns {Promise<Collection>} - Resolves with the collection data
66
*/
77
async function getCollection(collection_slug) {
88
const response = await this.request(`/collections/${collection_slug}`, {
@@ -15,7 +15,7 @@ async function getCollection(collection_slug) {
1515
/**
1616
* Fetch a list of model collections
1717
*
18-
* @returns {Promise<object>} - Resolves with the collections data
18+
* @returns {Promise<Page<Collection>>} - Resolves with the collections data
1919
*/
2020
async function listCollections() {
2121
const response = await this.request("/collections", {

Diff for: ‎lib/deployments.js

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
* @param {object} options.input - Required. An object with the model inputs
88
* @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false
99
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
10-
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
11-
* @returns {Promise<object>} Resolves with the created prediction data
10+
* @param {WebhookEventType[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
11+
* @returns {Promise<Prediction>} Resolves with the created prediction data
1212
*/
1313
async function createPrediction(deployment_owner, deployment_name, options) {
1414
const { stream, ...data } = options;

Diff for: ‎lib/hardware.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/**
22
* List hardware
33
*
4-
* @returns {Promise<object[]>} Resolves with the array of hardware
4+
* @returns {Promise<Hardware[]>} Resolves with the array of hardware
55
*/
66
async function listHardware() {
77
const response = await this.request("/hardware", {

Diff for: ‎lib/models.js

+9-9
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
*
44
* @param {string} model_owner - Required. The name of the user or organization that owns the model
55
* @param {string} model_name - Required. The name of the model
6-
* @returns {Promise<object>} Resolves with the model data
6+
* @returns {Promise<Model>} Resolves with the model data
77
*/
88
async function getModel(model_owner, model_name) {
99
const response = await this.request(`/models/${model_owner}/${model_name}`, {
@@ -18,7 +18,7 @@ async function getModel(model_owner, model_name) {
1818
*
1919
* @param {string} model_owner - Required. The name of the user or organization that owns the model
2020
* @param {string} model_name - Required. The name of the model
21-
* @returns {Promise<object>} Resolves with the list of model versions
21+
* @returns {Promise<Page<ModelVersion>>} Resolves with the list of model versions
2222
*/
2323
async function listModelVersions(model_owner, model_name) {
2424
const response = await this.request(
@@ -37,7 +37,7 @@ async function listModelVersions(model_owner, model_name) {
3737
* @param {string} model_owner - Required. The name of the user or organization that owns the model
3838
* @param {string} model_name - Required. The name of the model
3939
* @param {string} version_id - Required. The model version
40-
* @returns {Promise<object>} Resolves with the model version data
40+
* @returns {Promise<ModelVersion>} Resolves with the model version data
4141
*/
4242
async function getModelVersion(model_owner, model_name, version_id) {
4343
const response = await this.request(
@@ -53,7 +53,7 @@ async function getModelVersion(model_owner, model_name, version_id) {
5353
/**
5454
* List all public models
5555
*
56-
* @returns {Promise<object>} Resolves with the model version data
56+
* @returns {Promise<Page<Model>>} Resolves with the model version data
5757
*/
5858
async function listModels() {
5959
const response = await this.request("/models", {
@@ -72,11 +72,11 @@ async function listModels() {
7272
* @param {("public"|"private")} options.visibility - Required. Whether the model should be public or private. A public model can be viewed and run by anyone, whereas a private model can be viewed and run only by the user or organization members that own the model.
7373
* @param {string} options.hardware - Required. The SKU for the hardware used to run the model. Possible values can be found by calling `Replicate.hardware.list()`.
7474
* @param {string} options.description - A description of the model.
75-
* @param {string} options.github_url - A URL for the model's source code on GitHub.
76-
* @param {string} options.paper_url - A URL for the model's paper.
77-
* @param {string} options.license_url - A URL for the model's license.
78-
* @param {string} options.cover_image_url - A URL for the model's cover image. This should be an image file.
79-
* @returns {Promise<object>} Resolves with the model version data
75+
* @param {string=} options.github_url - A URL for the model's source code on GitHub.
76+
* @param {string=} options.paper_url - A URL for the model's paper.
77+
* @param {string=} options.license_url - A URL for the model's license.
78+
* @param {string=} options.cover_image_url - A URL for the model's cover image. This should be an image file.
79+
* @returns {Promise<Model>} Resolves with the model version data
8080
*/
8181
async function createModel(model_owner, model_name, options) {
8282
const data = { owner: model_owner, name: model_name, ...options };

Diff for: ‎lib/predictions.js

+7-7
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
* Create a new prediction
33
*
44
* @param {object} options
5-
* @param {string} options.model - The model.
6-
* @param {string} options.version - The model version.
5+
* @param {string=} options.model - The model (for official models)
6+
* @param {string=} options.version - The model version.
77
* @param {object} options.input - Required. An object with the model inputs
88
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
99
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
1010
* @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false
11-
* @returns {Promise<object>} Resolves with the created prediction
11+
* @returns {Promise<Prediction>} Resolves with the created prediction
1212
*/
1313
async function createPrediction(options) {
1414
const { model, version, stream, ...data } = options;
@@ -43,8 +43,8 @@ async function createPrediction(options) {
4343
/**
4444
* Fetch a prediction by ID
4545
*
46-
* @param {number} prediction_id - Required. The prediction ID
47-
* @returns {Promise<object>} Resolves with the prediction data
46+
* @param {string} prediction_id - Required. The prediction ID
47+
* @returns {Promise<Prediction>} Resolves with the prediction data
4848
*/
4949
async function getPrediction(prediction_id) {
5050
const response = await this.request(`/predictions/${prediction_id}`, {
@@ -58,7 +58,7 @@ async function getPrediction(prediction_id) {
5858
* Cancel a prediction by ID
5959
*
6060
* @param {string} prediction_id - Required. The training ID
61-
* @returns {Promise<object>} Resolves with the data for the training
61+
* @returns {Promise<Prediction>} Resolves with the data for the training
6262
*/
6363
async function cancelPrediction(prediction_id) {
6464
const response = await this.request(`/predictions/${prediction_id}/cancel`, {
@@ -71,7 +71,7 @@ async function cancelPrediction(prediction_id) {
7171
/**
7272
* List all predictions
7373
*
74-
* @returns {Promise<object>} - Resolves with a page of predictions
74+
* @returns {Promise<Page<Prediction>>} - Resolves with a page of predictions
7575
*/
7676
async function listPredictions() {
7777
const response = await this.request("/predictions", {

Diff for: ‎lib/replicate.js

+3-3
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ module.exports = class Replicate {
4949
* const input = {text: 'Hello, world!'}
5050
* const output = await replicate.run(model, { input });
5151
*
52-
* @param {object} options - Configuration options for the client
53-
* @param {string} options.auth - API access token. Defaults to the `REPLICATE_API_TOKEN` environment variable.
54-
* @param {string} options.userAgent - Identifier of your app
52+
* @param {Object={}} options - Configuration options for the client
53+
* @param {string} [options.auth] - API access token. Defaults to the `REPLICATE_API_TOKEN` environment variable.
54+
* @param {string} [options.userAgent] - Identifier of your app
5555
* @param {string} [options.baseUrl] - Defaults to https://api.replicate.com/v1
5656
* @param {Function} [options.fetch] - Fetch function to use. Defaults to `globalThis.fetch`
5757
*/

Diff for: ‎lib/trainings.js

+4-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
* @param {object} options.input - Required. An object with the model inputs
1010
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the training updates
1111
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
12-
* @returns {Promise<object>} Resolves with the data for the created training
12+
* @returns {Promise<Training>} Resolves with the data for the created training
1313
*/
1414
async function createTraining(model_owner, model_name, version_id, options) {
1515
const { ...data } = options;
@@ -38,7 +38,7 @@ async function createTraining(model_owner, model_name, version_id, options) {
3838
* Fetch a training by ID
3939
*
4040
* @param {string} training_id - Required. The training ID
41-
* @returns {Promise<object>} Resolves with the data for the training
41+
* @returns {Promise<Training>} Resolves with the data for the training
4242
*/
4343
async function getTraining(training_id) {
4444
const response = await this.request(`/trainings/${training_id}`, {
@@ -52,7 +52,7 @@ async function getTraining(training_id) {
5252
* Cancel a training by ID
5353
*
5454
* @param {string} training_id - Required. The training ID
55-
* @returns {Promise<object>} Resolves with the data for the training
55+
* @returns {Promise<Training>} Resolves with the data for the training
5656
*/
5757
async function cancelTraining(training_id) {
5858
const response = await this.request(`/trainings/${training_id}/cancel`, {
@@ -65,7 +65,7 @@ async function cancelTraining(training_id) {
6565
/**
6666
* List all trainings
6767
*
68-
* @returns {Promise<object>} - Resolves with a page of trainings
68+
* @returns {Promise<Page<Training>>} - Resolves with a page of trainings
6969
*/
7070
async function listTrainings() {
7171
const response = await this.request("/trainings", {

Diff for: ‎lib/types.js

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/**
2+
* @typedef {"starting" | "processing" | "succeeded" | "failed" | "canceled"} Status
3+
* @typedef {"public" | "private"} Visibility
4+
* @typedef {"start" | "output" | "logs" | "completed"} WebhookEventType
5+
*
6+
* @typedef {import('./lib/error')} ApiError
7+
*
8+
* @typedef {Object} Collection
9+
* @property {string} name
10+
* @property {string} slug
11+
* @property {string} description
12+
* @property {Model[]=} models
13+
*
14+
* @typedef {Object} Hardware
15+
* @property {string} sku
16+
* @property {string} name
17+
*
18+
* @typedef {Object} Model
19+
* @property {string} url
20+
* @property {string} owner
21+
* @property {string} name
22+
* @property {string=} description
23+
* @property {Visibility} visibility
24+
* @property {string=} github_url
25+
* @property {string=} paper_url
26+
* @property {string=} license_url
27+
* @property {number} run_count
28+
* @property {string=} cover_image_url
29+
* @property {Prediction=} default_example
30+
* @property {ModelVersion=} latest_version
31+
*
32+
* @typedef {Object} ModelVersion
33+
* @property {string} id
34+
* @property {string} created_at
35+
* @property {string} cog_version
36+
* @property {string} openapi_schema
37+
*
38+
* @typedef {Object} Prediction
39+
* @property {string} id
40+
* @property {Status} status
41+
* @property {string=} model
42+
* @property {string} version
43+
* @property {object} input
44+
* @property {unknown=} output
45+
* @property {"api" | "web"} source
46+
* @property {unknown=} error
47+
* @property {string=} logs
48+
* @property {{predict_time?: number}=} metrics
49+
* @property {string=} webhook
50+
* @property {WebhookEventType[]=} webhook_events_filter
51+
* @property {string} created_at
52+
* @property {string=} started_at
53+
* @property {string=} completed_at
54+
* @property {{get: string; cancel: string; stream?: string}} urls
55+
*
56+
* @typedef {Prediction} Training
57+
*
58+
* @property {Object} ServerSentEvent
59+
* @property {string} event
60+
* @property {string} data
61+
* @property {string=} id
62+
* @property {number=} retry
63+
*/
64+
65+
/**
66+
* @template T
67+
* @typedef {Object} Page
68+
* @property {string=} previous
69+
* @property {string=} next
70+
* @property {T[]} results
71+
*/

Diff for: ‎package.json

+5-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"bugs": "https://github.com/replicate/replicate-javascript/issues",
88
"license": "Apache-2.0",
99
"main": "index.js",
10+
"types": "dist/types/index.d.ts",
1011
"engines": {
1112
"node": ">=18.0.0",
1213
"npm": ">=7.19.0",
@@ -17,7 +18,10 @@
1718
"check": "tsc",
1819
"format": "biome format . --write",
1920
"lint": "biome lint .",
20-
"test": "REPLICATE_API_TOKEN=test-token jest"
21+
"test": "REPLICATE_API_TOKEN=test-token jest",
22+
"types": "tsc --target ES2022 --declaration --emitDeclarationOnly --allowJs --types node --outDir ./dist/types index.js ./lib/types.js",
23+
"test:integration": "npm --prefix integration/commonjs test;npm --prefix integration/esm test;npm --prefix integration/typescript test",
24+
"test:all": "npm run test && npm run test:integration"
2125
},
2226
"optionalDependencies": {
2327
"readable-stream": ">=4.0.0"

Diff for: ‎tsconfig.json

+4-1
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
"compilerOptions": {
33
"esModuleInterop": true,
44
"noEmit": true,
5-
"strict": true
5+
"strict": true,
6+
"allowJs": true,
67
},
78
"exclude": [
9+
"dist",
10+
"integration",
811
"**/node_modules"
912
]
1013
}

0 commit comments

Comments
 (0)
Please sign in to comment.