Skip to content

Commit b478f16

Browse files
authored
Allow run and stream methods to take model arguments, when supported (#210)
```python import replicate for event in replicate.stream( "meta/llama-2-70b-chat", input={ "prompt": "Please write a haiku about llamas.", }, ): print(str(event)) ``` --------- Signed-off-by: Mattt Zmuda <[email protected]>
1 parent 9736fb0 commit b478f16

File tree

6 files changed

+186
-82
lines changed

6 files changed

+186
-82
lines changed

replicate/identifier.py

+26-25
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,9 @@
11
import re
2-
from typing import NamedTuple
2+
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
33

4-
5-
class ModelIdentifier(NamedTuple):
6-
"""
7-
A reference to a model in the format owner/name:version.
8-
"""
9-
10-
owner: str
11-
name: str
12-
13-
@classmethod
14-
def parse(cls, ref: str) -> "ModelIdentifier":
15-
"""
16-
Split a reference in the format owner/name:version into its components.
17-
"""
18-
19-
match = re.match(r"^(?P<owner>[^/]+)/(?P<name>[^:]+)$", ref)
20-
if not match:
21-
raise ValueError(
22-
f"Invalid reference to model version: {ref}. Expected format: owner/name"
23-
)
24-
25-
return cls(match.group("owner"), match.group("name"))
4+
if TYPE_CHECKING:
5+
from replicate.model import Model
6+
from replicate.version import Version
267

278

289
class ModelVersionIdentifier(NamedTuple):
@@ -32,18 +13,38 @@ class ModelVersionIdentifier(NamedTuple):
3213

3314
owner: str
3415
name: str
35-
version: str
16+
version: Optional[str] = None
3617

3718
@classmethod
3819
def parse(cls, ref: str) -> "ModelVersionIdentifier":
3920
"""
4021
Split a reference in the format owner/name:version into its components.
4122
"""
4223

43-
match = re.match(r"^(?P<owner>[^/]+)/(?P<name>[^:]+):(?P<version>.+)$", ref)
24+
match = re.match(r"^(?P<owner>[^/]+)/(?P<name>[^/:]+)(:(?P<version>.+))?$", ref)
4425
if not match:
4526
raise ValueError(
4627
f"Invalid reference to model version: {ref}. Expected format: owner/name:version"
4728
)
4829

4930
return cls(match.group("owner"), match.group("name"), match.group("version"))
31+
32+
33+
def _resolve(
34+
ref: Union["Model", "Version", "ModelVersionIdentifier", str]
35+
) -> Tuple[Optional["Version"], Optional[str], Optional[str], Optional[str]]:
36+
from replicate.model import Model # pylint: disable=import-outside-toplevel
37+
from replicate.version import Version # pylint: disable=import-outside-toplevel
38+
39+
version = None
40+
owner, name, version_id = None, None, None
41+
if isinstance(ref, Model):
42+
owner, name = ref.owner, ref.name
43+
elif isinstance(ref, Version):
44+
version = ref
45+
version_id = ref.id
46+
elif isinstance(ref, ModelVersionIdentifier):
47+
owner, name, version_id = ref
48+
elif isinstance(ref, str):
49+
owner, name, version_id = ModelVersionIdentifier.parse(ref)
50+
return version, owner, name, version_id

replicate/model.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing_extensions import NotRequired, TypedDict, Unpack, deprecated
44

55
from replicate.exceptions import ReplicateException
6-
from replicate.identifier import ModelIdentifier
6+
from replicate.identifier import ModelVersionIdentifier
77
from replicate.pagination import Page
88
from replicate.prediction import (
99
Prediction,
@@ -296,7 +296,7 @@ class ModelsPredictions(Namespace):
296296

297297
def create(
298298
self,
299-
model: Optional[Union[str, Tuple[str, str], "Model"]],
299+
model: Union[str, Tuple[str, str], "Model"],
300300
input: Dict[str, Any],
301301
**params: Unpack["Predictions.CreatePredictionParams"],
302302
) -> Prediction:
@@ -317,7 +317,7 @@ def create(
317317

318318
async def async_create(
319319
self,
320-
model: Optional[Union[str, Tuple[str, str], "Model"]],
320+
model: Union[str, Tuple[str, str], "Model"],
321321
input: Dict[str, Any],
322322
**params: Unpack["Predictions.CreatePredictionParams"],
323323
) -> Prediction:
@@ -391,7 +391,11 @@ def _create_prediction_url_from_model(
391391
elif isinstance(model, tuple):
392392
owner, name = model[0], model[1]
393393
elif isinstance(model, str):
394-
owner, name = ModelIdentifier.parse(model)
394+
owner, name, version_id = ModelVersionIdentifier.parse(model)
395+
if version_id is not None:
396+
raise ValueError(
397+
f"Invalid reference to model version: {model}. Expected model or reference in the format owner/name"
398+
)
395399

396400
if owner is None or name is None:
397401
raise ValueError(

replicate/run.py

+54-39
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,50 @@
1-
import asyncio
21
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
32

43
from typing_extensions import Unpack
54

5+
from replicate import identifier
66
from replicate.exceptions import ModelError
7-
from replicate.identifier import ModelVersionIdentifier
7+
from replicate.model import Model
8+
from replicate.prediction import Prediction
89
from replicate.schema import make_schema_backwards_compatible
9-
from replicate.version import Versions
10+
from replicate.version import Version, Versions
1011

1112
if TYPE_CHECKING:
1213
from replicate.client import Client
14+
from replicate.identifier import ModelVersionIdentifier
1315
from replicate.prediction import Predictions
1416

1517

1618
def run(
1719
client: "Client",
18-
ref: str,
20+
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
1921
input: Optional[Dict[str, Any]] = None,
2022
**params: Unpack["Predictions.CreatePredictionParams"],
2123
) -> Union[Any, Iterator[Any]]: # noqa: ANN401
2224
"""
2325
Run a model and wait for its output.
2426
"""
2527

26-
owner, name, version_id = ModelVersionIdentifier.parse(ref)
28+
version, owner, name, version_id = identifier._resolve(ref)
2729

28-
prediction = client.predictions.create(
29-
version=version_id, input=input or {}, **params
30-
)
30+
if version_id is not None:
31+
prediction = client.predictions.create(
32+
version=version_id, input=input or {}, **params
33+
)
34+
elif owner and name:
35+
prediction = client.models.predictions.create(
36+
model=(owner, name), input=input or {}, **params
37+
)
38+
else:
39+
raise ValueError(
40+
f"Invalid argument: {ref}. Expected model, version, or reference in the format owner/name or owner/name:version"
41+
)
3142

32-
if owner and name:
43+
if not version and (owner and name and version_id):
3344
version = Versions(client, model=(owner, name)).get(version_id)
3445

35-
# Return an iterator of the output
36-
schema = make_schema_backwards_compatible(
37-
version.openapi_schema, version.cog_version
38-
)
39-
output = schema["components"]["schemas"]["Output"]
40-
if (
41-
output.get("type") == "array"
42-
and output.get("x-cog-array-type") == "iterator"
43-
):
44-
return prediction.output_iterator()
46+
if version and (iterator := _make_output_iterator(version, prediction)):
47+
return iterator
4548

4649
prediction.wait()
4750

@@ -53,42 +56,54 @@ def run(
5356

5457
async def async_run(
5558
client: "Client",
56-
ref: str,
59+
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
5760
input: Optional[Dict[str, Any]] = None,
5861
**params: Unpack["Predictions.CreatePredictionParams"],
5962
) -> Union[Any, Iterator[Any]]: # noqa: ANN401
6063
"""
6164
Run a model and wait for its output asynchronously.
6265
"""
6366

64-
owner, name, version_id = ModelVersionIdentifier.parse(ref)
67+
version, owner, name, version_id = identifier._resolve(ref)
6568

66-
prediction = await client.predictions.async_create(
67-
version=version_id, input=input or {}, **params
68-
)
69+
if version or version_id:
70+
prediction = await client.predictions.async_create(
71+
version=(version or version_id), input=input or {}, **params
72+
)
73+
elif owner and name:
74+
prediction = await client.models.predictions.async_create(
75+
model=(owner, name), input=input or {}, **params
76+
)
77+
else:
78+
raise ValueError(
79+
f"Invalid argument: {ref}. Expected model, version, or reference in the format owner/name or owner/name:version"
80+
)
6981

70-
if owner and name:
71-
version = await Versions(client, model=(owner, name)).async_get(version_id)
82+
if not version and (owner and name and version_id):
83+
version = Versions(client, model=(owner, name)).get(version_id)
7284

73-
# Return an iterator of the output
74-
schema = make_schema_backwards_compatible(
75-
version.openapi_schema, version.cog_version
76-
)
77-
output = schema["components"]["schemas"]["Output"]
78-
if (
79-
output.get("type") == "array"
80-
and output.get("x-cog-array-type") == "iterator"
81-
):
82-
return prediction.output_iterator()
85+
if version and (iterator := _make_output_iterator(version, prediction)):
86+
return iterator
8387

84-
while prediction.status not in ["succeeded", "failed", "canceled"]:
85-
await asyncio.sleep(client.poll_interval)
86-
prediction = await client.predictions.async_get(prediction.id)
88+
prediction.wait()
8789

8890
if prediction.status == "failed":
8991
raise ModelError(prediction.error)
9092

9193
return prediction.output
9294

9395

96+
def _make_output_iterator(
97+
version: Version, prediction: Prediction
98+
) -> Optional[Iterator[Any]]:
99+
schema = make_schema_backwards_compatible(
100+
version.openapi_schema, version.cog_version
101+
)
102+
output = schema["components"]["schemas"]["Output"]
103+
if output.get("type") == "array" and output.get("x-cog-array-type") == "iterator":
104+
return prediction.output_iterator()
105+
106+
return None
107+
108+
94109
__all__: List = []

replicate/stream.py

+38-11
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
Iterator,
88
List,
99
Optional,
10+
Union,
1011
)
1112

1213
from typing_extensions import Unpack
1314

15+
from replicate import identifier
1416
from replicate.exceptions import ReplicateError
15-
from replicate.identifier import ModelVersionIdentifier
1617

1718
try:
1819
from pydantic import v1 as pydantic # type: ignore
@@ -24,7 +25,10 @@
2425
import httpx
2526

2627
from replicate.client import Client
28+
from replicate.identifier import ModelVersionIdentifier
29+
from replicate.model import Model
2730
from replicate.prediction import Predictions
31+
from replicate.version import Version
2832

2933

3034
class ServerSentEvent(pydantic.BaseModel): # type: ignore
@@ -157,7 +161,7 @@ async def __aiter__(self) -> AsyncIterator[ServerSentEvent]:
157161

158162
def stream(
159163
client: "Client",
160-
ref: str,
164+
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
161165
input: Optional[Dict[str, Any]] = None,
162166
**params: Unpack["Predictions.CreatePredictionParams"],
163167
) -> Iterator[ServerSentEvent]:
@@ -168,10 +172,20 @@ def stream(
168172
params = params or {}
169173
params["stream"] = True
170174

171-
_, _, version_id = ModelVersionIdentifier.parse(ref)
172-
prediction = client.predictions.create(
173-
version=version_id, input=input or {}, **params
174-
)
175+
version, owner, name, version_id = identifier._resolve(ref)
176+
177+
if version or version_id:
178+
prediction = client.predictions.create(
179+
version=(version or version_id), input=input or {}, **params
180+
)
181+
elif owner and name:
182+
prediction = client.models.predictions.create(
183+
model=(owner, name), input=input or {}, **params
184+
)
185+
else:
186+
raise ValueError(
187+
f"Invalid argument: {ref}. Expected model, version, or reference in the format owner/name or owner/name:version"
188+
)
175189

176190
url = prediction.urls and prediction.urls.get("stream", None)
177191
if not url or not isinstance(url, str):
@@ -187,7 +201,7 @@ def stream(
187201

188202
async def async_stream(
189203
client: "Client",
190-
ref: str,
204+
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
191205
input: Optional[Dict[str, Any]] = None,
192206
**params: Unpack["Predictions.CreatePredictionParams"],
193207
) -> AsyncIterator[ServerSentEvent]:
@@ -198,10 +212,20 @@ async def async_stream(
198212
params = params or {}
199213
params["stream"] = True
200214

201-
_, _, version_id = ModelVersionIdentifier.parse(ref)
202-
prediction = await client.predictions.async_create(
203-
version=version_id, input=input or {}, **params
204-
)
215+
version, owner, name, version_id = identifier._resolve(ref)
216+
217+
if version or version_id:
218+
prediction = await client.predictions.async_create(
219+
version=(version or version_id), input=input or {}, **params
220+
)
221+
elif owner and name:
222+
prediction = await client.models.predictions.async_create(
223+
model=(owner, name), input=input or {}, **params
224+
)
225+
else:
226+
raise ValueError(
227+
f"Invalid argument: {ref}. Expected model, version, or reference in the format owner/name or owner/name:version"
228+
)
205229

206230
url = prediction.urls and prediction.urls.get("stream", None)
207231
if not url or not isinstance(url, str):
@@ -214,3 +238,6 @@ async def async_stream(
214238
async with client._async_client.stream("GET", url, headers=headers) as response:
215239
async for event in EventSource(response):
216240
yield event
241+
242+
243+
__all__ = ["ServerSentEvent"]

replicate/training.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from typing_extensions import NotRequired, Unpack
1515

1616
from replicate.files import upload_file
17-
from replicate.identifier import ModelIdentifier, ModelVersionIdentifier
17+
from replicate.identifier import ModelVersionIdentifier
1818
from replicate.json import encode_json
1919
from replicate.model import Model
2020
from replicate.pagination import Page
@@ -373,14 +373,14 @@ def _create_training_url_from_shorthand(ref: str) -> str:
373373

374374
def _create_training_url_from_model_and_version(
375375
model: Union[str, Tuple[str, str], "Model"],
376-
version: Union[str, Version],
376+
version: Union[str, "Version"],
377377
) -> str:
378378
if isinstance(model, Model):
379379
owner, name = model.owner, model.name
380380
elif isinstance(model, tuple):
381381
owner, name = model[0], model[1]
382382
elif isinstance(model, str):
383-
owner, name = ModelIdentifier.parse(model)
383+
owner, name, _ = ModelVersionIdentifier.parse(model)
384384
else:
385385
raise ValueError(
386386
"model must be a Model, a tuple of (owner, name), or a string in the format 'owner/name'"

0 commit comments

Comments
 (0)