Skip to content

Commit

Permalink
Merge pull request #965 from superagent-ai/refactor/agent-classes
Browse files Browse the repository at this point in the history
Native function calling
  • Loading branch information
elisalimli authored Apr 29, 2024
2 parents 5b76bc1 + cb09d29 commit bb29b68
Show file tree
Hide file tree
Showing 17 changed files with 730 additions and 438 deletions.
179 changes: 125 additions & 54 deletions libs/superagent/app/agents/base.py
Original file line number Diff line number Diff line change
@@ -1,106 +1,177 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional

from app.models.request import LLMParams
from langchain.agents import AgentExecutor
from pydantic import BaseModel

from app.models.request import LLMParams as LLMParamsRequest
from app.utils.callbacks import CustomAsyncIteratorCallbackHandler
from prisma.enums import AgentType
from prisma.models import Agent
from prisma.enums import AgentType, LLMProvider
from prisma.models import LLM, Agent


class LLMParams(BaseModel):
temperature: Optional[float] = 0.1
max_tokens: Optional[int]
aws_access_key_id: Optional[str] = None
aws_secret_access_key: Optional[str] = None
aws_region_name: Optional[str] = None


DEFAULT_PROMPT = (
"You are a helpful AI Assistant, answer the users questions to "
"the best of your ability."
)
class LLMData(BaseModel):
llm: LLM
params: LLMParams
model: str


class AgentBase:
class AgentBase(ABC):
_input: str
_messages: list = []
prompt: Any
tools: Any
session_id: str
enable_streaming: bool
output_schema: str
callbacks: List[CustomAsyncIteratorCallbackHandler]
agent_data: Agent
llm_data: LLMData

def __init__(
self,
agent_id: str,
session_id: str = None,
session_id: str,
enable_streaming: bool = False,
output_schema: str = None,
callbacks: List[CustomAsyncIteratorCallbackHandler] = [],
llm_params: Optional[LLMParams] = {},
agent_config: Agent = None,
llm_data: LLMData = None,
agent_data: Agent = None,
):
self.agent_id = agent_id
self.session_id = session_id
self.enable_streaming = enable_streaming
self.output_schema = output_schema
self.callbacks = callbacks
self.llm_params = llm_params
self.agent_config = agent_config
self.llm_data = llm_data
self.agent_data = agent_data

async def _get_tools(
self,
) -> List:
raise NotImplementedError
@property
def input(self):
return self._input

async def _get_llm(
self,
) -> Any:
raise NotImplementedError
@input.setter
def input(self, value: str):
self._input = value

async def _get_prompt(
@property
def messages(self):
return self._messages

@messages.setter
def messages(self, value: list):
self._messages = value

@property
@abstractmethod
def prompt(self) -> Any:
...

@property
@abstractmethod
def tools(self) -> Any:
...

@abstractmethod
def get_agent(self) -> AgentExecutor:
...


class AgentFactory:
def __init__(
self,
) -> str:
raise NotImplementedError
session_id: str = None,
enable_streaming: bool = False,
output_schema: str = None,
callbacks: List[CustomAsyncIteratorCallbackHandler] = [],
llm_params: Optional[LLMParamsRequest] = {},
agent_data: Agent = None,
):
self.session_id = session_id
self.enable_streaming = enable_streaming
self.output_schema = output_schema
self.callbacks = callbacks
self.api_llm_params = llm_params
self.agent_data = agent_data

@property
def llm_data(self):
llm = self.agent_data.llms[0].llm
params = self.api_llm_params.dict() if self.api_llm_params else {}

options = {
**(self.agent_data.metadata or {}),
**(llm.options or {}),
**(params),
}

async def _get_memory(self) -> List:
raise NotImplementedError
params = LLMParams(
temperature=options.get("temperature"),
max_tokens=options.get("max_tokens"),
aws_access_key_id=(
options.get("aws_access_key_id")
if llm.provider == LLMProvider.BEDROCK
else None
),
aws_secret_access_key=(
options.get("aws_secret_access_key")
if llm.provider == LLMProvider.BEDROCK
else None
),
aws_region_name=(
options.get("aws_region_name")
if llm.provider == LLMProvider.BEDROCK
else None
),
)

return LLMData(
llm=llm,
params=LLMParams.parse_obj(options),
model=self.agent_data.llmModel or self.agent_data.metadata.get("model"),
)

async def get_agent(self):
if self.agent_config.type == AgentType.OPENAI_ASSISTANT:
if self.agent_data.type == AgentType.OPENAI_ASSISTANT:
from app.agents.openai import OpenAiAssistant

agent = OpenAiAssistant(
agent_id=self.agent_id,
session_id=self.session_id,
enable_streaming=self.enable_streaming,
output_schema=self.output_schema,
callbacks=self.callbacks,
llm_params=self.llm_params,
agent_config=self.agent_config,
llm_data=self.llm_data,
agent_data=self.agent_data,
)

elif self.agent_config.type == AgentType.LLM:
elif self.agent_data.type == AgentType.LLM:
from app.agents.llm import LLMAgent

agent = LLMAgent(
agent_id=self.agent_id,
session_id=self.session_id,
enable_streaming=self.enable_streaming,
output_schema=self.output_schema,
callbacks=self.callbacks,
llm_params=self.llm_params,
agent_config=self.agent_config,
llm_data=self.llm_data,
agent_data=self.agent_data,
)

else:
from app.agents.langchain import LangchainAgent

agent = LangchainAgent(
agent_id=self.agent_id,
session_id=self.session_id,
enable_streaming=self.enable_streaming,
output_schema=self.output_schema,
callbacks=self.callbacks,
llm_params=self.llm_params,
agent_config=self.agent_config,
llm_data=self.llm_data,
agent_data=self.agent_data,
)

return await agent.get_agent()

def get_input(self, input: str, agent_type: AgentType):
agent_input = {
"input": input,
}

if agent_type == AgentType.OPENAI_ASSISTANT:
agent_input = {
"content": input,
}

if agent_type == AgentType.LLM:
agent_input = input

return agent_input
Loading

0 comments on commit bb29b68

Please sign in to comment.