Skip to content

Commit 21e7aae

Browse files
authored
Merge pull request #38 from bnarasimha21/track-usage-from-intagrations
feat(api): add user agent parameters to base_client and client classe…
2 parents 695cc57 + af7420c commit 21e7aae

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

src/gradient/_base_client.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,8 @@ def __init__(
376376
timeout: float | Timeout | None = DEFAULT_TIMEOUT,
377377
custom_headers: Mapping[str, str] | None = None,
378378
custom_query: Mapping[str, object] | None = None,
379+
user_agent_package: str | None = None,
380+
user_agent_version: str | None = None,
379381
) -> None:
380382
self._version = version
381383
self._base_url = self._enforce_trailing_slash(URL(base_url))
@@ -386,6 +388,8 @@ def __init__(
386388
self._strict_response_validation = _strict_response_validation
387389
self._idempotency_header = None
388390
self._platform: Platform | None = None
391+
self._user_agent_package = user_agent_package
392+
self._user_agent_version = user_agent_version
389393

390394
if max_retries is None: # pyright: ignore[reportUnnecessaryComparison]
391395
raise TypeError(
@@ -671,7 +675,10 @@ def _validate_headers(
671675

672676
@property
673677
def user_agent(self) -> str:
674-
return f"{self.__class__.__name__}/Python/{self._version}"
678+
# Format: "Gradient/package/version"
679+
package = self._user_agent_package or "Python"
680+
version = self._user_agent_version if self._user_agent_package and self._user_agent_version else self._version
681+
return f"{self.__class__.__name__}/{package}/{version}"
675682

676683
@property
677684
def base_url(self) -> URL:
@@ -830,6 +837,8 @@ def __init__(
830837
custom_headers: Mapping[str, str] | None = None,
831838
custom_query: Mapping[str, object] | None = None,
832839
_strict_response_validation: bool,
840+
user_agent_package: str | None = None,
841+
user_agent_version: str | None = None,
833842
) -> None:
834843
if not is_given(timeout):
835844
# if the user passed in a custom http client with a non-default
@@ -858,6 +867,8 @@ def __init__(
858867
custom_query=custom_query,
859868
custom_headers=custom_headers,
860869
_strict_response_validation=_strict_response_validation,
870+
user_agent_package=user_agent_package,
871+
user_agent_version=user_agent_version,
861872
)
862873
self._client = http_client or SyncHttpxClientWrapper(
863874
base_url=base_url,
@@ -1360,6 +1371,8 @@ def __init__(
13601371
http_client: httpx.AsyncClient | None = None,
13611372
custom_headers: Mapping[str, str] | None = None,
13621373
custom_query: Mapping[str, object] | None = None,
1374+
user_agent_package: str | None = None,
1375+
user_agent_version: str | None = None,
13631376
) -> None:
13641377
if not is_given(timeout):
13651378
# if the user passed in a custom http client with a non-default
@@ -1388,6 +1401,8 @@ def __init__(
13881401
custom_query=custom_query,
13891402
custom_headers=custom_headers,
13901403
_strict_response_validation=_strict_response_validation,
1404+
user_agent_package=user_agent_package,
1405+
user_agent_version=user_agent_version,
13911406
)
13921407
self._client = http_client or AsyncHttpxClientWrapper(
13931408
base_url=base_url,

src/gradient/_client.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ def __init__(
106106
# outlining your use-case to help us decide if it should be
107107
# part of our public interface in the future.
108108
_strict_response_validation: bool = False,
109+
# User agent tracking parameters
110+
user_agent_package: str | None = None,
111+
user_agent_version: str | None = None,
109112
) -> None:
110113
"""Construct a new synchronous Gradient client instance.
111114
@@ -169,6 +172,8 @@ def __init__(
169172
custom_headers=default_headers,
170173
custom_query=default_query,
171174
_strict_response_validation=_strict_response_validation,
175+
user_agent_package=user_agent_package,
176+
user_agent_version=user_agent_version,
172177
)
173178

174179
self._default_stream_cls = Stream
@@ -294,6 +299,8 @@ def copy(
294299
set_default_headers: Mapping[str, str] | None = None,
295300
default_query: Mapping[str, object] | None = None,
296301
set_default_query: Mapping[str, object] | None = None,
302+
user_agent_package: str | None = None,
303+
user_agent_version: str | None = None,
297304
_extra_kwargs: Mapping[str, Any] = {},
298305
) -> Self:
299306
"""
@@ -330,6 +337,8 @@ def copy(
330337
max_retries=max_retries if is_given(max_retries) else self.max_retries,
331338
default_headers=headers,
332339
default_query=params,
340+
user_agent_package=user_agent_package or self._user_agent_package,
341+
user_agent_version=user_agent_version or self._user_agent_version,
333342
**_extra_kwargs,
334343
)
335344
client._base_url_overridden = self._base_url_overridden or base_url is not None
@@ -410,6 +419,9 @@ def __init__(
410419
# outlining your use-case to help us decide if it should be
411420
# part of our public interface in the future.
412421
_strict_response_validation: bool = False,
422+
# User agent tracking parameters
423+
user_agent_package: str | None = None,
424+
user_agent_version: str | None = None,
413425
) -> None:
414426
"""Construct a new async AsyncGradient client instance.
415427
@@ -473,6 +485,8 @@ def __init__(
473485
custom_headers=default_headers,
474486
custom_query=default_query,
475487
_strict_response_validation=_strict_response_validation,
488+
user_agent_package=user_agent_package,
489+
user_agent_version=user_agent_version,
476490
)
477491

478492
self._default_stream_cls = AsyncStream
@@ -598,6 +612,8 @@ def copy(
598612
set_default_headers: Mapping[str, str] | None = None,
599613
default_query: Mapping[str, object] | None = None,
600614
set_default_query: Mapping[str, object] | None = None,
615+
user_agent_package: str | None = None,
616+
user_agent_version: str | None = None,
601617
_extra_kwargs: Mapping[str, Any] = {},
602618
) -> Self:
603619
"""
@@ -634,6 +650,8 @@ def copy(
634650
max_retries=max_retries if is_given(max_retries) else self.max_retries,
635651
default_headers=headers,
636652
default_query=params,
653+
user_agent_package=user_agent_package or self._user_agent_package,
654+
user_agent_version=user_agent_version or self._user_agent_version,
637655
**_extra_kwargs,
638656
)
639657
client._base_url_overridden = self._base_url_overridden or base_url is not None

0 commit comments

Comments
 (0)