From 74d7be46160fb389215d55308adf8367126d9d5b Mon Sep 17 00:00:00 2001 From: Yiannis Charalambous Date: Tue, 4 Feb 2025 14:05:52 +0000 Subject: [PATCH] Updated api_keys to be a dict --- esbmc_ai/ai_models.py | 28 +++++++++++----------------- esbmc_ai/config.py | 38 +++++++++++++++++++++++--------------- 2 files changed, 34 insertions(+), 32 deletions(-) diff --git a/esbmc_ai/ai_models.py b/esbmc_ai/ai_models.py index a1f66c3..476a19a 100644 --- a/esbmc_ai/ai_models.py +++ b/esbmc_ai/ai_models.py @@ -1,7 +1,7 @@ # Author: Yiannis Charalambous from abc import abstractmethod -from typing import Any, Iterable, Optional, Union +from typing import Any, Iterable, Union from enum import Enum from langchain_core.language_models import BaseChatModel from pydantic.types import SecretStr @@ -17,9 +17,6 @@ ) -from esbmc_ai.api_key_collection import APIKeyCollection - - class AIModel(object): """This base class represents an abstract AI model.""" @@ -37,7 +34,7 @@ def __init__( @abstractmethod def create_llm( self, - api_keys: APIKeyCollection, + api_keys: dict[str, str], temperature: float = 1.0, requests_max_tries: int = 5, requests_timeout: float = 60, @@ -143,15 +140,15 @@ class AIModelOpenAI(AIModel): @override def create_llm( self, - api_keys: APIKeyCollection, + api_keys: dict[str, str], temperature: float = 1.0, requests_max_tries: int = 5, requests_timeout: float = 60, ) -> BaseChatModel: - assert api_keys.openai, "No OpenAI api key has been specified..." + assert "openai" in api_keys, "No OpenAI api key has been specified..." return ChatOpenAI( model=self.name, - api_key=SecretStr(api_keys.openai), + api_key=SecretStr(api_keys["openai"]), max_tokens=None, temperature=temperature, max_retries=requests_max_tries, @@ -199,7 +196,7 @@ def __init__(self, name: str, tokens: int, url: str) -> None: @override def create_llm( self, - api_keys: APIKeyCollection, + api_keys: dict[str, str], temperature: float = 1, requests_max_tries: int = 5, requests_timeout: float = 60, @@ -222,7 +219,6 @@ class _AIModels(Enum): defined because they are fetched from the API.""" # FALCON_7B = OllamaAIModel(...) - pass _custom_ai_models: list[AIModel] = [] @@ -254,9 +250,9 @@ def add_custom_ai_model(ai_model: AIModel) -> None: _custom_ai_models.append(ai_model) -def download_openai_model_names(api_keys: APIKeyCollection) -> list[str]: +def download_openai_model_names(api_keys: dict[str, str]) -> list[str]: """Gets the open AI models from the API service and returns them.""" - assert api_keys and api_keys.openai + assert "openai" in api_keys from openai import Client "llm_requests.open_ai.model_refresh_seconds" @@ -264,15 +260,13 @@ def download_openai_model_names(api_keys: APIKeyCollection) -> list[str]: try: return [ str(model.id) - for model in Client(api_key=api_keys.openai).models.list().data + for model in Client(api_key=api_keys["openai"]).models.list().data ] except ImportError: return [] -def is_valid_ai_model( - ai_model: Union[str, AIModel], api_keys: Optional[APIKeyCollection] = None -) -> bool: +def is_valid_ai_model(ai_model: Union[str, AIModel]) -> bool: """Accepts both the AIModel object and the name as parameter. It checks the openai servers to see if a model is defined on their servers, if not, then it checks the internally defined AI models list.""" @@ -301,4 +295,4 @@ def get_ai_model_by_name(name: str) -> AIModel: if name == custom_ai.name: return custom_ai - raise Exception(f'The AI "{name}" was not found...') + raise ValueError(f'The AI "{name}" was not found...') diff --git a/esbmc_ai/config.py b/esbmc_ai/config.py index 5374eaa..0c300a0 100644 --- a/esbmc_ai/config.py +++ b/esbmc_ai/config.py @@ -33,7 +33,6 @@ OllamaAIModel, download_openai_model_names, ) -from .api_key_collection import APIKeyCollection @dataclass @@ -81,6 +80,14 @@ def __new__(cls): cls.instance = super(Config, cls).__new__(cls) return cls.instance + def __init__(self) -> None: + super().__init__() + self._args: argparse.Namespace + self.api_keys: dict[str, str] = {} + self.raw_conversation: bool = False + self.generate_patches: bool + self.output_dir: Optional[Path] = None + # Define some shortcuts for the values here (instead of having to use get_value) def get_ai_model(self) -> AIModel: @@ -92,7 +99,7 @@ def get_llm_requests_max_tries(self) -> int: return self.get_value("llm_requests.max_tries") def get_llm_requests_timeout(self) -> float: - """""" + """Max timeout for a request when prompting the LLM""" return self.get_value("llm_requests.timeout") def get_user_chat_initial(self) -> BaseMessage: @@ -120,11 +127,7 @@ def init(self, args: Any) -> None: """Will load the config from the args, the env file and then from config file. Call once to initialize.""" - self._args: argparse.Namespace = args - self.api_keys: APIKeyCollection - self.raw_conversation: bool = False - self.generate_patches: bool - self.output_dir: Optional[Path] = None + self._args = args self._load_envs() @@ -141,7 +144,7 @@ def init(self, args: Any) -> None: # Default is to refresh once a day default_value=self._load_openai_model_names(86400), validate=lambda v: isinstance(v, int), - on_load=lambda v: self._load_openai_model_names(v), + on_load=self._load_openai_model_names, error_message="Invalid value, needs to be an int in seconds", ), # This needs to be processed after ai_custom @@ -151,7 +154,7 @@ def init(self, args: Any) -> None: # Api keys are loaded from system env so they are already # available validate=lambda v: isinstance(v, str) and is_valid_ai_model(v), - on_load=lambda v: get_ai_model_by_name(v), + on_load=get_ai_model_by_name, ), ConfigField( name="temp_auto_clean", @@ -181,6 +184,12 @@ def init(self, args: Any) -> None: validate=lambda v: isinstance(v, str) and v in ["full", "single"], error_message="source_code_format can only be 'full' or 'single'", ), + # API Keys is a pseudo-entry, the value is fetched from the class + # itself rather config. + ConfigField( + name="api_keys", + default_value=self.api_keys, + ), ConfigField( name="solution.filenames", default_value=[], @@ -279,7 +288,8 @@ def init(self, args: Any) -> None: name="fix_code.message_history", default_value="normal", validate=lambda v: v in ["normal", "latest_only", "reverse"], - error_message='fix_code.message_history can only be "normal", "latest_only", "reverse"', + error_message='fix_code.message_history can only be "normal", ' + + '"latest_only", "reverse"', ), ConfigField( name="prompt_templates.user_chat.initial", @@ -385,9 +395,7 @@ def get_env_vars() -> None: print(f"Error: No ${key} in environment.") sys.exit(1) - self.api_keys = APIKeyCollection( - openai=str(os.getenv("OPENAI_API_KEY")), - ) + self.api_keys["openai"] = str(os.getenv("OPENAI_API_KEY")) self.cfg_path: Path = Path( os.path.expanduser(os.path.expandvars(str(os.getenv(config_env_name)))) @@ -400,7 +408,7 @@ def _load_args(self) -> None: # AI Model -m if args.ai_model != "": - if is_valid_ai_model(args.ai_model, self.api_keys): + if is_valid_ai_model(args.ai_model): ai_model = get_ai_model_by_name(args.ai_model) self.set_value("ai_model", ai_model) else: @@ -527,7 +535,7 @@ def write_cache(cache: Path) -> list[str]: return models_list duration = timedelta(seconds=refresh_duration_seconds) - if self.api_keys and self.api_keys.openai: + if "openai" in self.api_keys: print("Loading OpenAI models list") models_list: list[str] = [] cache: Path = Path(user_cache_dir("esbmc-ai", "Yiannis Charalambous"))