Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support passing OPENAI base url w/ env variable #332

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions src/ragas/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ragas.exceptions import AzureOpenAIKeyNotFound, OpenAIKeyNotFound
from ragas.llms.base import RagasLLM
from ragas.llms.langchain import _compute_token_usage_langchain
from ragas.utils import NO_KEY, get_debug_mode
from ragas.utils import NO_KEY, NO_BASE_URL, get_debug_mode

if t.TYPE_CHECKING:
from langchain.callbacks.base import Callbacks
Expand Down Expand Up @@ -95,7 +95,7 @@ def _before_sleep(retry_state: RetryCallState) -> None:


class OpenAIBase(RagasLLM):
def __init__(self, model: str, _api_key_env_var: str, timeout: int = 60) -> None:
def __init__(self, model: str, _api_key_env_var: str, _base_url_env_var: str, timeout: int = 60) -> None:
self.model = model
self._api_key_env_var = _api_key_env_var
self.timeout = timeout
Expand All @@ -106,6 +106,12 @@ def __init__(self, model: str, _api_key_env_var: str, timeout: int = 60) -> None
self.api_key = key_from_env
else:
self.api_key = self.api_key

# base url
base_url_from_env = os.getenv(self._base_url_env_var, NO_BASE_URL)
if base_url_from_env != NO_BASE_URL:
self.base_url = base_url_from_env

self._client: AsyncClient

@abstractmethod
Expand Down Expand Up @@ -182,13 +188,14 @@ class OpenAI(OpenAIBase):
model: str = "gpt-3.5-turbo-16k"
api_key: str = field(default=NO_KEY, repr=False)
_api_key_env_var: str = "OPENAI_API_KEY"
_base_url_env_var: str = "OPENAI_API_BASE"

def __post_init__(self):
super().__init__(model=self.model, _api_key_env_var=self._api_key_env_var)
super().__init__(model=self.model, _api_key_env_var=self._api_key_env_var, _base_url_env_var=self._base_url_env_var)
self._client_init()

def _client_init(self):
self._client = AsyncOpenAI(api_key=self.api_key, timeout=self.timeout)
self._client = AsyncOpenAI(api_key=self.api_key, base_url = self.base_url, timeout=self.timeout)

def validate_api_key(self):
# before validating, check if the api key is already set
Expand Down
1 change: 1 addition & 0 deletions src/ragas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
DEBUG_ENV_VAR = "RAGAS_DEBUG"
# constant to tell us that there is no key passed to the llm/embeddings
NO_KEY = "no-key"
NO_BASE_URL = "no-base-url"


@lru_cache(maxsize=1)
Expand Down