diff --git a/examples/base_usage.py b/examples/base_usage.py index 3f39d65..f6454c4 100644 --- a/examples/base_usage.py +++ b/examples/base_usage.py @@ -73,6 +73,10 @@ # call completion print(mind.completion('2+3')) +# stream completion +for chunk in mind.completion('2+3', stream=True): + print(chunk.content) + # --- managing datasources --- # create or replace diff --git a/minds/minds.py b/minds/minds.py index c095c20..ea203c6 100644 --- a/minds/minds.py +++ b/minds/minds.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import List, Union, Iterable from urllib.parse import urlparse, urlunparse from openai import OpenAI @@ -84,7 +84,15 @@ def add_datasource(self, datasource: Datasource): def del_datasource(self, datasource: Union[Datasource, str]): raise NotImplementedError - def completion(self, message): + def completion(self, message: str, stream: bool = False) -> Union[str, Iterable[object]]: + """ + Call mind completion + + :param message: input question + :param stream: to enable stream mode + + :return: string if stream mode is off or iterator of ChoiceDelta objects (by openai) + """ parsed = urlparse(self.api.base_url) netloc = parsed.netloc @@ -101,14 +109,18 @@ def completion(self, message): base_url=base_url ) - completion = openai_client.chat.completions.create( + response = openai_client.chat.completions.create( model=self.name, messages=[ {'role': 'user', 'content': message} ], - stream=False + stream=stream ) - return completion.choices[0].message.content + if stream: + for chunk in response: + yield chunk.choices[0].delta + else: + return response.choices[0].message.content class Minds: