diff --git a/mindsdb_sdk/agents.py b/mindsdb_sdk/agents.py index 6240e3c..26daac5 100644 --- a/mindsdb_sdk/agents.py +++ b/mindsdb_sdk/agents.py @@ -1,5 +1,5 @@ from requests.exceptions import HTTPError -from typing import List, Union +from typing import Iterable, List, Union from urllib.parse import urlparse from uuid import uuid4 import datetime @@ -37,6 +37,12 @@ class Agent: >>> completion = agent.completion([{'question': 'What is your name?', 'answer': None}]) >>> print(completion.content) + Query an agent with streaming: + + >>> completion = agent.completion_stream([{'question': 'What is your name?', 'answer': None}]) + >>> for chunk in completion: + print(chunk.choices[0].delta.content) + List all agents: >>> agents = agents.list() @@ -81,6 +87,9 @@ def __init__( def completion(self, messages: List[dict]) -> AgentCompletion: return self.collection.completion(self.name, messages) + def completion_stream(self, messages: List[dict]) -> Iterable[object]: + return self.collection.completion_stream(self.name, messages) + def add_files(self, file_paths: List[str], description: str, knowledge_base: str = None): """ Add a list of files to the agent for retrieval. @@ -195,6 +204,17 @@ def completion(self, name: str, messages: List[dict]) -> AgentCompletion: data = self.api.agent_completion(self.project, name, messages) return AgentCompletion(data['message']['content']) + def completion_stream(self, name, messages: List[dict]) -> Iterable[object]: + """ + Queries the agent for a completion and streams the response as an iterable object. + + :param name: Name of the agent + :param messageS: List of messages to be sent to the agent + + :return: iterable of completion chunks from querying the agent. + """ + return self.api.agent_completion_stream(self.project, name, messages) + def _create_default_knowledge_base(self, agent: Agent, name: str) -> KnowledgeBase: # Make sure default ML engine for embeddings exists. try: diff --git a/mindsdb_sdk/connectors/rest_api.py b/mindsdb_sdk/connectors/rest_api.py index 67bf089..68c819d 100644 --- a/mindsdb_sdk/connectors/rest_api.py +++ b/mindsdb_sdk/connectors/rest_api.py @@ -1,11 +1,13 @@ from functools import wraps from typing import List, Union import io +import json import requests import pandas as pd from mindsdb_sdk import __about__ +from sseclient import SSEClient def _try_relogin(fnc): @@ -260,6 +262,15 @@ def agent_completion(self, project: str, name: str, messages: List[dict]): return r.json() + @_try_relogin + def agent_completion_stream(self, project: str, name: str, messages: List[dict]): + url = self.url + f'/api/projects/{project}/agents/{name}/completions/stream' + stream = requests.post(url, json={'messages': messages}, stream=True) + client = SSEClient(stream) + for chunk in client.events(): + # Stream objects loaded from SSE events 'data' param. + yield json.loads(chunk.data) + @_try_relogin def create_agent(self, project: str, name: str, model: str, skills: List[str] = None, params: dict = None): url = self.url + f'/api/projects/{project}/agents' diff --git a/requirements.txt b/requirements.txt index b68f773..d3404e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ mindsdb-sql >= 0.17.0, < 1.0.0 docstring-parser >= 0.7.3 tenacity >= 8.0.1 openai >= 1.15.0 +sseclient-py >= 1.8.0