diff --git a/yeagerai/__init__.py b/yeagerai/__init__.py index e69de29..a299047 100644 --- a/yeagerai/__init__.py +++ b/yeagerai/__init__.py @@ -0,0 +1,37 @@ +#from langchain.llms.base import BaseLLM +import langchain.llms + +llm_classes = {cls.__name__: cls for cls in langchain.llms.__dict__.values() + if isinstance(cls, type) and issubclass(cls, langchain.llms.BaseLLM)} + +llm_defaults = { + "ChatOpenAI": { + + }, + "GPT4All": { + "model":"./models/ggml-gpt4all-j.bin", + "n_ctx":512, + "n_threads":8 + } +} + +def SimpleLLMFactory(llm_type: str, **kwargs) -> langchain.llms.BaseLLM: + """Returns a new instance of any BaseLLM available + in 'import langchain.llms' + + Args: + llm_type (str): The type of llm you want + + Raises: + ValueError: Type was not found. Check spelling + + Returns: + BaseLLM: Instanciated with the **kwargs provided + """ + + global llm_classes + global llm_defaults + try: + return llm_classes[llm_type](**{**llm_defaults.get(llm_type,{}), **kwargs}) + except KeyError: + raise ValueError(f"Unknown LLM type: {llm_type}") \ No newline at end of file diff --git a/yeagerai/agent/yeagerai_agent.py b/yeagerai/agent/yeagerai_agent.py index 9c056a6..62aec43 100644 --- a/yeagerai/agent/yeagerai_agent.py +++ b/yeagerai/agent/yeagerai_agent.py @@ -23,6 +23,7 @@ def __init__( username: str, session_id: str, session_path: str, + model_type: str, model_name: str, request_timeout: int, callbacks: List[Callable], @@ -33,6 +34,7 @@ def __init__( self.session_id = session_id self.session_path = session_path self.model_name = model_name + self.model_type = model_type self.request_timeout = request_timeout self.callbacks = callbacks self.context = context @@ -46,13 +48,22 @@ def __init__( input_variables=["input", "intermediate_steps"], chat_history=self.context.chat_buffer_memory.chat_memory, ) - + + llm_args = { + "ChatOpenAI": { + "temperature":0.2, + "model_name": self.model_name, + "request_timeout": self.request_timeout, + } + } + + llm = SimpleLLMFactory( + self.model_type, + kwargs = llm_args.get(self.model_type,{}) + ) + self.llm_chain = LLMChain( - llm=ChatOpenAI( - temperature=0.2, - model_name=self.model_name, - request_timeout=self.request_timeout, - ), + llm=llm, prompt=self.prompt, memory=self.context.chat_buffer_memory, callback_manager=CallbackManager(self.callbacks), @@ -61,12 +72,10 @@ def __init__( self.output_parser = YeagerAIOutputParser() - tool_names = [tool.name for tool in self.yeager_kit.get_tools()] self.agent = LLMSingleActionAgent( llm_chain=self.llm_chain, output_parser=self.output_parser, stop=["\nObservation:"], - allowed_tools=tool_names, ) self.agent_executor = AgentExecutor.from_agent_and_tools( agent=self.agent, diff --git a/yeagerai/interfaces/callbacks/local_file_system_n_git.py b/yeagerai/interfaces/callbacks/local_file_system_n_git.py index 59ed9d4..889b8cc 100644 --- a/yeagerai/interfaces/callbacks/local_file_system_n_git.py +++ b/yeagerai/interfaces/callbacks/local_file_system_n_git.py @@ -11,15 +11,17 @@ ChatPromptTemplate, HumanMessagePromptTemplate, ) +from yeagerai import SimpleLLMFactory class GitLocalRepoCallbackHandler(BaseCallbackHandler): """Callback Handler that creates a local git repo and commits changes.""" - def __init__(self, username: str, session_path: str) -> None: + def __init__(self, username: str, session_path: str, model_type: str) -> None: """Initialize callback handler.""" super().__init__() self.username = username + self.model_type = model_type self.session_path = session_path self.openai_api_key = os.getenv("OPENAI_API_KEY") @@ -45,10 +47,18 @@ def _get_gpt_commit_message(self, repo: Repo) -> str: # Create a prompt template prompt_template = "Explain the following changes in a Git commit message:\n\n{diff_output}\n\nCommit message:" - # Initialize ChatOpenAI with API key and model name - chat = ChatOpenAI( - openai_api_key=self.openai_api_key, model_name="gpt-3.5-turbo" - ) + # Initialize LLM + llm_args = { + "ChatOpenAI": { + "model_name": "gpt-3.5-turbo", + "openai_api_key" : self.openai_api_key + } + } + + chat = SimpleLLMFactory( + self.model_type, + kwargs = llm_args.get(self.model_type,{}) + ) # Create a PromptTemplate instance with the read template master_prompt = PromptTemplate( diff --git a/yeagerai/interfaces/cli.py b/yeagerai/interfaces/cli.py index 7322004..618d9e0 100644 --- a/yeagerai/interfaces/cli.py +++ b/yeagerai/interfaces/cli.py @@ -77,6 +77,7 @@ def pre_load(): def chat_interface( username, model_name, + model_type, request_timeout, streaming, session_id, @@ -94,6 +95,7 @@ def chat_interface( agent = YeagerAIAgent( username=username, model_name=model_name, + model_type=model_type, request_timeout=request_timeout, streaming=streaming, session_id=session_id, @@ -118,6 +120,7 @@ def main(): print("Exiting...") return + model_type = "ChatOpenAI" model_name = "gpt-4" # you can switch to gpt-3.5-turbo but is not tested request_timeout = 300 streaming = True @@ -128,7 +131,7 @@ def main(): # build callbacks callbacks = [ KageBunshinNoJutsu(y_context), - GitLocalRepoCallbackHandler(username=username, session_path=session_path), + GitLocalRepoCallbackHandler(username=username, session_path=session_path, model_type=model_type), ] # toolkit @@ -138,6 +141,7 @@ def main(): api_wrapper=DesignSolutionSketchAPIWrapper( session_path=session_path, model_name=model_name, + model_type=model_type, request_timeout=request_timeout, streaming=streaming, ) @@ -148,6 +152,7 @@ def main(): api_wrapper=CreateToolMockedTestsAPIWrapper( session_path=session_path, model_name=model_name, + model_type=model_type, request_timeout=request_timeout, streaming=streaming, ) @@ -158,6 +163,7 @@ def main(): api_wrapper=CreateToolSourceAPIWrapper( session_path=session_path, model_name=model_name, + model_type=model_type, request_timeout=request_timeout, streaming=streaming, ) @@ -169,6 +175,7 @@ def main(): api_wrapper=LoadNFixNewToolAPIWrapper( session_path=session_path, model_name=model_name, + model_type=model_type, request_timeout=request_timeout, streaming=streaming, toolkit=yeager_kit, diff --git a/yeagerai/interfaces/gradio_chat.py b/yeagerai/interfaces/gradio_chat.py index 22be8e2..3a4450a 100644 --- a/yeagerai/interfaces/gradio_chat.py +++ b/yeagerai/interfaces/gradio_chat.py @@ -53,7 +53,7 @@ def pre_load(): return has_api_key, openai_api_key, username, env_path, root_path -def set_session_variables(username, model_name, request_timeout, root_path, session_id): +def set_session_variables(username, model_name, model_type, request_timeout, root_path, session_id): session_path = os.path.join(root_path, session_id) # build context y_context = YeagerAIContext(username, session_id, session_path) @@ -61,7 +61,7 @@ def set_session_variables(username, model_name, request_timeout, root_path, sess # build callbacks callbacks = [ KageBunshinNoJutsu(y_context), - GitLocalRepoCallbackHandler(username=username, session_path=session_path), + GitLocalRepoCallbackHandler(username=username, session_path=session_path, model_type=model_type) ] # toolkit @@ -71,6 +71,7 @@ def set_session_variables(username, model_name, request_timeout, root_path, sess api_wrapper=DesignSolutionSketchAPIWrapper( session_path=session_path, model_name=model_name, + model_type=model_type, request_timeout=request_timeout, ) ), @@ -80,6 +81,7 @@ def set_session_variables(username, model_name, request_timeout, root_path, sess api_wrapper=CreateToolMockedTestsAPIWrapper( session_path=session_path, model_name=model_name, + model_type=model_type, request_timeout=request_timeout, ) ), @@ -89,6 +91,7 @@ def set_session_variables(username, model_name, request_timeout, root_path, sess api_wrapper=CreateToolSourceAPIWrapper( session_path=session_path, model_name=model_name, + model_type=model_type, request_timeout=request_timeout, ) ), @@ -99,6 +102,7 @@ def set_session_variables(username, model_name, request_timeout, root_path, sess api_wrapper=LoadNFixNewToolAPIWrapper( session_path=session_path, model_name=model_name, + model_type=model_type, request_timeout=request_timeout, toolkit=yeager_kit, ) @@ -112,9 +116,11 @@ def load_state(): session_id = str(uuid.uuid1())[:7] + "-" + username session_path = os.path.join(root_path, session_id) model_name = "gpt-4" + #model_type = "GPT4All" + model_type = "ChatOpenAI" request_timeout = 300 y_context, callbacks, yeager_kit = set_session_variables( - username, model_name, request_timeout, root_path, session_id + username, model_name, model_type, request_timeout, root_path, session_id ) return { "has_api_key": has_api_key, @@ -125,6 +131,7 @@ def load_state(): "session_id": session_id, "session_path": session_path, "model_name": model_name, + "model_type": model_type, "request_timeout": request_timeout, "y_context": y_context, "callbacks": callbacks, @@ -133,10 +140,11 @@ def load_state(): def update_state_from_settings( - session_id, model_name, request_timeout, openai_api_key, session_data + session_id, model_name, model_type, request_timeout, openai_api_key, session_data ): session_data["session_id"] = session_id session_data["model_name"] = model_name + session_data["model_type"] = model_type session_data["request_timeout"] = request_timeout if openai_api_key != session_data["openai_api_key"]: session_data["openai_api_key"] = openai_api_key @@ -157,6 +165,7 @@ def bot(history, session_data): agent = YeagerAIAgent( username=session_data["username"], model_name=session_data["model_name"], + model_type=session_data["model_type"], request_timeout=session_data["request_timeout"], session_id=session_data["session_id"], session_path=session_data["session_path"], @@ -219,6 +228,12 @@ def main(): name="model_name", value=session_data.value["model_name"], ) + model_type_dropdown = gr.Dropdown( + ["ChatOpenAI", "GPT4All"], + label="Model Type", + name="model_type", + value=session_data.value["model_type"], + ) request_timeout_input = gr.Number( name="request_timeout", label="Request Timeout", @@ -243,6 +258,7 @@ def main(): inputs=[ session_id_input, model_name_radio, + model_type_dropdown, request_timeout_input, api_key_input, session_data, diff --git a/yeagerai/models/put local model bin in here.txt b/yeagerai/models/put local model bin in here.txt new file mode 100644 index 0000000..2995a4d --- /dev/null +++ b/yeagerai/models/put local model bin in here.txt @@ -0,0 +1 @@ +dummy \ No newline at end of file diff --git a/yeagerai/toolkit/create_tool_mocked_tests/create_tool_mocked_tests.py b/yeagerai/toolkit/create_tool_mocked_tests/create_tool_mocked_tests.py index dc1e913..edf7c06 100644 --- a/yeagerai/toolkit/create_tool_mocked_tests/create_tool_mocked_tests.py +++ b/yeagerai/toolkit/create_tool_mocked_tests/create_tool_mocked_tests.py @@ -3,6 +3,7 @@ import re from typing import List from pydantic import BaseModel +from yeagerai import SimpleLLMFactory from yeagerai.toolkit.yeagerai_tool import YeagerAITool @@ -21,17 +22,26 @@ class CreateToolMockedTestsAPIWrapper(BaseModel): session_path: str model_name: str + model_type: str request_timeout: int - openai_api_key: str = os.getenv("OPENAI_API_KEY") + openai_api_key: str | None = os.getenv("OPENAI_API_KEY") def run(self, solution_sketch: str) -> str: - # Initialize ChatOpenAI with API key and model name - chat = ChatOpenAI( - openai_api_key=self.openai_api_key, - model_name=self.model_name, - request_timeout=self.request_timeout, - ) + # Initialize LLM + llm_args = { + "ChatOpenAI": { + "model_name": self.model_name, + "openai_api_key": self.openai_api_key, + "request_timeout": self.request_timeout + } + } + + chat = SimpleLLMFactory( + self.model_type, + kwargs = llm_args.get(self.model_type,{}) + ) + # Create a PromptTemplate instance with the read template y_tool_master_prompt = PromptTemplate( input_variables=["solution_sketch"], @@ -47,7 +57,8 @@ def run(self, solution_sketch: str) -> str: out = chain.run(solution_sketch) # Extract the name of the class from the code block - quick_llm = OpenAI(temperature=0) + quick_llm = OpenAI(temperature=0) # TODO How to deal with secondary networks (ChatOpenAI vs OpenAI) + #quick_llm = chat class_name = quick_llm( f"Which is the name of the class that is being tested here? Return only the class_name value like a python string, without any other explanation \n {out}" ).replace("\n", "") diff --git a/yeagerai/toolkit/create_tool_source/create_tool_source.py b/yeagerai/toolkit/create_tool_source/create_tool_source.py index 449044b..faf512e 100644 --- a/yeagerai/toolkit/create_tool_source/create_tool_source.py +++ b/yeagerai/toolkit/create_tool_source/create_tool_source.py @@ -3,6 +3,7 @@ import re from typing import List from pydantic import BaseModel +from yeagerai import SimpleLLMFactory from yeagerai.toolkit.yeagerai_tool import YeagerAITool @@ -20,8 +21,9 @@ class CreateToolSourceAPIWrapper(BaseModel): session_path: str model_name: str + model_type: str request_timeout: int - openai_api_key: str = os.getenv("OPENAI_API_KEY") + openai_api_key: str | None = os.getenv("OPENAI_API_KEY") def run(self, solution_sketch_n_tool_tests: str) -> str: # Split the solution sketch and tool tests @@ -34,11 +36,18 @@ def run(self, solution_sketch_n_tool_tests: str) -> str: )[1] except IndexError: return "You have not provided the split token ######SPLIT_TOKEN########, retry it providing it between the solution sketch and the tool tests." - # Initialize ChatOpenAI with API key and model name - chat = ChatOpenAI( - openai_api_key=self.openai_api_key, - model_name=self.model_name, - request_timeout=self.request_timeout, + # Initialize LLM + llm_args = { + "ChatOpenAI": { + "model_name": self.model_name, + "openai_api_key": self.openai_api_key, + "request_timeout": self.request_timeout + } + } + + chat = SimpleLLMFactory( + self.model_type, + kwargs = llm_args.get(self.model_type,{}) ) # Create a PromptTemplate instance with the read template diff --git a/yeagerai/toolkit/design_solution_sketch/design_solution_sketch.py b/yeagerai/toolkit/design_solution_sketch/design_solution_sketch.py index ca25f62..d325c85 100644 --- a/yeagerai/toolkit/design_solution_sketch/design_solution_sketch.py +++ b/yeagerai/toolkit/design_solution_sketch/design_solution_sketch.py @@ -2,6 +2,7 @@ import os from pydantic import BaseModel +from yeagerai import SimpleLLMFactory from yeagerai.toolkit.yeagerai_tool import YeagerAITool @@ -19,17 +20,26 @@ class DesignSolutionSketchAPIWrapper(BaseModel): session_path: str model_name: str + model_type: str request_timeout: int - openai_api_key: str = os.getenv("OPENAI_API_KEY") + openai_api_key: str | None = os.getenv("OPENAI_API_KEY") def run(self, tool_description_prompt: str) -> str: - # Initialize ChatOpenAI with API key and model name - chat = ChatOpenAI( - openai_api_key=self.openai_api_key, - model_name=self.model_name, - request_timeout=self.request_timeout, + # Initialize LLM + llm_args = { + "ChatOpenAI": { + "model_name": self.model_name, + "openai_api_key": self.openai_api_key, + "request_timeout": self.request_timeout + } + } + + chat = SimpleLLMFactory( + self.model_type, + kwargs = llm_args.get(self.model_type,{}) ) + # Create a PromptTemplate instance with the read template y_tool_master_prompt = PromptTemplate( input_variables=["tool_description_prompt"], diff --git a/yeagerai/toolkit/load_n_fix_new_tool/load_n_fix_new_tool.py b/yeagerai/toolkit/load_n_fix_new_tool/load_n_fix_new_tool.py index 55f0371..cc251ca 100644 --- a/yeagerai/toolkit/load_n_fix_new_tool/load_n_fix_new_tool.py +++ b/yeagerai/toolkit/load_n_fix_new_tool/load_n_fix_new_tool.py @@ -4,6 +4,7 @@ import re from typing import List from pydantic import BaseModel +from yeagerai import SimpleLLMFactory from yeagerai.toolkit.yeagerai_tool import YeagerAITool @@ -22,8 +23,9 @@ class LoadNFixNewToolAPIWrapper(BaseModel): session_path: str model_name: str + model_type: str request_timeout: int - openai_api_key: str = os.getenv("OPENAI_API_KEY") + openai_api_key: str | None = os.getenv("OPENAI_API_KEY") toolkit: YeagerAIToolkit class Config: @@ -44,6 +46,10 @@ def run(self, new_tool_path: str) -> str: try: spec = importlib.util.spec_from_file_location(class_name, new_tool_path) + if spec is None: + raise ImportError(f"Cannot load {class_name} from {new_tool_path}") + if spec.loader is None: + raise ImportError(f"Cannot load loader in {class_name} from {new_tool_path}") myfile = importlib.util.module_from_spec(spec) spec.loader.exec_module(myfile) @@ -53,11 +59,18 @@ def run(self, new_tool_path: str) -> str: self.toolkit.register_tool(class_run(api_wrapper=class_api_wrapper())) except Exception as traceback: - # Initialize ChatOpenAI with API key and model name - chat = ChatOpenAI( - openai_api_key=self.openai_api_key, - model_name=self.model_name, - request_timeout=self.request_timeout, + # Initialize LLM + llm_args = { + "ChatOpenAI": { + "model_name": self.model_name, + "openai_api_key": self.openai_api_key, + "request_timeout": self.request_timeout + } + } + + chat = SimpleLLMFactory( + self.model_type, + kwargs = llm_args.get(self.model_type,{}) ) # Create a PromptTemplate instance with the read template diff --git a/yeagerai/toolkit/yeagerai_toolkit.py b/yeagerai/toolkit/yeagerai_toolkit.py index 015abef..d4ea4f6 100644 --- a/yeagerai/toolkit/yeagerai_toolkit.py +++ b/yeagerai/toolkit/yeagerai_toolkit.py @@ -4,19 +4,19 @@ from typing import List from langchain.agents.agent_toolkits.base import BaseToolkit -from langchain.tools import BaseTool +from yeagerai.toolkit.yeagerai_tool import YeagerAITool -class YeagerAIToolkit: +class YeagerAIToolkit(BaseToolkit): """Toolkit for interacting with a JSON spec.""" def __init__(self) -> None: - self.tools_list: List[BaseTool] = [] + self.tools_list: List[YeagerAITool] = [] - def get_tools(self) -> List[BaseTool]: + def get_tools(self) -> List[YeagerAITool]: """Get the tools in the toolkit.""" return self.tools_list - def register_tool(self, tool: BaseTool): + def register_tool(self, tool: YeagerAITool): """Register a tool to the toolkit.""" self.tools_list.append(tool)