Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

openai[patch]: instantiate clients lazily #29926

Closed
wants to merge 15 commits into from
12 changes: 1 addition & 11 deletions libs/partners/deepseek/langchain_deepseek/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def validate_environment(self) -> Self:
self.api_key and self.api_key.get_secret_value()
):
raise ValueError("If using default api base, DEEPSEEK_API_KEY must be set.")
client_params: dict = {
self._client_params: dict = {
k: v
for k, v in {
"api_key": self.api_key.get_secret_value() if self.api_key else None,
Expand All @@ -191,16 +191,6 @@ def validate_environment(self) -> Self:
if v is not None
}

if not (self.client or None):
sync_specific: dict = {"http_client": self.http_client}
self.client = openai.OpenAI(
**client_params, **sync_specific
).chat.completions
if not (self.async_client or None):
async_specific: dict = {"http_client": self.http_async_client}
self.async_client = openai.AsyncOpenAI(
**client_params, **async_specific
).chat.completions
return self

def _create_chat_result(
Expand Down
48 changes: 32 additions & 16 deletions libs/partners/openai/langchain_openai/chat_models/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ def validate_environment(self) -> Self:
"Or you can equivalently specify:\n\n"
'base_url="https://xxx.openai.azure.com/openai/deployments/my-deployment"'
)
client_params: dict = {
self._client_params: dict = {
"api_version": self.openai_api_version,
"azure_endpoint": self.azure_endpoint,
"azure_deployment": self.deployment_name,
Expand All @@ -650,26 +650,42 @@ def validate_environment(self) -> Self:
"default_query": self.default_query,
}
if self.max_retries is not None:
client_params["max_retries"] = self.max_retries
self._client_params["max_retries"] = self.max_retries

if not self.client:
if self.azure_ad_async_token_provider:
self._client_params["azure_ad_token_provider"] = (
self.azure_ad_async_token_provider
)

return self

@property
def root_client(self) -> Any:
if self._root_client is None:
sync_specific = {"http_client": self.http_client}
self.root_client = openai.AzureOpenAI(**client_params, **sync_specific) # type: ignore[arg-type]
self.client = self.root_client.chat.completions
if not self.async_client:
async_specific = {"http_client": self.http_async_client}
self._root_client = openai.AzureOpenAI(
**self._client_params,
**sync_specific, # type: ignore[call-overload]
)
return self._root_client

if self.azure_ad_async_token_provider:
client_params["azure_ad_token_provider"] = (
self.azure_ad_async_token_provider
)
@root_client.setter
def root_client(self, value: openai.AzureOpenAI) -> None:
self._root_client = value

self.root_async_client = openai.AsyncAzureOpenAI(
**client_params,
**async_specific, # type: ignore[arg-type]
@property
def root_async_client(self) -> Any:
if self._root_async_client is None:
async_specific = {"http_client": self.http_async_client}
self._root_async_client = openai.AsyncAzureOpenAI(
**self._client_params,
**async_specific, # type: ignore[call-overload]
)
self.async_client = self.root_async_client.chat.completions
return self
return self._root_async_client

@root_async_client.setter
def root_async_client(self, value: openai.AsyncAzureOpenAI) -> None:
self._root_async_client = value

@property
def _identifying_params(self) -> Dict[str, Any]:
Expand Down
189 changes: 151 additions & 38 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import json
import logging
import os
import ssl
import sys
import warnings
from io import BytesIO
from math import ceil
from operator import itemgetter
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
Expand All @@ -31,6 +33,7 @@
)
from urllib.parse import urlparse

import certifi
import openai
import tiktoken
from langchain_core._api.deprecation import deprecated
Expand Down Expand Up @@ -91,12 +94,25 @@
is_basemodel_subclass,
)
from langchain_core.utils.utils import _build_model_kwargs, from_env, secret_from_env
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from pydantic import (
BaseModel,
ConfigDict,
Field,
PrivateAttr,
SecretStr,
model_validator,
)
from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import Self

if TYPE_CHECKING:
import httpx

logger = logging.getLogger(__name__)

# This SSL context is equivelent to the default `verify=True`.
global_ssl_context = ssl.create_default_context(cafile=certifi.where())


def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
"""Convert a dictionary to a LangChain message.
Expand Down Expand Up @@ -385,10 +401,10 @@ class _AllReturnType(TypedDict):


class BaseChatOpenAI(BaseChatModel):
client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
root_client: Any = Field(default=None, exclude=True) #: :meta private:
root_async_client: Any = Field(default=None, exclude=True) #: :meta private:
_client: Any = PrivateAttr(default=None) #: :meta private:
_async_client: Any = PrivateAttr(default=None) #: :meta private:
_root_client: Any = PrivateAttr(default=None) #: :meta private:
_root_async_client: Any = PrivateAttr(default=None) #: :meta private:
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
"""Model name to use."""
temperature: Optional[float] = None
Expand Down Expand Up @@ -460,11 +476,11 @@ class BaseChatOpenAI(BaseChatModel):
default_query: Union[Mapping[str, object], None] = None
# Configure a custom httpx client. See the
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client: Union[Any, None] = Field(default=None, exclude=True)
_http_client: Union[Any, None] = PrivateAttr(default=None)
"""Optional httpx.Client. Only used for sync invocations. Must specify
http_async_client as well if you'd like a custom client for async invocations.
"""
http_async_client: Union[Any, None] = Field(default=None, exclude=True)
_http_async_client: Union[Any, None] = PrivateAttr(default=None)
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
http_client as well if you'd like a custom client for sync invocations."""
stop: Optional[Union[List[str], str]] = Field(default=None, alias="stop_sequences")
Expand All @@ -491,6 +507,7 @@ class BaseChatOpenAI(BaseChatModel):
However this does not prevent a user from directly passed in the parameter during
invocation.
"""
_client_params: Dict[str, Any] = PrivateAttr(default_factory=dict)

model_config = ConfigDict(populate_by_name=True)

Expand All @@ -511,6 +528,24 @@ def validate_temperature(cls, values: Dict[str, Any]) -> Any:
values["temperature"] = 1
return values

def __init__(
self,
client: Optional[Any] = None,
async_client: Optional[Any] = None,
root_client: Optional[Any] = None,
async_root_client: Optional[Any] = None,
http_client: Optional[Any] = None,
http_async_client: Optional[Any] = None,
Comment on lines +533 to +538
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ooc are these untyped on purpose?

**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self._client = client
self._async_client = async_client
self._root_client = root_client
self._async_root_client = async_root_client
self._http_client = http_client
self._http_async_client = http_async_client

@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
Expand All @@ -526,7 +561,7 @@ def validate_environment(self) -> Self:
or os.getenv("OPENAI_ORGANIZATION")
)
self.openai_api_base = self.openai_api_base or os.getenv("OPENAI_API_BASE")
client_params: dict = {
self._client_params: dict = {
"api_key": (
self.openai_api_key.get_secret_value() if self.openai_api_key else None
),
Expand All @@ -537,47 +572,122 @@ def validate_environment(self) -> Self:
"default_query": self.default_query,
}
if self.max_retries is not None:
client_params["max_retries"] = self.max_retries
self._client_params["max_retries"] = self.max_retries

if self.openai_proxy and (self.http_client or self.http_async_client):
if self.openai_proxy and (self._http_client or self._http_async_client):
openai_proxy = self.openai_proxy
http_client = self.http_client
http_async_client = self.http_async_client
http_client = self._http_client
http_async_client = self._http_async_client
raise ValueError(
"Cannot specify 'openai_proxy' if one of "
"'http_client'/'http_async_client' is already specified. Received:\n"
f"{openai_proxy=}\n{http_client=}\n{http_async_client=}"
)
if not self.client:
if self.openai_proxy and not self.http_client:
try:
import httpx
except ImportError as e:
raise ImportError(
"Could not import httpx python package. "
"Please install it with `pip install httpx`."
) from e
self.http_client = httpx.Client(proxy=self.openai_proxy)

return self

@property
def http_client(self) -> Optional[httpx.Client]:
"""Optional httpx.Client. Only used for sync invocations.

Must specify http_async_client as well if you'd like a custom client for
async invocations.
"""
# Configure a custom httpx client. See the
# [httpx documentation](https://www.python-httpx.org/api/#client) for more
# details.
if self._http_client is None:
if not self.openai_proxy:
return None
try:
import httpx
except ImportError as e:
raise ImportError(
"Could not import httpx python package. "
"Please install it with `pip install httpx`."
) from e
self._http_client = httpx.Client(
proxy=self.openai_proxy, verify=global_ssl_context
)
return self._http_client

@http_client.setter
def http_client(self, value: Optional[httpx.Client]) -> None:
self._http_client = value

@property
def http_async_client(self) -> Optional[httpx.AsyncClient]:
"""Optional httpx.AsyncClient. Only used for async invocations.

Must specify http_client as well if you'd like a custom client for sync
invocations.
"""
if self._http_async_client is None:
if not self.openai_proxy:
return None
try:
import httpx
except ImportError as e:
raise ImportError(
"Could not import httpx python package. "
"Please install it with `pip install httpx`."
) from e
self._http_async_client = httpx.AsyncClient(
proxy=self.openai_proxy, verify=global_ssl_context
)
return self._http_async_client

@http_async_client.setter
def http_async_client(self, value: Optional[httpx.AsyncClient]) -> None:
self._http_async_client = value

@property
def root_client(self) -> Any:
if self._root_client is None:
sync_specific = {"http_client": self.http_client}
self.root_client = openai.OpenAI(**client_params, **sync_specific) # type: ignore[arg-type]
self.client = self.root_client.chat.completions
if not self.async_client:
if self.openai_proxy and not self.http_async_client:
try:
import httpx
except ImportError as e:
raise ImportError(
"Could not import httpx python package. "
"Please install it with `pip install httpx`."
) from e
self.http_async_client = httpx.AsyncClient(proxy=self.openai_proxy)
self._root_client = openai.OpenAI(
**self._client_params,
**sync_specific, # type: ignore[arg-type]
)
return self._root_client

@root_client.setter
def root_client(self, value: openai.OpenAI) -> None:
self._root_client = value

@property
def root_async_client(self) -> Any:
if self._root_async_client is None:
async_specific = {"http_client": self.http_async_client}
self.root_async_client = openai.AsyncOpenAI(
**client_params,
self._root_async_client = openai.AsyncOpenAI(
**self._client_params,
**async_specific, # type: ignore[arg-type]
)
self.async_client = self.root_async_client.chat.completions
return self
return self._root_async_client

@root_async_client.setter
def root_async_client(self, value: openai.AsyncOpenAI) -> None:
self._root_async_client = value

@property
def client(self) -> Any:
if self._client is None:
self._client = self.root_client.chat.completions
return self._client

@client.setter
def client(self, value: Any) -> None:
self._client = value

@property
def async_client(self) -> Any:
if self._async_client is None:
self._async_client = self.root_async_client.chat.completions
return self._async_client

@async_client.setter
def async_client(self, value: Any) -> None:
self._async_client = value

@property
def _default_params(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -1961,6 +2071,9 @@ class Joke(BaseModel):
max_tokens: Optional[int] = Field(default=None, alias="max_completion_tokens")
"""Maximum number of tokens to generate."""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

@property
def lc_secrets(self) -> Dict[str, str]:
return {"openai_api_key": "OPENAI_API_KEY"}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,16 @@ def test_openai_proxy() -> None:
assert proxy.host == b"localhost"
assert proxy.port == 8080

http_async_client = httpx.AsyncClient(proxy="http://localhost:8081")
chat_openai = ChatOpenAI(http_async_client=http_async_client)
mounts = chat_openai.async_client._client._client._mounts
assert len(mounts) == 1
for key, value in mounts.items():
proxy = value._pool._proxy_url.origin
assert proxy.scheme == b"http"
assert proxy.host == b"localhost"
assert proxy.port == 8081


def test_openai_response_headers() -> None:
"""Test ChatOpenAI response headers."""
Expand Down
Loading
Loading