diff --git a/replicate/client.py b/replicate/client.py index 3e767d6..a808e63 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -26,6 +26,7 @@ from replicate.exceptions import ReplicateError from replicate.file import Files from replicate.hardware import HardwareNamespace as Hardware +from replicate.helpers import FileOutput from replicate.model import Models from replicate.prediction import Predictions from replicate.run import async_run, run @@ -160,19 +161,65 @@ def webhooks(self) -> Webhooks: """ return Webhooks(client=self) + + @overload def run( self, ref: str, input: Optional[Dict[str, Any]] = None, *, - use_file_output: Optional[bool] = True, + use_file_output: None = None, **params: Unpack["Predictions.CreatePredictionParams"], - ) -> Union[Any, Iterator[Any]]: # noqa: ANN401 + ) -> Union[FileOutput, List[FileOutput]]: ... + + @overload + def run( + self, + ref: str, + input: Optional[Dict[str, Any]] = None, + *, + use_file_output: True, + **params: Unpack["Predictions.CreatePredictionParams"], + ) -> Union[FileOutput, List[FileOutput]]: ... + + @overload + def run( + self, + ref: str, + input: Optional[Dict[str, Any]] = None, + *, + use_file_output: False, + **params: Unpack["Predictions.CreatePredictionParams"], + ) -> Union[str, List[str]]: ... + + def run( + self, + ref: str, + input: Optional[Dict[str, Any]] = None, + *, + use_file_output: Optional[bool] = None, + **params: Unpack["Predictions.CreatePredictionParams"], + ) -> Union[Union[str, FileOutput], List[Union[str, FileOutput]]]: """ Run a model and wait for its output. - """ - return run(self, ref, input, use_file_output=use_file_output, **params) + Args: + ref: The reference identifier for the model to run. + input: Optional input data for the model. + use_file_output: If False, returns strings. If True or None (default), returns FileOutput objects. + **params: Additional parameters for prediction creation. + + Returns: + If use_file_output is True or None (default): + - Single prediction: FileOutput + - Multiple predictions: List[FileOutput] + If use_file_output is False: + - Single prediction: str + - Multiple predictions: List[str] + """ + # If use_file_output is None, we treat it as True + effective_use_file_output = True if use_file_output is None else use_file_output + return run(self, ref, input, use_file_output=effective_use_file_output, **params) async def async_run( self,