From 762ba603e5ed1436c77eea6660bc84a6e3038aa2 Mon Sep 17 00:00:00 2001 From: Elkin Andrew Date: Tue, 24 Sep 2024 10:53:08 +0300 Subject: [PATCH] steam completion --- examples/base_usage.py | 4 ++++ minds/minds.py | 22 +++++++++++++++++----- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/examples/base_usage.py b/examples/base_usage.py index 3f39d65..f6454c4 100644 --- a/examples/base_usage.py +++ b/examples/base_usage.py @@ -73,6 +73,10 @@ # call completion print(mind.completion('2+3')) +# stream completion +for chunk in mind.completion('2+3', stream=True): + print(chunk.content) + # --- managing datasources --- # create or replace diff --git a/minds/minds.py b/minds/minds.py index ad450e9..a49e6c0 100644 --- a/minds/minds.py +++ b/minds/minds.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import List, Union, Iterable from urllib.parse import urlparse, urlunparse from openai import OpenAI @@ -90,7 +90,15 @@ def add_datasource(self, datasource: Datasource): def del_datasource(self, datasource: Union[Datasource, str]): raise NotImplementedError - def completion(self, message): + def completion(self, message: str, stream: bool = False) -> Union[str, Iterable[object]]: + """ + Call mind completion + + :param message: input question + :param stream: to enable stream mode + + :return: string if stream mode is off or iterator of ChoiceDelta objects (by openai) + """ parsed = urlparse(self.api.base_url) netloc = parsed.netloc @@ -107,14 +115,18 @@ def completion(self, message): base_url=base_url ) - completion = openai_client.chat.completions.create( + response = openai_client.chat.completions.create( model=self.name, messages=[ {'role': 'user', 'content': message} ], - stream=False + stream=stream ) - return completion.choices[0].message.content + if stream: + for chunk in response: + yield chunk.choices[0].delta + else: + return response.choices[0].message.content class Minds: