Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions yeagerai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#from langchain.llms.base import BaseLLM
import langchain.llms

llm_classes = {cls.__name__: cls for cls in langchain.llms.__dict__.values()
if isinstance(cls, type) and issubclass(cls, langchain.llms.BaseLLM)}

llm_defaults = {
"ChatOpenAI": {

},
"GPT4All": {
"model":"./models/ggml-gpt4all-j.bin",
"n_ctx":512,
"n_threads":8
}
}

def SimpleLLMFactory(llm_type: str, **kwargs) -> langchain.llms.BaseLLM:
"""Returns a new instance of any BaseLLM available
in 'import langchain.llms'

Args:
llm_type (str): The type of llm you want

Raises:
ValueError: Type was not found. Check spelling

Returns:
BaseLLM: Instanciated with the **kwargs provided
"""

global llm_classes
global llm_defaults
try:
return llm_classes[llm_type](**{**llm_defaults.get(llm_type,{}), **kwargs})
except KeyError:
raise ValueError(f"Unknown LLM type: {llm_type}")
25 changes: 17 additions & 8 deletions yeagerai/agent/yeagerai_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
username: str,
session_id: str,
session_path: str,
model_type: str,
model_name: str,
request_timeout: int,
callbacks: List[Callable],
Expand All @@ -33,6 +34,7 @@ def __init__(
self.session_id = session_id
self.session_path = session_path
self.model_name = model_name
self.model_type = model_type
self.request_timeout = request_timeout
self.callbacks = callbacks
self.context = context
Expand All @@ -46,13 +48,22 @@ def __init__(
input_variables=["input", "intermediate_steps"],
chat_history=self.context.chat_buffer_memory.chat_memory,
)


llm_args = {
"ChatOpenAI": {
"temperature":0.2,
"model_name": self.model_name,
"request_timeout": self.request_timeout,
}
}

llm = SimpleLLMFactory(
self.model_type,
kwargs = llm_args.get(self.model_type,{})
)

self.llm_chain = LLMChain(
llm=ChatOpenAI(
temperature=0.2,
model_name=self.model_name,
request_timeout=self.request_timeout,
),
llm=llm,
prompt=self.prompt,
memory=self.context.chat_buffer_memory,
callback_manager=CallbackManager(self.callbacks),
Expand All @@ -61,12 +72,10 @@ def __init__(

self.output_parser = YeagerAIOutputParser()

tool_names = [tool.name for tool in self.yeager_kit.get_tools()]
self.agent = LLMSingleActionAgent(
llm_chain=self.llm_chain,
output_parser=self.output_parser,
stop=["\nObservation:"],
allowed_tools=tool_names,
)
self.agent_executor = AgentExecutor.from_agent_and_tools(
agent=self.agent,
Expand Down
20 changes: 15 additions & 5 deletions yeagerai/interfaces/callbacks/local_file_system_n_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
from yeagerai import SimpleLLMFactory


class GitLocalRepoCallbackHandler(BaseCallbackHandler):
"""Callback Handler that creates a local git repo and commits changes."""

def __init__(self, username: str, session_path: str) -> None:
def __init__(self, username: str, session_path: str, model_type: str) -> None:
"""Initialize callback handler."""
super().__init__()
self.username = username
self.model_type = model_type
self.session_path = session_path
self.openai_api_key = os.getenv("OPENAI_API_KEY")

Expand All @@ -45,10 +47,18 @@ def _get_gpt_commit_message(self, repo: Repo) -> str:
# Create a prompt template
prompt_template = "Explain the following changes in a Git commit message:\n\n{diff_output}\n\nCommit message:"

# Initialize ChatOpenAI with API key and model name
chat = ChatOpenAI(
openai_api_key=self.openai_api_key, model_name="gpt-3.5-turbo"
)
# Initialize LLM
llm_args = {
"ChatOpenAI": {
"model_name": "gpt-3.5-turbo",
"openai_api_key" : self.openai_api_key
}
}

chat = SimpleLLMFactory(
self.model_type,
kwargs = llm_args.get(self.model_type,{})
)

# Create a PromptTemplate instance with the read template
master_prompt = PromptTemplate(
Expand Down
9 changes: 8 additions & 1 deletion yeagerai/interfaces/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def pre_load():
def chat_interface(
username,
model_name,
model_type,
request_timeout,
streaming,
session_id,
Expand All @@ -94,6 +95,7 @@ def chat_interface(
agent = YeagerAIAgent(
username=username,
model_name=model_name,
model_type=model_type,
request_timeout=request_timeout,
streaming=streaming,
session_id=session_id,
Expand All @@ -118,6 +120,7 @@ def main():
print("Exiting...")
return

model_type = "ChatOpenAI"
model_name = "gpt-4" # you can switch to gpt-3.5-turbo but is not tested
request_timeout = 300
streaming = True
Expand All @@ -128,7 +131,7 @@ def main():
# build callbacks
callbacks = [
KageBunshinNoJutsu(y_context),
GitLocalRepoCallbackHandler(username=username, session_path=session_path),
GitLocalRepoCallbackHandler(username=username, session_path=session_path, model_type=model_type),
]

# toolkit
Expand All @@ -138,6 +141,7 @@ def main():
api_wrapper=DesignSolutionSketchAPIWrapper(
session_path=session_path,
model_name=model_name,
model_type=model_type,
request_timeout=request_timeout,
streaming=streaming,
)
Expand All @@ -148,6 +152,7 @@ def main():
api_wrapper=CreateToolMockedTestsAPIWrapper(
session_path=session_path,
model_name=model_name,
model_type=model_type,
request_timeout=request_timeout,
streaming=streaming,
)
Expand All @@ -158,6 +163,7 @@ def main():
api_wrapper=CreateToolSourceAPIWrapper(
session_path=session_path,
model_name=model_name,
model_type=model_type,
request_timeout=request_timeout,
streaming=streaming,
)
Expand All @@ -169,6 +175,7 @@ def main():
api_wrapper=LoadNFixNewToolAPIWrapper(
session_path=session_path,
model_name=model_name,
model_type=model_type,
request_timeout=request_timeout,
streaming=streaming,
toolkit=yeager_kit,
Expand Down
24 changes: 20 additions & 4 deletions yeagerai/interfaces/gradio_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ def pre_load():
return has_api_key, openai_api_key, username, env_path, root_path


def set_session_variables(username, model_name, request_timeout, root_path, session_id):
def set_session_variables(username, model_name, model_type, request_timeout, root_path, session_id):
session_path = os.path.join(root_path, session_id)
# build context
y_context = YeagerAIContext(username, session_id, session_path)

# build callbacks
callbacks = [
KageBunshinNoJutsu(y_context),
GitLocalRepoCallbackHandler(username=username, session_path=session_path),
GitLocalRepoCallbackHandler(username=username, session_path=session_path, model_type=model_type)
]

# toolkit
Expand All @@ -71,6 +71,7 @@ def set_session_variables(username, model_name, request_timeout, root_path, sess
api_wrapper=DesignSolutionSketchAPIWrapper(
session_path=session_path,
model_name=model_name,
model_type=model_type,
request_timeout=request_timeout,
)
),
Expand All @@ -80,6 +81,7 @@ def set_session_variables(username, model_name, request_timeout, root_path, sess
api_wrapper=CreateToolMockedTestsAPIWrapper(
session_path=session_path,
model_name=model_name,
model_type=model_type,
request_timeout=request_timeout,
)
),
Expand All @@ -89,6 +91,7 @@ def set_session_variables(username, model_name, request_timeout, root_path, sess
api_wrapper=CreateToolSourceAPIWrapper(
session_path=session_path,
model_name=model_name,
model_type=model_type,
request_timeout=request_timeout,
)
),
Expand All @@ -99,6 +102,7 @@ def set_session_variables(username, model_name, request_timeout, root_path, sess
api_wrapper=LoadNFixNewToolAPIWrapper(
session_path=session_path,
model_name=model_name,
model_type=model_type,
request_timeout=request_timeout,
toolkit=yeager_kit,
)
Expand All @@ -112,9 +116,11 @@ def load_state():
session_id = str(uuid.uuid1())[:7] + "-" + username
session_path = os.path.join(root_path, session_id)
model_name = "gpt-4"
#model_type = "GPT4All"
model_type = "ChatOpenAI"
request_timeout = 300
y_context, callbacks, yeager_kit = set_session_variables(
username, model_name, request_timeout, root_path, session_id
username, model_name, model_type, request_timeout, root_path, session_id
)
return {
"has_api_key": has_api_key,
Expand All @@ -125,6 +131,7 @@ def load_state():
"session_id": session_id,
"session_path": session_path,
"model_name": model_name,
"model_type": model_type,
"request_timeout": request_timeout,
"y_context": y_context,
"callbacks": callbacks,
Expand All @@ -133,10 +140,11 @@ def load_state():


def update_state_from_settings(
session_id, model_name, request_timeout, openai_api_key, session_data
session_id, model_name, model_type, request_timeout, openai_api_key, session_data
):
session_data["session_id"] = session_id
session_data["model_name"] = model_name
session_data["model_type"] = model_type
session_data["request_timeout"] = request_timeout
if openai_api_key != session_data["openai_api_key"]:
session_data["openai_api_key"] = openai_api_key
Expand All @@ -157,6 +165,7 @@ def bot(history, session_data):
agent = YeagerAIAgent(
username=session_data["username"],
model_name=session_data["model_name"],
model_type=session_data["model_type"],
request_timeout=session_data["request_timeout"],
session_id=session_data["session_id"],
session_path=session_data["session_path"],
Expand Down Expand Up @@ -219,6 +228,12 @@ def main():
name="model_name",
value=session_data.value["model_name"],
)
model_type_dropdown = gr.Dropdown(
["ChatOpenAI", "GPT4All"],
label="Model Type",
name="model_type",
value=session_data.value["model_type"],
)
request_timeout_input = gr.Number(
name="request_timeout",
label="Request Timeout",
Expand All @@ -243,6 +258,7 @@ def main():
inputs=[
session_id_input,
model_name_radio,
model_type_dropdown,
request_timeout_input,
api_key_input,
session_data,
Expand Down
1 change: 1 addition & 0 deletions yeagerai/models/put local model bin in here.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
dummy
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
from typing import List
from pydantic import BaseModel
from yeagerai import SimpleLLMFactory

from yeagerai.toolkit.yeagerai_tool import YeagerAITool

Expand All @@ -21,17 +22,26 @@
class CreateToolMockedTestsAPIWrapper(BaseModel):
session_path: str
model_name: str
model_type: str
request_timeout: int
openai_api_key: str = os.getenv("OPENAI_API_KEY")
openai_api_key: str | None = os.getenv("OPENAI_API_KEY")

def run(self, solution_sketch: str) -> str:
# Initialize ChatOpenAI with API key and model name
chat = ChatOpenAI(
openai_api_key=self.openai_api_key,
model_name=self.model_name,
request_timeout=self.request_timeout,
)

# Initialize LLM
llm_args = {
"ChatOpenAI": {
"model_name": self.model_name,
"openai_api_key": self.openai_api_key,
"request_timeout": self.request_timeout
}
}

chat = SimpleLLMFactory(
self.model_type,
kwargs = llm_args.get(self.model_type,{})
)

# Create a PromptTemplate instance with the read template
y_tool_master_prompt = PromptTemplate(
input_variables=["solution_sketch"],
Expand All @@ -47,7 +57,8 @@ def run(self, solution_sketch: str) -> str:
out = chain.run(solution_sketch)

# Extract the name of the class from the code block
quick_llm = OpenAI(temperature=0)
quick_llm = OpenAI(temperature=0) # TODO How to deal with secondary networks (ChatOpenAI vs OpenAI)
#quick_llm = chat
class_name = quick_llm(
f"Which is the name of the class that is being tested here? Return only the class_name value like a python string, without any other explanation \n {out}"
).replace("\n", "")
Expand Down
Loading