Skip to content

Commit 1dd0e14

Browse files
committed
feat(openai): add callable support for openai_api_key parameter
1 parent 0c8cbfb commit 1dd0e14

File tree

4 files changed

+60
-36
lines changed

4 files changed

+60
-36
lines changed

libs/partners/openai/langchain_openai/chat_models/base.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,7 @@
1616
from json import JSONDecodeError
1717
from math import ceil
1818
from operator import itemgetter
19-
from typing import (
20-
TYPE_CHECKING,
21-
Any,
22-
Literal,
23-
TypeAlias,
24-
TypeVar,
25-
cast,
26-
)
19+
from typing import TYPE_CHECKING, Any, Literal, TypeAlias, TypeVar, cast
2720
from urllib.parse import urlparse
2821

2922
import certifi
@@ -34,10 +27,7 @@
3427
CallbackManagerForLLMRun,
3528
)
3629
from langchain_core.language_models import LanguageModelInput
37-
from langchain_core.language_models.chat_models import (
38-
BaseChatModel,
39-
LangSmithParams,
40-
)
30+
from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams
4131
from langchain_core.messages import (
4232
AIMessage,
4333
AIMessageChunk,
@@ -96,13 +86,7 @@
9686
is_basemodel_subclass,
9787
)
9888
from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env
99-
from pydantic import (
100-
BaseModel,
101-
ConfigDict,
102-
Field,
103-
SecretStr,
104-
model_validator,
105-
)
89+
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
10690
from pydantic.v1 import BaseModel as BaseModelV1
10791
from typing_extensions import Self
10892

@@ -465,8 +449,11 @@ class BaseChatOpenAI(BaseChatModel):
465449
"""What sampling temperature to use."""
466450
model_kwargs: dict[str, Any] = Field(default_factory=dict)
467451
"""Holds any model parameters valid for `create` call not explicitly specified."""
468-
openai_api_key: SecretStr | None = Field(
469-
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
452+
openai_api_key: SecretStr | None | Callable[[], str] = Field(
453+
alias="api_key",
454+
default_factory=secret_from_env(
455+
["OPENAI_API_KEY", "AZURE_OPENAI_API_KEY"], default=None
456+
),
470457
)
471458
openai_api_base: str | None = Field(default=None, alias="base_url")
472459
"""Base URL path for API requests, leave blank if not using a proxy or service emulator.""" # noqa: E501
@@ -776,10 +763,16 @@ def validate_environment(self) -> Self:
776763
):
777764
self.stream_usage = True
778765

