Skip to content

Commit e55a6a6

Browse files
MarkDaoustcopybara-github
authored andcommitted
feat: Add credential support for live music.
PiperOrigin-RevId: 838331631
1 parent 99058b6 commit e55a6a6

File tree

7 files changed

+638
-144
lines changed

7 files changed

+638
-144
lines changed

google/genai/_api_client.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -679,10 +679,29 @@ def __init__(
679679
)
680680
self._http_options.api_version = 'v1beta1'
681681
else: # Implicit initialization or missing arguments.
682-
if not self.api_key:
682+
if env_api_key and api_key:
683+
# Explicit credentials take precedence over implicit api_key.
684+
logger.info(
685+
'The client initialiser api_key argument takes '
686+
'precedence over the API key from the environment variable.'
687+
)
688+
if credentials:
689+
if api_key:
690+
raise ValueError(
691+
'Credentials and API key are mutually exclusive in the client'
692+
' initializer.'
693+
)
694+
elif env_api_key:
695+
logger.info(
696+
'The user `credentials` argument will take precedence over the'
697+
' api key from the environment variables.'
698+
)
699+
self.api_key = None
700+
701+
if not self.api_key and not credentials:
683702
raise ValueError(
684703
'Missing key inputs argument! To use the Google AI API,'
685-
' provide (`api_key`) arguments. To use the Google Cloud API,'
704+
' provide (`api_key` or `credentials`) arguments. To use the Google Cloud API,'
686705
' provide (`vertexai`, `project` & `location`) arguments.'
687706
)
688707
self._http_options.base_url = 'https://generativelanguage.googleapis.com/'
@@ -1162,20 +1181,21 @@ def _request_once(
11621181
stream: bool = False,
11631182
) -> HttpResponse:
11641183
data: Optional[Union[str, bytes]] = None
1165-
# If using proj/location, fetch ADC
1166-
if self.vertexai and (self.project or self.location):
1184+
1185+
uses_vertex_creds = self.vertexai and (self.project or self.location)
1186+
uses_mldev_creds = not self.vertexai and self._credentials
1187+
if (uses_vertex_creds or uses_mldev_creds):
11671188
http_request.headers['Authorization'] = f'Bearer {self._access_token()}'
11681189
if self._credentials and self._credentials.quota_project_id:
11691190
http_request.headers['x-goog-user-project'] = (
11701191
self._credentials.quota_project_id
11711192
)
1172-
data = json.dumps(http_request.data) if http_request.data else None
1173-
else:
1174-
if http_request.data:
1175-
if not isinstance(http_request.data, bytes):
1176-
data = json.dumps(http_request.data) if http_request.data else None
1177-
else:
1178-
data = http_request.data
1193+
1194+
if http_request.data:
1195+
if not isinstance(http_request.data, bytes):
1196+
data = json.dumps(http_request.data) if http_request.data else None
1197+
else:
1198+
data = http_request.data
11791199

11801200
if stream:
11811201
httpx_request = self._httpx_client.build_request(
@@ -1228,22 +1248,22 @@ async def _async_request_once(
12281248
) -> HttpResponse:
12291249
data: Optional[Union[str, bytes]] = None
12301250

1231-
# If using proj/location, fetch ADC
1232-
if self.vertexai and (self.project or self.location):
1251+
uses_vertex_creds = self.vertexai and (self.project or self.location)
1252+
uses_mldev_creds = not self.vertexai and self._credentials
1253+
if (uses_vertex_creds or uses_mldev_creds):
12331254
http_request.headers['Authorization'] = (
12341255
f'Bearer {await self._async_access_token()}'
12351256
)
12361257
if self._credentials and self._credentials.quota_project_id:
12371258
http_request.headers['x-goog-user-project'] = (
12381259
self._credentials.quota_project_id
12391260
)
1240-
data = json.dumps(http_request.data) if http_request.data else None
1241-
else:
1242-
if http_request.data:
1243-
if not isinstance(http_request.data, bytes):
1244-
data = json.dumps(http_request.data) if http_request.data else None
1245-
else:
1246-
data = http_request.data
1261+
1262+
if http_request.data:
1263+
if not isinstance(http_request.data, bytes):
1264+
data = json.dumps(http_request.data) if http_request.data else None
1265+
else:
1266+
data = http_request.data
12471267

12481268
if stream:
12491269
if self._use_aiohttp():

google/genai/_extra_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,20 @@
1616
"""Extra utils depending on types that are shared between sync and async modules."""
1717

1818
import asyncio
19+
from collections.abc import Callable, MutableMapping
1920
import inspect
2021
import io
2122
import logging
2223
import sys
2324
import typing
24-
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
25+
from typing import Any, Optional, Union, get_args, get_origin
2526
import mimetypes
2627
import os
2728
import pydantic
2829

30+
import google.auth.transport.requests
31+
32+
2933
from . import _common
3034
from . import _mcp_utils
3135
from . import _transformers as t
@@ -674,3 +678,18 @@ def prepare_resumable_upload(
674678
http_options.headers = {}
675679
http_options.headers['X-Goog-Upload-File-Name'] = os.path.basename(file)
676680
return http_options, size_bytes, mime_type
681+
682+
683+
async def _maybe_update_and_insert_auth_token(
684+
headers:MutableMapping[str, str],
685+
creds: google.auth.credentials.Credentials) -> None:
686+
# Refresh credentials to ensure token is valid
687+
if not (creds.token and creds.valid):
688+
try:
689+
auth_req = google.auth.transport.requests.Request() # type: ignore[no-untyped-call]
690+
await asyncio.to_thread(creds.refresh, auth_req)
691+
except Exception as e:
692+
raise ConnectionError(f"Failed to refresh credentials") from e
693+
694+
if not headers.get('Authorization'):
695+
headers['Authorization'] = f'Bearer {creds.token}'

0 commit comments

Comments
 (0)