Skip to content

Commit c716722

Browse files
authored
Add stream option to predictions.create operation (#99)
* Update request method to return response instead of JSON * Add stream parameter to predictions.create * Remove wait parameter from predictions.create * Rename wait parameter maxAttempts to max_attempts * Add documentation for ApiError class * Specify withCredentials options in EventSource constructor
1 parent 398cca5 commit c716722

File tree

9 files changed

+172
-70
lines changed

9 files changed

+172
-70
lines changed

README.md

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -251,17 +251,14 @@ const response = await replicate.predictions.create(options);
251251
| ------------------------------- | -------- | -------------------------------------------------------------------------------------------------------------------------------- |
252252
| `options.version` | string | **Required**. The model version |
253253
| `options.input` | object | **Required**. An object with the model's inputs |
254+
| `options.stream` | boolean | Requests a URL for streaming output output |
254255
| `options.webhook` | string | An HTTPS URL for receiving a webhook when the prediction has new output |
255256
| `options.webhook_events_filter` | string[] | You can change which events trigger webhook requests by specifying webhook events (`start` \| `output` \| `logs` \| `completed`) |
256257

257258
```jsonc
258259
{
259260
"id": "ufawqhfynnddngldkgtslldrkq",
260261
"version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
261-
"urls": {
262-
"get": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq",
263-
"cancel": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel"
264-
},
265262
"status": "succeeded",
266263
"input": {
267264
"text": "Alice"
@@ -272,10 +269,56 @@ const response = await replicate.predictions.create(options);
272269
"metrics": {},
273270
"created_at": "2022-04-26T22:13:06.224088Z",
274271
"started_at": null,
275-
"completed_at": null
272+
"completed_at": null,
273+
"urls": {
274+
"get": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq",
275+
"cancel": "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel",
276+
"stream": "https://streaming.api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq" // Present only if `options.stream` is `true`
277+
}
276278
}
277279
```
278280

281+
#### Streaming
282+
283+
Specify the `stream` option when creating a prediction
284+
to request a URL to receive streaming output using
285+
[server-sent events (SSE)](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events).
286+
287+
If the requested model version supports streaming,
288+
then the returned prediction will have a `stream` entry in its `urls` property
289+
with a URL that you can use to construct an
290+
[`EventSource`](https://developer.mozilla.org/en-US/docs/Web/API/EventSource).
291+
292+
```js
293+
if (prediction && prediction.urls && prediction.urls.stream) {
294+
const source = new EventSource(prediction.urls.stream, { withCredentials: true });
295+
296+
source.addEventListener("output", (e) => {
297+
console.log("output", e.data);
298+
});
299+
300+
source.addEventListener("error"), (e) => {
301+
console.error("error", JSON.parse(e.data));
302+
});
303+
304+
source.addEventListener("done"), (e) => {
305+
source.close();
306+
console.log("done", JSON.parse(e.data));
307+
});
308+
}
309+
```
310+
311+
A prediction's event stream consists of the following event types:
312+
313+
| event | format | description |
314+
| -------- | ---------- | ---------------------------------------------- |
315+
| `output` | plain text | Emitted when the prediction returns new output |
316+
| `error` | JSON | Emitted when the prediction returns an error |
317+
| `done` | JSON | Emitted when the prediction finishes |
318+
319+
A `done` event is emitted when a prediction finishes successfully,
320+
is cancelled, or produces an error.
321+
279322
### `replicate.predictions.get`
280323

281324
```js

index.d.ts

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ declare module 'replicate' {
5353
created_at: string;
5454
updated_at: string;
5555
completed_at?: string;
56+
urls: {
57+
get: string;
58+
cancel: string;
59+
stream?: string;
60+
};
5661
}
5762

5863
export type Training = Prediction;
@@ -80,18 +85,26 @@ declare module 'replicate' {
8085
identifier: `${string}/${string}:${string}`,
8186
options: {
8287
input: object;
83-
wait?: boolean | { interval?: number; maxAttempts?: number };
88+
wait?: { interval?: number; max_attempts?: number };
8489
webhook?: string;
8590
webhook_events_filter?: WebhookEventType[];
8691
}
8792
): Promise<object>;
88-
request(route: string, parameters: any): Promise<any>;
93+
94+
request(route: string | URL, options: {
95+
method?: string;
96+
headers?: object | Headers;
97+
params?: object;
98+
data?: object;
99+
}): Promise<Response>;
100+
89101
paginate<T>(endpoint: () => Promise<Page<T>>): AsyncGenerator<[ T ]>;
102+
90103
wait(
91104
prediction: Prediction,
92105
options: {
93106
interval?: number;
94-
maxAttempts?: number;
107+
max_attempts?: number;
95108
}
96109
): Promise<Prediction>;
97110

@@ -116,6 +129,7 @@ declare module 'replicate' {
116129
create(options: {
117130
version: string;
118131
input: object;
132+
stream?: boolean;
119133
webhook?: string;
120134
webhook_events_filter?: WebhookEventType[];
121135
}): Promise<Prediction>;

index.js

Lines changed: 44 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,17 @@ class Replicate {
8080
* @param {string} identifier - Required. The model version identifier in the format "{owner}/{name}:{version}"
8181
* @param {object} options
8282
* @param {object} options.input - Required. An object with the model inputs
83-
* @param {object} [options.wait] - Whether to wait for the prediction to finish. Defaults to false
83+
* @param {object} [options.wait] - Options for waiting for the prediction to finish
8484
* @param {number} [options.wait.interval] - Polling interval in milliseconds. Defaults to 250
85-
* @param {number} [options.wait.maxAttempts] - Maximum number of polling attempts. Defaults to no limit
85+
* @param {number} [options.wait.max_attempts] - Maximum number of polling attempts. Defaults to no limit
8686
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
8787
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
8888
* @throws {Error} If the prediction failed
8989
* @returns {Promise<object>} - Resolves with the output of running the model
9090
*/
9191
async run(identifier, options) {
92+
const { wait, ...data } = options;
93+
9294
// Define a pattern for owner and model names that allows
9395
// letters, digits, and certain special characters.
9496
// Example: "user123", "abc__123", "user.name"
@@ -108,12 +110,14 @@ class Replicate {
108110
}
109111

110112
const { version } = match.groups;
111-
const prediction = await this.predictions.create({
112-
wait: true,
113-
...options,
113+
114+
let prediction = await this.predictions.create({
115+
...data,
114116
version,
115117
});
116118

119+
prediction = await this.wait(prediction, wait || {});
120+
117121
if (prediction.status === 'failed') {
118122
throw new Error(`Prediction failed: ${prediction.error}`);
119123
}
@@ -125,43 +129,53 @@ class Replicate {
125129
* Make a request to the Replicate API.
126130
*
127131
* @param {string} route - REST API endpoint path
128-
* @param {object} parameters - Request parameters
129-
* @param {string} [parameters.method] - HTTP method. Defaults to GET
130-
* @param {object} [parameters.params] - Query parameters
131-
* @param {object} [parameters.data] - Body parameters
132-
* @returns {Promise<object>} - Resolves with the API response data
132+
* @param {object} options - Request parameters
133+
* @param {string} [options.method] - HTTP method. Defaults to GET
134+
* @param {object} [options.params] - Query parameters
135+
* @param {object|Headers} [options.headers] - HTTP headers
136+
* @param {object} [options.data] - Body parameters
137+
* @returns {Promise<Response>} - Resolves with the response object
133138
* @throws {ApiError} If the request failed
134139
*/
135-
async request(route, parameters) {
140+
async request(route, options) {
136141
const { auth, baseUrl, userAgent } = this;
137142

138-
const url = new URL(
139-
route.startsWith('/') ? route.slice(1) : route,
140-
baseUrl.endsWith('/') ? baseUrl : `${baseUrl}/`
141-
);
143+
let url;
144+
if (route instanceof URL) {
145+
url = route;
146+
} else {
147+
url = new URL(
148+
route.startsWith('/') ? route.slice(1) : route,
149+
baseUrl.endsWith('/') ? baseUrl : `${baseUrl}/`
150+
);
151+
}
142152

143-
const { method = 'GET', params = {}, data } = parameters;
153+
const { method = 'GET', params = {}, data } = options;
144154

145155
Object.entries(params).forEach(([key, value]) => {
146156
url.searchParams.append(key, value);
147157
});
148158

149-
const headers = {
150-
Authorization: `Token ${auth}`,
151-
'Content-Type': 'application/json',
152-
'User-Agent': userAgent,
153-
};
159+
const headers = new Headers();
160+
headers.append('Authorization', `Token ${auth}`);
161+
headers.append('Content-Type', 'application/json');
162+
headers.append('User-Agent', userAgent);
163+
if (options.headers) {
164+
options.headers.forEach((value, key) => {
165+
headers.append(key, value);
166+
});
167+
}
154168

155-
const options = {
169+
const init = {
156170
method,
157171
headers,
158172
body: data ? JSON.stringify(data) : undefined,
159173
};
160174

161-
const response = await this.fetch(url, options);
175+
const response = await this.fetch(url, init);
162176

163177
if (!response.ok) {
164-
const request = new Request(url, options);
178+
const request = new Request(url, init);
165179
const responseText = await response.text();
166180
throw new ApiError(
167181
`Request to ${url} failed with status ${response.status} ${response.statusText}: ${responseText}.`,
@@ -170,7 +184,7 @@ class Replicate {
170184
);
171185
}
172186

173-
return response.json();
187+
return response;
174188
}
175189

176190
/**
@@ -188,7 +202,7 @@ class Replicate {
188202
const response = await endpoint();
189203
yield response.results;
190204
if (response.next) {
191-
const nextPage = () => this.request(response.next, { method: 'GET' });
205+
const nextPage = () => this.request(response.next, { method: 'GET' }).then((r) => r.json());
192206
yield* this.paginate(nextPage);
193207
}
194208
}
@@ -204,7 +218,7 @@ class Replicate {
204218
* @param {object} prediction - Prediction object
205219
* @param {object} options - Options
206220
* @param {number} [options.interval] - Polling interval in milliseconds. Defaults to 250
207-
* @param {number} [options.maxAttempts] - Maximum number of polling attempts. Defaults to no limit
221+
* @param {number} [options.max_attempts] - Maximum number of polling attempts. Defaults to no limit
208222
* @throws {Error} If the prediction doesn't complete within the maximum number of attempts
209223
* @throws {Error} If the prediction failed
210224
* @returns {Promise<object>} Resolves with the completed prediction object
@@ -230,17 +244,17 @@ class Replicate {
230244

231245
let attempts = 0;
232246
const interval = options.interval || 250;
233-
const maxAttempts = options.maxAttempts || null;
247+
const max_attempts = options.max_attempts || null;
234248

235249
while (
236250
updatedPrediction.status !== 'succeeded' &&
237251
updatedPrediction.status !== 'failed' &&
238252
updatedPrediction.status !== 'canceled'
239253
) {
240254
attempts += 1;
241-
if (maxAttempts && attempts > maxAttempts) {
255+
if (max_attempts && attempts > max_attempts) {
242256
throw new Error(
243-
`Prediction ${id} did not finish after ${maxAttempts} attempts`
257+
`Prediction ${id} did not finish after ${max_attempts} attempts`
244258
);
245259
}
246260

index.test.ts

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,24 @@ describe('Replicate client', () => {
152152
expect(prediction.id).toBe('ufawqhfynnddngldkgtslldrkq');
153153
});
154154

155+
test('Passes stream parameter to API endpoint', async () => {
156+
nock(BASE_URL)
157+
.post('/predictions')
158+
.reply(201, (_uri, body) => {
159+
expect(body[ 'stream' ]).toBe(true);
160+
return body
161+
})
162+
163+
await client.predictions.create({
164+
version:
165+
'5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa',
166+
input: {
167+
prompt: 'Tell me a story',
168+
},
169+
stream: true
170+
});
171+
});
172+
155173
test('Throws an error if webhook URL is invalid', async () => {
156174
await expect(async () => {
157175
await client.predictions.create({
@@ -506,7 +524,7 @@ describe('Replicate client', () => {
506524
status: 'processing',
507525
})
508526
.get('/predictions/ufawqhfynnddngldkgtslldrkq')
509-
.reply(200, {
527+
.reply(201, {
510528
id: 'ufawqhfynnddngldkgtslldrkq',
511529
status: 'succeeded',
512530
output: 'foobar',

lib/collections.js

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
* @returns {Promise<object>} - Resolves with the collection data
66
*/
77
async function getCollection(collection_slug) {
8-
return this.request(`/collections/${collection_slug}`, {
8+
const response = await this.request(`/collections/${collection_slug}`, {
99
method: 'GET',
1010
});
11+
12+
return response.json();
1113
}
1214

1315
/**
@@ -16,9 +18,11 @@ async function getCollection(collection_slug) {
1618
* @returns {Promise<object>} - Resolves with the collections data
1719
*/
1820
async function listCollections() {
19-
return this.request('/collections', {
21+
const response = await this.request('/collections', {
2022
method: 'GET',
2123
});
24+
25+
return response.json();
2226
}
2327

2428
module.exports = { get: getCollection, list: listCollections };

lib/error.js

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
/**
2+
* A representation of an API error.
3+
*/
14
class ApiError extends Error {
25
/**
36
* Creates a representation of an API error.

lib/models.js

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
* @returns {Promise<object>} Resolves with the model data
77
*/
88
async function getModel(model_owner, model_name) {
9-
return this.request(`/models/${model_owner}/${model_name}`, {
9+
const response = await this.request(`/models/${model_owner}/${model_name}`, {
1010
method: 'GET',
1111
});
12+
13+
return response.json();
1214
}
1315

1416
/**
@@ -19,9 +21,11 @@ async function getModel(model_owner, model_name) {
1921
* @returns {Promise<object>} Resolves with the list of model versions
2022
*/
2123
async function listModelVersions(model_owner, model_name) {
22-
return this.request(`/models/${model_owner}/${model_name}/versions`, {
24+
const response = await this.request(`/models/${model_owner}/${model_name}/versions`, {
2325
method: 'GET',
2426
});
27+
28+
return response.json();
2529
}
2630

2731
/**
@@ -33,12 +37,14 @@ async function listModelVersions(model_owner, model_name) {
3337
* @returns {Promise<object>} Resolves with the model version data
3438
*/
3539
async function getModelVersion(model_owner, model_name, version_id) {
36-
return this.request(
40+
const response = await this.request(
3741
`/models/${model_owner}/${model_name}/versions/${version_id}`,
3842
{
3943
method: 'GET',
4044
}
4145
);
46+
47+
return response.json();
4248
}
4349

4450
module.exports = {

0 commit comments

Comments
 (0)