Skip to content

Fix type hinting #389

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 51 additions & 4 deletions replicate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from replicate.exceptions import ReplicateError
from replicate.file import Files
from replicate.hardware import HardwareNamespace as Hardware
from replicate.helpers import FileOutput
from replicate.model import Models
from replicate.prediction import Predictions
from replicate.run import async_run, run
Expand Down Expand Up @@ -160,19 +161,65 @@ def webhooks(self) -> Webhooks:
"""
return Webhooks(client=self)


@overload
def run(
self,
ref: str,
input: Optional[Dict[str, Any]] = None,
*,
use_file_output: Optional[bool] = True,
use_file_output: None = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Union[Any, Iterator[Any]]: # noqa: ANN401
) -> Union[FileOutput, List[FileOutput]]: ...

@overload
def run(
self,
ref: str,
input: Optional[Dict[str, Any]] = None,
*,
use_file_output: True,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Union[FileOutput, List[FileOutput]]: ...

@overload
def run(
self,
ref: str,
input: Optional[Dict[str, Any]] = None,
*,
use_file_output: False,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Union[str, List[str]]: ...

def run(
self,
ref: str,
input: Optional[Dict[str, Any]] = None,
*,
use_file_output: Optional[bool] = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Union[Union[str, FileOutput], List[Union[str, FileOutput]]]:
"""
Run a model and wait for its output.
"""

return run(self, ref, input, use_file_output=use_file_output, **params)
Args:
ref: The reference identifier for the model to run.
input: Optional input data for the model.
use_file_output: If False, returns strings. If True or None (default), returns FileOutput objects.
**params: Additional parameters for prediction creation.

Returns:
If use_file_output is True or None (default):
- Single prediction: FileOutput
- Multiple predictions: List[FileOutput]
If use_file_output is False:
- Single prediction: str
- Multiple predictions: List[str]
"""
# If use_file_output is None, we treat it as True
effective_use_file_output = True if use_file_output is None else use_file_output
return run(self, ref, input, use_file_output=effective_use_file_output, **params)

async def async_run(
self,
Expand Down