766+
# Resolve API key from SecretStr or Callable
767+
api_key_value = None
768+
if self.openai_api_key is not None:
769+
if isinstance(self.openai_api_key, SecretStr):
770+
api_key_value = self.openai_api_key.get_secret_value()
771+
elif callable(self.openai_api_key):
772+
api_key_value = self.openai_api_key()
773+
779774
client_params: dict = {
780-
"api_key": (
781-
self.openai_api_key.get_secret_value() if self.openai_api_key else None
782-
),
775+
"api_key": api_key_value,
783776
"organization": self.openai_organization,
784777
"base_url": self.openai_api_base,
785778
"timeout": self.request_timeout,

libs/partners/openai/langchain_openai/embeddings/base.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import logging
66
import warnings
7-
from collections.abc import Iterable, Mapping, Sequence
7+
from collections.abc import Callable, Iterable, Mapping, Sequence
88
from typing import Any, Literal, cast
99

1010
import openai
@@ -189,8 +189,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
189189
)
190190
embedding_ctx_length: int = 8191
191191
"""The maximum number of tokens to embed at once."""
192-
openai_api_key: SecretStr | None = Field(
193-
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
192+
openai_api_key: SecretStr | None | Callable[[], str] = Field(
193+
alias="api_key",
194+
default_factory=secret_from_env(
195+
["OPENAI_API_KEY", "AZURE_OPENAI_API_KEY"], default=None
196+
),
194197
)
195198
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
196199
openai_organization: str | None = Field(
@@ -292,10 +295,17 @@ def validate_environment(self) -> Self:
292295
"If you are using Azure, please use the `AzureOpenAIEmbeddings` class."
293296
)
294297
raise ValueError(msg)
298+
299+
# Resolve API key from SecretStr or Callable
300+
api_key_value = None
301+
if self.openai_api_key is not None:
302+
if isinstance(self.openai_api_key, SecretStr):
303+
api_key_value = self.openai_api_key.get_secret_value()
304+
elif callable(self.openai_api_key):
305+
api_key_value = self.openai_api_key()
306+
295307
client_params: dict = {
296-
"api_key": (
297-
self.openai_api_key.get_secret_value() if self.openai_api_key else None
298-
),
308+
"api_key": api_key_value,
299309
"organization": self.openai_organization,
300310
"base_url": self.openai_api_base,
301311
"timeout": self.request_timeout,

libs/partners/openai/langchain_openai/llms/base.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import logging
66
import sys
7-
from collections.abc import AsyncIterator, Collection, Iterator, Mapping
7+
from collections.abc import AsyncIterator, Callable, Collection, Iterator, Mapping
88
from typing import Any, Literal
99

1010
import openai
@@ -186,8 +186,11 @@ class BaseOpenAI(BaseLLM):
186186
"""Generates best_of completions server-side and returns the "best"."""
187187
model_kwargs: dict[str, Any] = Field(default_factory=dict)
188188
"""Holds any model parameters valid for `create` call not explicitly specified."""
189-
openai_api_key: SecretStr | None = Field(
190-
alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None)
189+
openai_api_key: SecretStr | None | Callable[[], str] = Field(
190+
alias="api_key",
191+
default_factory=secret_from_env(
192+
["OPENAI_API_KEY", "AZURE_OPENAI_API_KEY"], default=None
193+
),
191194
)
192195
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
193196
openai_api_base: str | None = Field(
@@ -276,10 +279,16 @@ def validate_environment(self) -> Self:
276279
msg = "Cannot stream results when best_of > 1."
277280
raise ValueError(msg)
278281

282+
# Resolve API key from SecretStr or Callable
283+
api_key_value = None
284+
if self.openai_api_key is not None:
285+
if isinstance(self.openai_api_key, SecretStr):
286+
api_key_value = self.openai_api_key.get_secret_value()
287+
elif callable(self.openai_api_key):
288+
api_key_value = self.openai_api_key()
289+
279290
client_params: dict = {
280-
"api_key": (
281-
self.openai_api_key.get_secret_value() if self.openai_api_key else None
282-
),
291+
"api_key": api_key_value,
283292
"organization": self.openai_organization,
284293
"base_url": self.openai_api_base,
285294
"timeout": self.request_timeout,

libs/partners/openai/tests/unit_tests/test_secrets.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,18 @@ def test_openai_uses_actual_secret_value_from_secretstr(model_class: type) -> No
187187
assert cast(SecretStr, model.openai_api_key).get_secret_value() == "secret-api-key"
188188

189189

190+
@pytest.mark.parametrize("model_class", [ChatOpenAI, OpenAI, OpenAIEmbeddings])
191+
def test_openai_api_key_accepts_callable(model_class: type) -> None:
192+
"""Test that the API key can be passed as a callable."""
193+
194+
def get_api_key() -> str:
195+
return "secret-api-key-from-callable"
196+
197+
model = model_class(openai_api_key=get_api_key)
198+
assert callable(model.openai_api_key)
199+
assert model.openai_api_key() == "secret-api-key-from-callable"
200+
201+
190202
@pytest.mark.parametrize("model_class", [AzureChatOpenAI, AzureOpenAI])
191203
def test_azure_serialized_secrets(model_class: type) -> None:
192204
"""Test that the actual secret value is correctly retrieved."""

0 commit comments

Comments
 (0)