Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 74 additions & 18 deletions Writer/Interface/Wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def LoadModels(self, Models: list):
Provider, ProviderModel, ModelHost, ModelOptions = (
self.GetModelAndProvider(Model)
)
print(f"DEBUG: Loading Model {ProviderModel} from {Provider}@{ModelHost}")
print(
f"DEBUG: Loading Model {ProviderModel} from {Provider}@{ModelHost}"
)

if Provider == "ollama":
# Get ollama models (only once)
Expand Down Expand Up @@ -96,7 +98,30 @@ def LoadModels(self, Models: list):
)

elif Provider == "openai":
raise NotImplementedError("OpenAI API not supported")
# Validate OpenAI API Key
if (
not "OPENAI_API_KEY" in os.environ
or os.environ["OPENAI_API_KEY"] == ""
):
raise Exception(
"OPENAI_API_KEY not found in environment variables. Add dummy if using local models"
)
self.ensure_package_is_installed("openai")
import openai

if ModelHost and not "://" in ModelHost:
raise ValueError(
f"Invalid Model Host URL {ModelHost}. Make sure to include the protocol (http/https)"
)

self.Clients[Model] = openai.OpenAI(
api_key=os.environ["OPENAI_API_KEY"],
base_url=(
"https://api.openai.com/v1"
if ModelHost is None
else ModelHost
),
)

elif Provider == "openrouter":
if (
Expand Down Expand Up @@ -126,23 +151,27 @@ def SafeGenerateText(
_Model: str,
_SeedOverride: int = -1,
_Format: str = None,
_MinWordCount: int = 1
):
_MinWordCount: int = 1,
):
"""
This function guarantees that the output will not be whitespace.
"""

NewMsg = self.ChatAndStreamResponse(_Logger, _Messages, _Model, _SeedOverride, _Format)
NewMsg = self.ChatAndStreamResponse(
_Logger, _Messages, _Model, _SeedOverride, _Format
)

while (self.GetLastMessageText(NewMsg).isspace()) or (len(self.GetLastMessageText(NewMsg).split(" ")) < _MinWordCount):
while (self.GetLastMessageText(NewMsg).isspace()) or (
len(self.GetLastMessageText(NewMsg).split(" ")) < _MinWordCount
):
_Logger.Log("Generation Failed, Reattempting Output", 7)
del _Messages[-1] # Remove failed attempt
NewMsg = self.ChatAndStreamResponse(_Logger, _Messages, _Model, random.randint(0, 99999), _Format)
del _Messages[-1] # Remove failed attempt
NewMsg = self.ChatAndStreamResponse(
_Logger, _Messages, _Model, random.randint(0, 99999), _Format
)

return NewMsg



def ChatAndStreamResponse(
self,
_Logger,
Expand Down Expand Up @@ -310,7 +339,30 @@ def ChatAndStreamResponse(
m["role"] = "assistant"

elif Provider == "openai":
raise NotImplementedError("OpenAI API not supported")

while True:
try:
Stream = self.Clients[_Model].chat.completions.create(
messages=_Messages,
stream=True,
model=ProviderModel,
)
_Messages.append(self.StreamResponse(Stream, Provider))
break
except Exception as e:
if MaxRetries > 0:
_Logger.Log(
f"Exception During Generation '{e}', {MaxRetries} Retries Remaining",
7,
)
MaxRetries -= 1
else:
_Logger.Log(
f"Max Retries Exceeded During Generation, Aborting!", 7
)
raise Exception(
"Generation Failed, Max Retires Exceeded, Aborting"
)

elif Provider == "openrouter":

Expand Down Expand Up @@ -375,6 +427,8 @@ def StreamResponse(self, _Stream, _Provider: str):
ChunkText = chunk["message"]["content"]
elif _Provider == "google":
ChunkText = chunk.text
elif _Provider == "openai":
ChunkText = chunk.choices[0].delta.content or ""
else:
raise ValueError(f"Unsupported provider: {_Provider}")

Expand Down Expand Up @@ -405,23 +459,25 @@ def GetModelAndProvider(self, _Model: str):
print(parsed)
Provider = parsed.scheme

if "@" in parsed.netloc:
Model, Host = parsed.netloc.split("@")
FullPath = parsed.netloc + parsed.path

if "@" in FullPath:
Model, Host = FullPath.split("@")

elif Provider == "openrouter":
Model = f"{parsed.netloc}{parsed.path}"
Model = FullPath
Host = None

elif "ollama" in _Model:
if "@" in parsed.path:
Model = parsed.netloc + parsed.path.split("@")[0]
Host = parsed.path.split("@")[1]
Model = FullPath.split("@")[0]
Host = FullPath.split("@")[1]
else:
Model = parsed.netloc
Model = FullPath
Host = "localhost:11434"

else:
Model = parsed.netloc
Model = FullPath
Host = None
QueryParams = parse_qs(parsed.query)

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ollama
termcolor
google.generativeai
dotenv-python
python-dotenv