diff --git a/libs/partners/openai/langchain_openai/chat_models/_client_utils.py b/libs/partners/openai/langchain_openai/chat_models/_client_utils.py index 3eba5c7309be5..4a0efce9e1bfe 100644 --- a/libs/partners/openai/langchain_openai/chat_models/_client_utils.py +++ b/libs/partners/openai/langchain_openai/chat_models/_client_utils.py @@ -9,11 +9,14 @@ from __future__ import annotations import asyncio +import inspect import os +from collections.abc import Awaitable, Callable from functools import lru_cache -from typing import Any +from typing import Any, cast import openai +from pydantic import SecretStr class _SyncHttpxClientWrapper(openai.DefaultHttpxClient): @@ -107,3 +110,33 @@ def _get_default_async_httpx_client( return _build_async_httpx_client(base_url, timeout) else: return _cached_async_httpx_client(base_url, timeout) + + +def _resolve_sync_and_async_api_keys( + api_key: SecretStr | Callable[[], str] | Callable[[], Awaitable[str]], +) -> tuple[str | None | Callable[[], str], str | Callable[[], Awaitable[str]]]: + """Resolve sync and async API key values. + + Because OpenAI and AsyncOpenAI clients support either sync or async callables for + the API key, we need to resolve separate values here. + """ + if isinstance(api_key, SecretStr): + sync_api_key_value: str | None | Callable[[], str] = api_key.get_secret_value() + async_api_key_value: str | Callable[[], Awaitable[str]] = ( + api_key.get_secret_value() + ) + elif callable(api_key): + if inspect.iscoroutinefunction(api_key): + async_api_key_value = api_key + sync_api_key_value = None + else: + sync_api_key_value = cast(Callable, api_key) + + async def async_api_key_wrapper() -> str: + return await asyncio.get_event_loop().run_in_executor( + None, cast(Callable, api_key) + ) + + async_api_key_value = async_api_key_wrapper + + return sync_api_key_value, async_api_key_value diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index a3eee17aa5899..8303242ebbc8c 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -10,7 +10,14 @@ import ssl import sys import warnings -from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence +from collections.abc import ( + AsyncIterator, + Awaitable, + Callable, + Iterator, + Mapping, + Sequence, +) from functools import partial from io import BytesIO from json import JSONDecodeError @@ -109,6 +116,7 @@ from langchain_openai.chat_models._client_utils import ( _get_default_async_httpx_client, _get_default_httpx_client, + _resolve_sync_and_async_api_keys, ) from langchain_openai.chat_models._compat import ( _convert_from_v1_to_chat_completions, @@ -465,9 +473,57 @@ class BaseChatOpenAI(BaseChatModel): """What sampling temperature to use.""" model_kwargs: dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" - openai_api_key: SecretStr | None = Field( + openai_api_key: ( + SecretStr | None | Callable[[], str] | Callable[[], Awaitable[str]] + ) = Field( alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None) ) + """API key to use. + + Can be inferred from the `OPENAI_API_KEY` environment variable, or specified as a + string, or sync or async callable that returns a string. + + ??? example "Specify with environment variable" + + ```bash + export OPENAI_API_KEY=... + ``` + ```python + from langchain_openai import ChatOpenAI + + model = ChatOpenAI(model="gpt-5-nano") + ``` + + ??? example "Specify with a string" + + ```python + from langchain_openai import ChatOpenAI + + model = ChatOpenAI(model="gpt-5-nano", api_key="...") + ``` + + ??? example "Specify with a sync callable" + ```python + from langchain_openai import ChatOpenAI + + def get_api_key() -> str: + # Custom logic to retrieve API key + return "..." + + model = ChatOpenAI(model="gpt-5-nano", api_key=get_api_key) + ``` + + ??? example "Specify with an async callable" + ```python + from langchain_openai import ChatOpenAI + + async def get_api_key() -> str: + # Custom async logic to retrieve API key + return "..." + + model = ChatOpenAI(model="gpt-5-nano", api_key=get_api_key) + ``` + """ openai_api_base: str | None = Field(default=None, alias="base_url") """Base URL path for API requests, leave blank if not using a proxy or service emulator.""" # noqa: E501 openai_organization: str | None = Field(default=None, alias="organization") @@ -776,10 +832,18 @@ def validate_environment(self) -> Self: ): self.stream_usage = True + # Resolve API key from SecretStr or Callable + sync_api_key_value: str | Callable[[], str] | None = None + async_api_key_value: str | Callable[[], Awaitable[str]] | None = None + + if self.openai_api_key is not None: + # Because OpenAI and AsyncOpenAI clients support either sync or async + # callables for the API key, we need to resolve separate values here. + sync_api_key_value, async_api_key_value = _resolve_sync_and_async_api_keys( + self.openai_api_key + ) + client_params: dict = { - "api_key": ( - self.openai_api_key.get_secret_value() if self.openai_api_key else None - ), "organization": self.openai_organization, "base_url": self.openai_api_base, "timeout": self.request_timeout, @@ -800,24 +864,33 @@ def validate_environment(self) -> Self: ) raise ValueError(msg) if not self.client: - if self.openai_proxy and not self.http_client: - try: - import httpx - except ImportError as e: - msg = ( - "Could not import httpx python package. " - "Please install it with `pip install httpx`." + if sync_api_key_value is None: + # No valid sync API key, leave client as None and raise informative + # error on invocation. + self.client = None + self.root_client = None + else: + if self.openai_proxy and not self.http_client: + try: + import httpx + except ImportError as e: + msg = ( + "Could not import httpx python package. " + "Please install it with `pip install httpx`." + ) + raise ImportError(msg) from e + self.http_client = httpx.Client( + proxy=self.openai_proxy, verify=global_ssl_context ) - raise ImportError(msg) from e - self.http_client = httpx.Client( - proxy=self.openai_proxy, verify=global_ssl_context - ) - sync_specific = { - "http_client": self.http_client - or _get_default_httpx_client(self.openai_api_base, self.request_timeout) - } - self.root_client = openai.OpenAI(**client_params, **sync_specific) # type: ignore[arg-type] - self.client = self.root_client.chat.completions + sync_specific = { + "http_client": self.http_client + or _get_default_httpx_client( + self.openai_api_base, self.request_timeout + ), + "api_key": sync_api_key_value, + } + self.root_client = openai.OpenAI(**client_params, **sync_specific) # type: ignore[arg-type] + self.client = self.root_client.chat.completions if not self.async_client: if self.openai_proxy and not self.http_async_client: try: @@ -835,7 +908,8 @@ def validate_environment(self) -> Self: "http_client": self.http_async_client or _get_default_async_httpx_client( self.openai_api_base, self.request_timeout - ) + ), + "api_key": async_api_key_value, } self.root_async_client = openai.AsyncOpenAI( **client_params, @@ -965,6 +1039,16 @@ def _convert_chunk_to_generation_chunk( message=message_chunk, generation_info=generation_info or None ) + def _ensure_sync_client_available(self) -> None: + """Check that sync client is available, raise error if not.""" + if self.client is None: + msg = ( + "Sync client is not available. This happens when an async callable " + "was provided for the API key. Use async methods (ainvoke, astream) " + "instead, or provide a string or sync callable for the API key." + ) + raise ValueError(msg) + def _stream_responses( self, messages: list[BaseMessage], @@ -972,6 +1056,7 @@ def _stream_responses( run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: + self._ensure_sync_client_available() kwargs["stream"] = True payload = self._get_request_payload(messages, stop=stop, **kwargs) if self.include_response_headers: @@ -1101,6 +1186,7 @@ def _stream( stream_usage: bool | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: + self._ensure_sync_client_available() kwargs["stream"] = True stream_usage = self._should_stream_usage(stream_usage, **kwargs) if stream_usage: @@ -1169,6 +1255,7 @@ def _generate( run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: + self._ensure_sync_client_available() payload = self._get_request_payload(messages, stop=stop, **kwargs) generation_info = None raw_response = None diff --git a/libs/partners/openai/langchain_openai/embeddings/base.py b/libs/partners/openai/langchain_openai/embeddings/base.py index b0b5c9f770d75..f53640b02ed40 100644 --- a/libs/partners/openai/langchain_openai/embeddings/base.py +++ b/libs/partners/openai/langchain_openai/embeddings/base.py @@ -4,7 +4,7 @@ import logging import warnings -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Awaitable, Callable, Iterable, Mapping, Sequence from typing import Any, Literal, cast import openai @@ -15,6 +15,8 @@ from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator from typing_extensions import Self +from langchain_openai.chat_models._client_utils import _resolve_sync_and_async_api_keys + logger = logging.getLogger(__name__) @@ -189,7 +191,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): ) embedding_ctx_length: int = 8191 """The maximum number of tokens to embed at once.""" - openai_api_key: SecretStr | None = Field( + openai_api_key: ( + SecretStr | None | Callable[[], str] | Callable[[], Awaitable[str]] + ) = Field( alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None) ) """Automatically inferred from env var `OPENAI_API_KEY` if not provided.""" @@ -292,10 +296,19 @@ def validate_environment(self) -> Self: "If you are using Azure, please use the `AzureOpenAIEmbeddings` class." ) raise ValueError(msg) + + # Resolve API key from SecretStr or Callable + sync_api_key_value: str | Callable[[], str] | None = None + async_api_key_value: str | Callable[[], Awaitable[str]] | None = None + + if self.openai_api_key is not None: + # Because OpenAI and AsyncOpenAI clients support either sync or async + # callables for the API key, we need to resolve separate values here. + sync_api_key_value, async_api_key_value = _resolve_sync_and_async_api_keys( + self.openai_api_key + ) + client_params: dict = { - "api_key": ( - self.openai_api_key.get_secret_value() if self.openai_api_key else None - ), "organization": self.openai_organization, "base_url": self.openai_api_base, "timeout": self.request_timeout, @@ -315,18 +328,26 @@ def validate_environment(self) -> Self: ) raise ValueError(msg) if not self.client: - if self.openai_proxy and not self.http_client: - try: - import httpx - except ImportError as e: - msg = ( - "Could not import httpx python package. " - "Please install it with `pip install httpx`." - ) - raise ImportError(msg) from e - self.http_client = httpx.Client(proxy=self.openai_proxy) - sync_specific = {"http_client": self.http_client} - self.client = openai.OpenAI(**client_params, **sync_specific).embeddings # type: ignore[arg-type] + if sync_api_key_value is None: + # No valid sync API key, leave client as None and raise informative + # error on invocation. + self.client = None + else: + if self.openai_proxy and not self.http_client: + try: + import httpx + except ImportError as e: + msg = ( + "Could not import httpx python package. " + "Please install it with `pip install httpx`." + ) + raise ImportError(msg) from e + self.http_client = httpx.Client(proxy=self.openai_proxy) + sync_specific = { + "http_client": self.http_client, + "api_key": sync_api_key_value, + } + self.client = openai.OpenAI(**client_params, **sync_specific).embeddings # type: ignore[arg-type] if not self.async_client: if self.openai_proxy and not self.http_async_client: try: @@ -338,7 +359,10 @@ def validate_environment(self) -> Self: ) raise ImportError(msg) from e self.http_async_client = httpx.AsyncClient(proxy=self.openai_proxy) - async_specific = {"http_client": self.http_async_client} + async_specific = { + "http_client": self.http_async_client, + "api_key": async_api_key_value, + } self.async_client = openai.AsyncOpenAI( **client_params, **async_specific, # type: ignore[arg-type] @@ -352,6 +376,16 @@ def _invocation_params(self) -> dict[str, Any]: params["dimensions"] = self.dimensions return params + def _ensure_sync_client_available(self) -> None: + """Check that sync client is available, raise error if not.""" + if self.client is None: + msg = ( + "Sync client is not available. This happens when an async callable " + "was provided for the API key. Use async methods (ainvoke, astream) " + "instead, or provide a string or sync callable for the API key." + ) + raise ValueError(msg) + def _tokenize( self, texts: list[str], chunk_size: int ) -> tuple[Iterable[int], list[list[int] | str], list[int]]: @@ -571,6 +605,7 @@ def embed_documents( Returns: List of embeddings, one for each text. """ + self._ensure_sync_client_available() chunk_size_ = chunk_size or self.chunk_size client_kwargs = {**self._invocation_params, **kwargs} if not self.check_embedding_ctx_length: @@ -635,6 +670,7 @@ def embed_query(self, text: str, **kwargs: Any) -> list[float]: Returns: Embedding for the text. """ + self._ensure_sync_client_available() return self.embed_documents([text], **kwargs)[0] async def aembed_query(self, text: str, **kwargs: Any) -> list[float]: diff --git a/libs/partners/openai/langchain_openai/llms/base.py b/libs/partners/openai/langchain_openai/llms/base.py index ed456ec89df39..f983f3982390d 100644 --- a/libs/partners/openai/langchain_openai/llms/base.py +++ b/libs/partners/openai/langchain_openai/llms/base.py @@ -4,7 +4,7 @@ import logging import sys -from collections.abc import AsyncIterator, Collection, Iterator, Mapping +from collections.abc import AsyncIterator, Callable, Collection, Iterator, Mapping from typing import Any, Literal import openai @@ -186,7 +186,7 @@ class BaseOpenAI(BaseLLM): """Generates best_of completions server-side and returns the "best".""" model_kwargs: dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" - openai_api_key: SecretStr | None = Field( + openai_api_key: SecretStr | None | Callable[[], str] = Field( alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None) ) """Automatically inferred from env var `OPENAI_API_KEY` if not provided.""" @@ -276,10 +276,16 @@ def validate_environment(self) -> Self: msg = "Cannot stream results when best_of > 1." raise ValueError(msg) + # Resolve API key from SecretStr or Callable + api_key_value: str | Callable[[], str] | None = None + if self.openai_api_key is not None: + if isinstance(self.openai_api_key, SecretStr): + api_key_value = self.openai_api_key.get_secret_value() + elif callable(self.openai_api_key): + api_key_value = self.openai_api_key + client_params: dict = { - "api_key": ( - self.openai_api_key.get_secret_value() if self.openai_api_key else None - ), + "api_key": api_key_value, "organization": self.openai_organization, "base_url": self.openai_api_base, "timeout": self.request_timeout, diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index 945e83124d324..6c0193de58f94 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -2,6 +2,7 @@ import base64 import json +import os from collections.abc import AsyncIterator from pathlib import Path from textwrap import dedent @@ -64,6 +65,57 @@ def test_chat_openai_model() -> None: assert chat.model_name == "bar" +def test_callable_api_key(monkeypatch: pytest.MonkeyPatch) -> None: + original_key = os.environ["OPENAI_API_KEY"] + + calls = {"sync": 0} + + def get_openai_api_key() -> str: + calls["sync"] += 1 + return original_key + + monkeypatch.delenv("OPENAI_API_KEY") + + model = ChatOpenAI(model="gpt-4.1-mini", api_key=get_openai_api_key) + response = model.invoke("hello") + assert isinstance(response, AIMessage) + assert calls["sync"] == 1 + + +async def test_callable_api_key_async(monkeypatch: pytest.MonkeyPatch) -> None: + original_key = os.environ["OPENAI_API_KEY"] + + calls = {"sync": 0, "async": 0} + + def get_openai_api_key() -> str: + calls["sync"] += 1 + return original_key + + async def get_openai_api_key_async() -> str: + calls["async"] += 1 + return original_key + + monkeypatch.delenv("OPENAI_API_KEY") + + model = ChatOpenAI(model="gpt-4.1-mini", api_key=get_openai_api_key) + response = model.invoke("hello") + assert isinstance(response, AIMessage) + assert calls["sync"] == 1 + + response = await model.ainvoke("hello") + assert isinstance(response, AIMessage) + assert calls["sync"] == 2 + + model = ChatOpenAI(model="gpt-4.1-mini", api_key=get_openai_api_key_async) + async_response = await model.ainvoke("hello") + assert isinstance(async_response, AIMessage) + assert calls["async"] == 1 + + with pytest.raises(ValueError): + # We do not create a sync callable from an async one + _ = model.invoke("hello") + + @pytest.mark.parametrize("use_responses_api", [False, True]) def test_chat_openai_system_message(use_responses_api: bool) -> None: """Test ChatOpenAI wrapper with system message.""" diff --git a/libs/partners/openai/tests/integration_tests/embeddings/test_base.py b/libs/partners/openai/tests/integration_tests/embeddings/test_base.py index 321edcfc0fb82..95a0385945581 100644 --- a/libs/partners/openai/tests/integration_tests/embeddings/test_base.py +++ b/libs/partners/openai/tests/integration_tests/embeddings/test_base.py @@ -1,7 +1,10 @@ """Test OpenAI embeddings.""" +import os + import numpy as np import openai +import pytest from langchain_openai.embeddings.base import OpenAIEmbeddings @@ -67,3 +70,56 @@ def test_langchain_openai_embeddings_dimensions_large_num() -> None: output = embedding.embed_documents(documents) assert len(output) == 2000 assert len(output[0]) == 128 + + +def test_callable_api_key(monkeypatch: pytest.MonkeyPatch) -> None: + original_key = os.environ["OPENAI_API_KEY"] + + calls = {"sync": 0} + + def get_openai_api_key() -> str: + calls["sync"] += 1 + return original_key + + monkeypatch.delenv("OPENAI_API_KEY") + + model = OpenAIEmbeddings( + model="text-embedding-3-small", dimensions=128, api_key=get_openai_api_key + ) + _ = model.embed_query("hello") + assert calls["sync"] == 1 + + +async def test_callable_api_key_async(monkeypatch: pytest.MonkeyPatch) -> None: + original_key = os.environ["OPENAI_API_KEY"] + + calls = {"sync": 0, "async": 0} + + def get_openai_api_key() -> str: + calls["sync"] += 1 + return original_key + + async def get_openai_api_key_async() -> str: + calls["async"] += 1 + return original_key + + monkeypatch.delenv("OPENAI_API_KEY") + + model = OpenAIEmbeddings( + model="text-embedding-3-small", dimensions=128, api_key=get_openai_api_key + ) + _ = model.embed_query("hello") + assert calls["sync"] == 1 + + _ = await model.aembed_query("hello") + assert calls["sync"] == 2 + + model = OpenAIEmbeddings( + model="text-embedding-3-small", dimensions=128, api_key=get_openai_api_key_async + ) + _ = await model.aembed_query("hello") + assert calls["async"] == 1 + + with pytest.raises(ValueError): + # We do not create a sync callable from an async one + _ = model.embed_query("hello") diff --git a/libs/partners/openai/tests/unit_tests/test_secrets.py b/libs/partners/openai/tests/unit_tests/test_secrets.py index aa1484058e0e2..27d69bed92ce1 100644 --- a/libs/partners/openai/tests/unit_tests/test_secrets.py +++ b/libs/partners/openai/tests/unit_tests/test_secrets.py @@ -187,6 +187,18 @@ def test_openai_uses_actual_secret_value_from_secretstr(model_class: type) -> No assert cast(SecretStr, model.openai_api_key).get_secret_value() == "secret-api-key" +@pytest.mark.parametrize("model_class", [ChatOpenAI, OpenAI, OpenAIEmbeddings]) +def test_openai_api_key_accepts_callable(model_class: type) -> None: + """Test that the API key can be passed as a callable.""" + + def get_api_key() -> str: + return "secret-api-key-from-callable" + + model = model_class(openai_api_key=get_api_key) + assert callable(model.openai_api_key) + assert model.openai_api_key() == "secret-api-key-from-callable" + + @pytest.mark.parametrize("model_class", [AzureChatOpenAI, AzureOpenAI]) def test_azure_serialized_secrets(model_class: type) -> None: """Test that the actual secret value is correctly retrieved."""