diff --git a/Writer/Interface/Wrapper.py b/Writer/Interface/Wrapper.py index 89b93035..14a19880 100644 --- a/Writer/Interface/Wrapper.py +++ b/Writer/Interface/Wrapper.py @@ -39,7 +39,9 @@ def LoadModels(self, Models: list): Provider, ProviderModel, ModelHost, ModelOptions = ( self.GetModelAndProvider(Model) ) - print(f"DEBUG: Loading Model {ProviderModel} from {Provider}@{ModelHost}") + print( + f"DEBUG: Loading Model {ProviderModel} from {Provider}@{ModelHost}" + ) if Provider == "ollama": # Get ollama models (only once) @@ -96,7 +98,30 @@ def LoadModels(self, Models: list): ) elif Provider == "openai": - raise NotImplementedError("OpenAI API not supported") + # Validate OpenAI API Key + if ( + not "OPENAI_API_KEY" in os.environ + or os.environ["OPENAI_API_KEY"] == "" + ): + raise Exception( + "OPENAI_API_KEY not found in environment variables. Add dummy if using local models" + ) + self.ensure_package_is_installed("openai") + import openai + + if ModelHost and not "://" in ModelHost: + raise ValueError( + f"Invalid Model Host URL {ModelHost}. Make sure to include the protocol (http/https)" + ) + + self.Clients[Model] = openai.OpenAI( + api_key=os.environ["OPENAI_API_KEY"], + base_url=( + "https://api.openai.com/v1" + if ModelHost is None + else ModelHost + ), + ) elif Provider == "openrouter": if ( @@ -126,23 +151,27 @@ def SafeGenerateText( _Model: str, _SeedOverride: int = -1, _Format: str = None, - _MinWordCount: int = 1 - ): + _MinWordCount: int = 1, + ): """ This function guarantees that the output will not be whitespace. """ - NewMsg = self.ChatAndStreamResponse(_Logger, _Messages, _Model, _SeedOverride, _Format) + NewMsg = self.ChatAndStreamResponse( + _Logger, _Messages, _Model, _SeedOverride, _Format + ) - while (self.GetLastMessageText(NewMsg).isspace()) or (len(self.GetLastMessageText(NewMsg).split(" ")) < _MinWordCount): + while (self.GetLastMessageText(NewMsg).isspace()) or ( + len(self.GetLastMessageText(NewMsg).split(" ")) < _MinWordCount + ): _Logger.Log("Generation Failed, Reattempting Output", 7) - del _Messages[-1] # Remove failed attempt - NewMsg = self.ChatAndStreamResponse(_Logger, _Messages, _Model, random.randint(0, 99999), _Format) + del _Messages[-1] # Remove failed attempt + NewMsg = self.ChatAndStreamResponse( + _Logger, _Messages, _Model, random.randint(0, 99999), _Format + ) return NewMsg - - def ChatAndStreamResponse( self, _Logger, @@ -310,7 +339,30 @@ def ChatAndStreamResponse( m["role"] = "assistant" elif Provider == "openai": - raise NotImplementedError("OpenAI API not supported") + + while True: + try: + Stream = self.Clients[_Model].chat.completions.create( + messages=_Messages, + stream=True, + model=ProviderModel, + ) + _Messages.append(self.StreamResponse(Stream, Provider)) + break + except Exception as e: + if MaxRetries > 0: + _Logger.Log( + f"Exception During Generation '{e}', {MaxRetries} Retries Remaining", + 7, + ) + MaxRetries -= 1 + else: + _Logger.Log( + f"Max Retries Exceeded During Generation, Aborting!", 7 + ) + raise Exception( + "Generation Failed, Max Retires Exceeded, Aborting" + ) elif Provider == "openrouter": @@ -375,6 +427,8 @@ def StreamResponse(self, _Stream, _Provider: str): ChunkText = chunk["message"]["content"] elif _Provider == "google": ChunkText = chunk.text + elif _Provider == "openai": + ChunkText = chunk.choices[0].delta.content or "" else: raise ValueError(f"Unsupported provider: {_Provider}") @@ -405,23 +459,25 @@ def GetModelAndProvider(self, _Model: str): print(parsed) Provider = parsed.scheme - if "@" in parsed.netloc: - Model, Host = parsed.netloc.split("@") + FullPath = parsed.netloc + parsed.path + + if "@" in FullPath: + Model, Host = FullPath.split("@") elif Provider == "openrouter": - Model = f"{parsed.netloc}{parsed.path}" + Model = FullPath Host = None elif "ollama" in _Model: if "@" in parsed.path: - Model = parsed.netloc + parsed.path.split("@")[0] - Host = parsed.path.split("@")[1] + Model = FullPath.split("@")[0] + Host = FullPath.split("@")[1] else: - Model = parsed.netloc + Model = FullPath Host = "localhost:11434" else: - Model = parsed.netloc + Model = FullPath Host = None QueryParams = parse_qs(parsed.query) diff --git a/requirements.txt b/requirements.txt index 2b8ef125..c0965fc6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ ollama termcolor google.generativeai -dotenv-python \ No newline at end of file +python-dotenv \ No newline at end of file