diff --git a/google/genai/_api_client.py b/google/genai/_api_client.py index bd3071d39..6182a45e8 100644 --- a/google/genai/_api_client.py +++ b/google/genai/_api_client.py @@ -679,10 +679,29 @@ def __init__( ) self._http_options.api_version = 'v1beta1' else: # Implicit initialization or missing arguments. - if not self.api_key: + if env_api_key and api_key: + # Explicit credentials take precedence over implicit api_key. + logger.info( + 'The client initialiser api_key argument takes ' + 'precedence over the API key from the environment variable.' + ) + if credentials: + if api_key: + raise ValueError( + 'Credentials and API key are mutually exclusive in the client' + ' initializer.' + ) + elif env_api_key: + logger.info( + 'The user `credentials` argument will take precedence over the' + ' api key from the environment variables.' + ) + self.api_key = None + + if not self.api_key and not credentials: raise ValueError( 'Missing key inputs argument! To use the Google AI API,' - ' provide (`api_key`) arguments. To use the Google Cloud API,' + ' provide (`api_key` or `credentials`) arguments. To use the Google Cloud API,' ' provide (`vertexai`, `project` & `location`) arguments.' ) self._http_options.base_url = 'https://generativelanguage.googleapis.com/' @@ -1162,20 +1181,21 @@ def _request_once( stream: bool = False, ) -> HttpResponse: data: Optional[Union[str, bytes]] = None - # If using proj/location, fetch ADC - if self.vertexai and (self.project or self.location): + + uses_vertex_creds = self.vertexai and (self.project or self.location) + uses_mldev_creds = not self.vertexai and self._credentials + if (uses_vertex_creds or uses_mldev_creds): http_request.headers['Authorization'] = f'Bearer {self._access_token()}' if self._credentials and self._credentials.quota_project_id: http_request.headers['x-goog-user-project'] = ( self._credentials.quota_project_id ) - data = json.dumps(http_request.data) if http_request.data else None - else: - if http_request.data: - if not isinstance(http_request.data, bytes): - data = json.dumps(http_request.data) if http_request.data else None - else: - data = http_request.data + + if http_request.data: + if not isinstance(http_request.data, bytes): + data = json.dumps(http_request.data) if http_request.data else None + else: + data = http_request.data if stream: httpx_request = self._httpx_client.build_request( @@ -1228,8 +1248,9 @@ async def _async_request_once( ) -> HttpResponse: data: Optional[Union[str, bytes]] = None - # If using proj/location, fetch ADC - if self.vertexai and (self.project or self.location): + uses_vertex_creds = self.vertexai and (self.project or self.location) + uses_mldev_creds = not self.vertexai and self._credentials + if (uses_vertex_creds or uses_mldev_creds): http_request.headers['Authorization'] = ( f'Bearer {await self._async_access_token()}' ) @@ -1237,13 +1258,12 @@ async def _async_request_once( http_request.headers['x-goog-user-project'] = ( self._credentials.quota_project_id ) - data = json.dumps(http_request.data) if http_request.data else None - else: - if http_request.data: - if not isinstance(http_request.data, bytes): - data = json.dumps(http_request.data) if http_request.data else None - else: - data = http_request.data + + if http_request.data: + if not isinstance(http_request.data, bytes): + data = json.dumps(http_request.data) if http_request.data else None + else: + data = http_request.data if stream: if self._use_aiohttp(): diff --git a/google/genai/_extra_utils.py b/google/genai/_extra_utils.py index e0fb9c105..bbc70eed8 100644 --- a/google/genai/_extra_utils.py +++ b/google/genai/_extra_utils.py @@ -16,16 +16,20 @@ """Extra utils depending on types that are shared between sync and async modules.""" import asyncio +from collections.abc import Callable, MutableMapping import inspect import io import logging import sys import typing -from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin +from typing import Any, Optional, Union, get_args, get_origin import mimetypes import os import pydantic +import google.auth.transport.requests + + from . import _common from . import _mcp_utils from . import _transformers as t @@ -674,3 +678,18 @@ def prepare_resumable_upload( http_options.headers = {} http_options.headers['X-Goog-Upload-File-Name'] = os.path.basename(file) return http_options, size_bytes, mime_type + + +async def _maybe_update_and_insert_auth_token( + headers:MutableMapping[str, str], + creds: google.auth.credentials.Credentials) -> None: + # Refresh credentials to ensure token is valid + if not (creds.token and creds.valid): + try: + auth_req = google.auth.transport.requests.Request() # type: ignore[no-untyped-call] + await asyncio.to_thread(creds.refresh, auth_req) + except Exception as e: + raise ConnectionError(f"Failed to refresh credentials") from e + + if not headers.get('Authorization'): + headers['Authorization'] = f'Bearer {creds.token}' diff --git a/google/genai/live.py b/google/genai/live.py index 0d8453fc5..0c2b875e2 100644 --- a/google/genai/live.py +++ b/google/genai/live.py @@ -29,6 +29,7 @@ from websockets import ConnectionClosed from . import _api_module +from . import _extra_utils from . import _common from . import _live_converters as live_converters from . import _mcp_utils @@ -929,17 +930,95 @@ async def connect( base_url = self._api_client._websocket_base_url() if isinstance(base_url, bytes): base_url = base_url.decode('utf-8') - transformed_model = t.t_model(self._api_client, model) # type: ignore parameter_model = await _t_live_connect_config(self._api_client, config) - if self._api_client.api_key and not self._api_client.vertexai: - version = self._api_client._http_options.api_version - api_key = self._api_client.api_key - method = 'BidiGenerateContent' - original_headers = self._api_client._http_options.headers - headers = original_headers.copy() if original_headers is not None else {} + if self._api_client.vertexai: + uri, headers, request = await self._prepare_connection_vertex( + base_url=base_url, model=model, parameter_model=parameter_model + ) + else: + uri, headers, request = await self._prepare_connection_mldev( + base_url=base_url, model=model, parameter_model=parameter_model + ) + + if parameter_model.tools and _mcp_utils.has_mcp_tool_usage( + parameter_model.tools + ): + if headers is None: + headers = {} + _mcp_utils.set_mcp_usage_header(headers) + + async with ws_connect( + uri, additional_headers=headers, **self._api_client._websocket_ssl_ctx + ) as ws: + await ws.send(request) + try: + # websockets 14.0+ + raw_response = await ws.recv(decode=False) + except TypeError: + raw_response = await ws.recv() # type: ignore[assignment] + if raw_response: + try: + response = json.loads(raw_response) + except json.decoder.JSONDecodeError as e: + raise ValueError(f'Failed to parse response: {raw_response!r}') from e + else: + response = {} + + if self._api_client.vertexai: + response_dict = live_converters._LiveServerMessage_from_vertex(response) + else: + response_dict = response + + setup_response = types.LiveServerMessage._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + if setup_response.setup_complete: + session_id = setup_response.setup_complete.session_id + else: + session_id = None + yield AsyncSession( + api_client=self._api_client, + websocket=ws, + session_id=session_id, + ) + + async def _prepare_connection_mldev( + self, *, + base_url: str, + model: str, + parameter_model: types.LiveConnectConfig, + ) -> tuple[str, _common.StringDict, str]: + """Prepares live connection parameters for the MLDev API. + + Constructs the WebSocket URI, headers, and request body necessary + to establish a connection with the MLDev backend. + + Args: + base_url: The base URL for the WebSocket connection. + model: The name of the model to use. + parameter_model: Configuration parameters for the connection. + + Returns: + A tuple containing: + - uri: The WebSocket connection URI. + - headers: A dictionary of headers for the connection. + - request: The JSON-serialized request body. + + Raises: + ValueError: If an API key is not provided. + """ + transformed_model = t.t_model(self._api_client, model) # type: ignore + version = self._api_client._http_options.api_version + method = 'BidiGenerateContent' + original_headers = self._api_client._http_options.headers + headers = original_headers.copy() if original_headers is not None else {} + + if api_key := self._api_client.api_key: if api_key.startswith('auth_tokens/'): + method = 'BidiGenerateContentConstrained' + headers['Authorization'] = f'Token {api_key}' warnings.warn( message=( "The SDK's ephemeral token support is experimental, and may" @@ -947,8 +1026,6 @@ async def connect( ), category=errors.ExperimentalWarning, ) - method = 'BidiGenerateContentConstrained' - headers['Authorization'] = f'Token {api_key}' if version != 'v1alpha': warnings.warn( message=( @@ -959,46 +1036,67 @@ async def connect( ), category=errors.ExperimentalWarning, ) - uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.{method}' - - request_dict = _common.convert_to_dict( - live_converters._LiveConnectParameters_to_mldev( - api_client=self._api_client, - from_object=types.LiveConnectParameters( - model=transformed_model, - config=parameter_model, - ).model_dump(exclude_none=True), - ) - ) - del request_dict['config'] + elif creds := self._api_client._credentials: + await _extra_utils._maybe_update_and_insert_auth_token(headers, creds) + else: + # this shouldn't happen. + raise ValueError('Genai live connection requires credentials or API key provided.') + + uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.{method}' + + request_dict = _common.convert_to_dict( + live_converters._LiveConnectParameters_to_mldev( + api_client=self._api_client, + from_object=types.LiveConnectParameters( + model=transformed_model, + config=parameter_model, + ).model_dump(exclude_none=True), + ) + ) + del request_dict['config'] - setv(request_dict, ['setup', 'model'], transformed_model) + setv(request_dict, ['setup', 'model'], transformed_model) - request = json.dumps(request_dict) - elif self._api_client.api_key and self._api_client.vertexai: - # Headers already contains api key for express mode. - api_key = self._api_client.api_key - version = self._api_client._http_options.api_version - uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent' - original_headers = self._api_client._http_options.headers - headers = original_headers.copy() if original_headers is not None else {} - - request_dict = _common.convert_to_dict( - live_converters._LiveConnectParameters_to_vertex( - api_client=self._api_client, - from_object=types.LiveConnectParameters( - model=transformed_model, - config=parameter_model, - ).model_dump(exclude_none=True), - ) - ) - del request_dict['config'] + return uri, headers, json.dumps(request_dict) + + + async def _prepare_connection_vertex( + self, *, + base_url: str, + model: str, + parameter_model: types.LiveConnectConfig, + ) -> tuple[str, _common.StringDict, str]: + """Prepares live connection parameters for the Vertex AI API. - setv(request_dict, ['setup', 'model'], transformed_model) + Constructs the WebSocket URI, headers, and request body necessary + to establish a connection with the Vertex AI backend. Handles + authentication using either an API key or default credentials. - request = json.dumps(request_dict) + Args: + base_url: The base URL for the WebSocket connection. + model: The name of the model to use. + parameter_model: Configuration parameters for the connection. + + Returns: + A tuple containing: + - uri: The WebSocket connection URI. + - headers: A dictionary of headers for the connection. + - request: The JSON-serialized request body. + + Raises: + ValueError: If project and location are not provided when + default credentials are used. + """ + transformed_model = t.t_model(self._api_client, model) # type: ignore + version = self._api_client._http_options.api_version + original_headers = self._api_client._http_options.headers + headers = ( + original_headers.copy() if original_headers is not None else {} + ) + if api_key := self._api_client.api_key: + # Headers already contains api key + uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent' else: - version = self._api_client._http_options.api_version has_sufficient_auth = ( self._api_client.project and self._api_client.location ) @@ -1026,17 +1124,8 @@ async def connect( creds = self._api_client._credentials # creds.valid is False, and creds.token is None # Need to refresh credentials to populate those - if not (creds.token and creds.valid): - auth_req = google.auth.transport.requests.Request() # type: ignore - creds.refresh(auth_req) - bearer_token = creds.token + await _extra_utils._maybe_update_and_insert_auth_token(headers, creds) - original_headers = self._api_client._http_options.headers - headers = ( - original_headers.copy() if original_headers is not None else {} - ) - if not headers.get('Authorization'): - headers['Authorization'] = f'Bearer {bearer_token}' location = self._api_client.location project = self._api_client.project @@ -1044,17 +1133,22 @@ async def connect( transformed_model = ( f'projects/{project}/locations/{location}/' + transformed_model ) - request_dict = _common.convert_to_dict( - live_converters._LiveConnectParameters_to_vertex( - api_client=self._api_client, - from_object=types.LiveConnectParameters( - model=transformed_model, - config=parameter_model, - ).model_dump(exclude_none=True), - ) - ) - del request_dict['config'] + request_dict = _common.convert_to_dict( + live_converters._LiveConnectParameters_to_vertex( + api_client=self._api_client, + from_object=types.LiveConnectParameters( + model=transformed_model, + config=parameter_model, + ).model_dump(exclude_none=True), + ) + ) + del request_dict['config'] + + if api_key is None: + # Refactor note: I'm surprised the two paths are different, you'd have + # to test every model to be sure. The goal of this refactor is to not + # change any behavior so leaving it as is. if ( getv( request_dict, ['setup', 'generationConfig', 'responseModalities'] @@ -1067,49 +1161,10 @@ async def connect( ['AUDIO'], ) - request = json.dumps(request_dict) + return uri, headers, json.dumps(request_dict) - if parameter_model.tools and _mcp_utils.has_mcp_tool_usage( - parameter_model.tools - ): - if headers is None: - headers = {} - _mcp_utils.set_mcp_usage_header(headers) - async with ws_connect( - uri, additional_headers=headers, **self._api_client._websocket_ssl_ctx - ) as ws: - await ws.send(request) - try: - # websockets 14.0+ - raw_response = await ws.recv(decode=False) - except TypeError: - raw_response = await ws.recv() # type: ignore[assignment] - if raw_response: - try: - response = json.loads(raw_response) - except json.decoder.JSONDecodeError: - raise ValueError(f'Failed to parse response: {raw_response!r}') - else: - response = {} - if self._api_client.vertexai: - response_dict = live_converters._LiveServerMessage_from_vertex(response) - else: - response_dict = response - - setup_response = types.LiveServerMessage._from_response( - response=response_dict, kwargs=parameter_model.model_dump() - ) - if setup_response.setup_complete: - session_id = setup_response.setup_complete.session_id - else: - session_id = None - yield AsyncSession( - api_client=self._api_client, - websocket=ws, - session_id=session_id, - ) async def _t_live_connect_config( diff --git a/google/genai/live_music.py b/google/genai/live_music.py index 2f739d5b6..0bde96e52 100644 --- a/google/genai/live_music.py +++ b/google/genai/live_music.py @@ -21,6 +21,7 @@ from typing import AsyncIterator from . import _api_module +from . import _extra_utils from . import _common from . import _live_converters as live_converters from . import _transformers as t @@ -156,31 +157,42 @@ class AsyncLiveMusic(_api_module.BaseModule): @contextlib.asynccontextmanager async def connect(self, *, model: str) -> AsyncIterator[AsyncMusicSession]: """[Experimental] Connect to the live music server.""" + if self._api_client.vertexai: + raise NotImplementedError('Live music generation is not supported in Vertex AI.') + base_url = self._api_client._websocket_base_url() if isinstance(base_url, bytes): base_url = base_url.decode('utf-8') transformed_model = t.t_model(self._api_client, model) + version = self._api_client._http_options.api_version + original_headers = self._api_client._http_options.headers + headers = original_headers.copy() if original_headers is not None else {} + if self._api_client.api_key: - api_key = self._api_client.api_key - version = self._api_client._http_options.api_version - uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateMusic?key={api_key}' - headers = self._api_client._http_options.headers - - # Only mldev supported - request_dict = _common.convert_to_dict( - live_converters._LiveMusicConnectParameters_to_mldev( - from_object=types.LiveMusicConnectParameters( - model=transformed_model, - ).model_dump(exclude_none=True) - ) - ) + # API key is already included in headers. + pass + elif creds := self._api_client._credentials: + await _extra_utils._maybe_update_and_insert_auth_token(headers, creds) + else: + # This shouldn't happen. + raise ValueError('Genai live music connection requires credentials or API key provided.') + + uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateMusic' + + # Only mldev supported + request_dict = _common.convert_to_dict( + live_converters._LiveMusicConnectParameters_to_mldev( + from_object=types.LiveMusicConnectParameters( + model=transformed_model, + ).model_dump(exclude_none=True) + ) + ) - setv(request_dict, ['setup', 'model'], transformed_model) + setv(request_dict, ['setup', 'model'], transformed_model) + + request = json.dumps(request_dict) - request = json.dumps(request_dict) - else: - raise NotImplementedError('Live music generation is not supported in Vertex AI.') try: async with connect(uri, additional_headers=headers) as ws: diff --git a/google/genai/tests/client/test_client_initialization.py b/google/genai/tests/client/test_client_initialization.py index abeb780b7..58aceb15d 100644 --- a/google/genai/tests/client/test_client_initialization.py +++ b/google/genai/tests/client/test_client_initialization.py @@ -47,6 +47,28 @@ ) +class FakeCredentials(credentials.Credentials): + def __init__(self, token="fake_token", expired=False, quota_project_id=None): + super().__init__() + self.token = token + self._expired = expired + self._quota_project_id = quota_project_id + self.refresh_count = 0 + + @property + def expired(self): + return self._expired + + @property + def quota_project_id(self): + return self._quota_project_id + + def refresh(self, request): + self.refresh_count += 1 + self.token = "refreshed_token" + self._expired = False + + @pytest.fixture(autouse=True) def reset_has_aiohttp(): yield @@ -1685,3 +1707,141 @@ async def test_get_aiohttp_session(): assert initial_session is not None session = await client._api_client._get_aiohttp_session() assert session is initial_session + + +def test_missing_api_key_and_credentials(monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "") + with pytest.raises(ValueError, match="Missing key inputs argument!"): + Client() + + +auth_precedence_test_cases = [ + # client_args, env_vars, expected_headers + ( + {"credentials": FakeCredentials()}, + {"GOOGLE_API_KEY": "env_api_key"}, + {"Authorization": "Bearer fake_token"} + ), + ( + {"credentials": FakeCredentials(quota_project_id="quota-proj")}, + {"GOOGLE_API_KEY": "env_api_key"}, + { + "Authorization": "Bearer fake_token", + "x-goog-user-project": "quota-proj" + } + ), + ( + {"api_key": "test_api_key"}, + {"GOOGLE_API_KEY": "env_api_key"}, + {"x-goog-api-key": "test_api_key"} + ), + ( + {}, + {"GOOGLE_API_KEY": "env_api_key"}, + {"x-goog-api-key": "env_api_key"} + ), +] + + +@pytest.mark.parametrize( + ["client_kwargs", "env_vars", "expected_headers"], + auth_precedence_test_cases, +) +@mock.patch.object(httpx.Client, "send", autospec=True) +def test_auth_precedence_mldev(mock_send, monkeypatch, client_kwargs, env_vars, expected_headers): + for key, value in env_vars.items(): + monkeypatch.setenv(key, value) + + client = Client(**client_kwargs) + mock_send.return_value = httpx.Response( + status_code=200, + json={"candidates": [{"content": {"parts": [{"text": "response"}]}}]} + ) + client.models.generate_content(model="test", contents="hello?") + mock_send.assert_called_once() + request = mock_send.call_args[0][1] + + for key, value in expected_headers.items(): + assert key in request.headers + assert request.headers[key] == value + + if "Authorization" in expected_headers: + assert "x-goog-api-key" not in request.headers + if "x-goog-api-key" in expected_headers: + assert "Authorization" not in request.headers + if "x-goog-user-project" not in expected_headers: + assert "x-goog-user-project" not in request.headers + assert request.content == b'{"contents": [{"parts": [{"text": "hello?"}], "role": "user"}]}' + +@pytest.mark.parametrize( + ["client_kwargs", "env_vars", "expected_headers"], + auth_precedence_test_cases, +) +@pytest.mark.asyncio +@mock.patch.object(httpx.AsyncClient, "send", autospec=True) +async def test_async_auth_precedence_mldev(mock_send, monkeypatch, client_kwargs, env_vars, expected_headers): + for key, value in env_vars.items(): + monkeypatch.setenv(key, value) + + client = Client(**client_kwargs) + mock_send.return_value = httpx.Response( + status_code=200, + json={"candidates": [{"content": {"parts": [{"text": "response"}]}}]} + ) + await client.aio.models.generate_content(model="test", contents="hello?") + mock_send.assert_called_once() + request = mock_send.call_args[0][1] + + for key, value in expected_headers.items(): + assert key in request.headers + assert request.headers[key] == value + + if "Authorization" in expected_headers: + assert "x-goog-api-key" not in request.headers + if "x-goog-api-key" in expected_headers: + assert "Authorization" not in request.headers + if "x-goog-user-project" not in expected_headers: + assert "x-goog-user-project" not in request.headers + assert request.content == b'{"contents": [{"parts": [{"text": "hello?"}], "role": "user"}]}' + + +async def test_both_credentials_mldev(): + with pytest.raises(ValueError, match="mutually exclusive"): + creds = FakeCredentials(expired=True) + client = Client(credentials=creds, api_key="test-api-key") + + +@mock.patch.object(httpx.Client, "send", autospec=True) +def test_refresh_credentials_mldev(mock_send): + creds = FakeCredentials(expired=True) + client = Client(credentials=creds) + mock_send.return_value = httpx.Response( + status_code=200, + json={"candidates": [{"content": {"parts": [{"text": "response"}]}}]}, + ) + client.models.generate_content(model="test", contents="hello?") + mock_send.assert_called_once() + request = mock_send.call_args[0][1] + assert "Authorization" in request.headers + assert request.headers["Authorization"] == "Bearer refreshed_token" + assert "x-goog-api-key" not in request.headers + assert creds.refresh_count == 1 + + +@requires_aiohttp +@pytest.mark.asyncio +@mock.patch.object(httpx.AsyncClient, "send", autospec=True) +async def test_async_refresh_credentials_mldev(mock_send): + creds = FakeCredentials(expired=True) + client = Client(credentials=creds) + mock_send.return_value = httpx.Response( + status_code=200, + json={"candidates": [{"content": {"parts": [{"text": "response"}]}}]}, + ) + await client.aio.models.generate_content(model="test", contents="hello?") + mock_send.assert_called_once() + request = mock_send.call_args[0][1] + assert "Authorization" in request.headers + assert request.headers["Authorization"] == "Bearer refreshed_token" + assert "x-goog-api-key" not in request.headers + assert creds.refresh_count == 1 diff --git a/google/genai/tests/live/test_live.py b/google/genai/tests/live/test_live.py index 9ffebb726..cd2f9c6c4 100644 --- a/google/genai/tests/live/test_live.py +++ b/google/genai/tests/live/test_live.py @@ -29,6 +29,8 @@ import warnings import certifi +import google.auth +from google.auth.transport import requests from google.oauth2.credentials import Credentials import pytest from websockets import client @@ -40,6 +42,9 @@ from ... import client as gl_client from ... import live from ... import types +from ... import _extra_utils +from google.auth import credentials + try: import aiohttp AIOHTTP_NOT_INSTALLED = False @@ -85,6 +90,23 @@ }] +class FakeCredentials(Credentials): + def __init__(self, token='fake_token', valid=True): + super().__init__(token='placeholder') + self.token = token + self._valid = valid + self.refresh_called = False + + def refresh(self, request): + self.token = 'refreshed_token' + self._valid = True + self.refresh_called = True + + @property + def valid(self): + return self._valid + + def get_current_weather(location: str, unit: str): """Get the current weather in a city.""" return 15 if unit == 'C' else 59 @@ -2073,3 +2095,122 @@ async def mock_connect(uri, additional_headers=None, **kwargs): assert 'x-goog-api-key' in capture['headers'], "x-goog-api-key is missing from headers" assert capture['headers']['x-goog-api-key'] == 'TEST_API_KEY' assert 'BidiGenerateContent' in capture['uri'] + + + +@pytest.mark.asyncio +async def test_prepare_connection_vertex_with_api_key(mock_websocket): + # Test the branch where api_key is present in vertexai + client = Client(vertexai=True, api_key="test_api_key") + capture = {} + + @contextlib.asynccontextmanager + async def mock_ws_connect(uri, additional_headers=None, **kwargs): + capture['uri'] = uri + capture['headers'] = additional_headers + yield mock_websocket + + with patch.object(live, 'ws_connect', new=mock_ws_connect): + live_module = client.aio.live + async with live_module.connect(model='test-model'): + pass + + headers = capture['headers'] + uri = capture['uri'] + assert 'x-goog-api-key' in headers + assert headers['x-goog-api-key'] == "test_api_key" + # Authorization header should not be added by this method if api_key is used + assert 'Authorization' not in headers + assert "BidiGenerateContent" in uri + + +@pytest.mark.asyncio +async def test_prepare_connection_vertex_refresh_creds(mock_websocket): + # Test the branch where credentials need refreshing + fake_creds = FakeCredentials(token=None, valid=False) + capture = {} + + @contextlib.asynccontextmanager + async def mock_ws_connect(uri, additional_headers=None, **kwargs): + capture['uri'] = uri + capture['headers'] = additional_headers + yield mock_websocket + + with ( + patch.object(google.auth, 'default', return_value=(fake_creds, "test-project")), + patch.object(requests, 'Request', return_value=Mock()), + patch.object(live, 'ws_connect', new=mock_ws_connect) + ): + client = Client(vertexai=True, project="test-project", + location="us-central1") + live_module = client.aio.live + async with live_module.connect(model='test-model'): + pass + + headers = capture['headers'] + uri = capture['uri'] + assert fake_creds.refresh_called + assert 'Authorization' in headers + assert headers['Authorization'] == f'Bearer refreshed_token' + assert "BidiGenerateContent" in uri + + +@pytest.mark.asyncio +async def test_async_live_connect_with_api_key(mock_websocket): + client = api_client.BaseApiClient(api_key='test_api_key') + async_live = live.AsyncLive(client) + capture = {} + + @contextlib.asynccontextmanager + async def mock_connect(uri, additional_headers=None, **kwargs): + capture['headers'] = additional_headers + yield mock_websocket + + with mock.patch.object(live, 'ws_connect', new=mock_connect): + async with async_live.connect(model='models/test-model'): + pass + + assert 'headers' in capture + headers = capture['headers'] + + assert headers['x-goog-api-key'] == 'test_api_key' + + assert 'Authorization' not in headers + +@pytest.mark.parametrize( + "creds, existing_headers, expected_auth, expect_refresh", + [ + (FakeCredentials(), {}, 'Bearer fake_token', False), + (FakeCredentials(valid=False), {}, 'Bearer refreshed_token', True), + (FakeCredentials(token=None, valid=False), {}, 'Bearer refreshed_token', True), + (FakeCredentials(token='existing_token', valid=True), {}, 'Bearer existing_token', False), + (FakeCredentials(token='new_token', valid=True), {'Authorization': 'Bearer old_token'}, 'Bearer old_token', False), + ], +) +@pytest.mark.asyncio +async def test_async_live_connect_with_credentials( + mock_websocket, creds, existing_headers, expected_auth, expect_refresh +): + client = api_client.BaseApiClient(credentials=creds) + if existing_headers: + client._http_options.headers = existing_headers + async_live = live.AsyncLive(client) + capture = {} + + @contextlib.asynccontextmanager + async def mock_connect(uri, additional_headers=None, **kwargs): + capture['headers'] = additional_headers + yield mock_websocket + + with ( + mock.patch.object(live, 'ws_connect', new=mock_connect), + mock.patch.object(requests, 'Request', autospec=True) + ): + async with async_live.connect(model='models/test-model'): + pass + + assert 'headers' in capture + headers = capture['headers'] + assert headers.get('Authorization') == expected_auth + assert 'x-goog-api-key' not in headers + assert creds.refresh_called == expect_refresh diff --git a/google/genai/tests/live/test_live_music.py b/google/genai/tests/live/test_live_music.py index f51f8248d..7cee2e6f4 100644 --- a/google/genai/tests/live/test_live_music.py +++ b/google/genai/tests/live/test_live_music.py @@ -36,6 +36,7 @@ from ... import live_music from ... import types from .. import pytest_helper + try: import aiohttp AIOHTTP_NOT_INSTALLED = False @@ -49,10 +50,27 @@ ) -def mock_api_client(vertexai=False, credentials=None): +class FakeCredentials(Credentials): + def __init__(self, token='fake_token', valid=True): + super().__init__(token='placeholder') + self.token = token + self._valid = valid + self.refresh_called = False + + def refresh(self, request): + self.token = 'refreshed_token' + self._valid = True + self.refresh_called = True + + @property + def valid(self): + return self._valid + + +def mock_api_client(vertexai=False, credentials=None, api_key='TEST_API_KEY'): api_client = mock.MagicMock(spec=gl_client.BaseApiClient) if not vertexai: - api_client.api_key = 'TEST_API_KEY' + api_client.api_key = api_key api_client.location = None api_client.project = None else: @@ -67,6 +85,7 @@ def mock_api_client(vertexai=False, credentials=None): ) # Ensure headers exist api_client.vertexai = vertexai api_client._api_client = api_client + api_client._websocket_base_url = lambda: 'wss://test.com' return api_client @@ -142,6 +161,7 @@ def test_mldev_from_env(monkeypatch): assert not client.aio.live.music._api_client.vertexai assert client.aio.live.music._api_client.api_key == api_key assert isinstance(client.aio.live._api_client, api_client.BaseApiClient) + assert client.aio.live._api_client._http_options.headers['x-goog-api-key'] == api_key @requires_aiohttp @@ -360,3 +380,70 @@ async def test_setup_to_api(vertexai): else: expected_result['setup']['model'] = 'models/test_model' assert result == expected_result + +@pytest.mark.asyncio +async def test_connect_with_api_key(mock_websocket): + client = Client(api_key='TEST_API_KEY', http_options={'api_version': 'v1test'}) + client._api_client._websocket_base_url = lambda: 'wss://test.com' + live_module = client.aio.live.music + capture = {} + + @contextlib.asynccontextmanager + async def mock_connect(uri, additional_headers=None, **kwargs): + capture['uri'] = uri + capture['headers'] = additional_headers + yield mock_websocket + + with patch('google.genai.live_music.connect', new=mock_connect): + async with live_module.connect(model='test-model'): + pass + + assert capture['uri'] == 'wss://test.com/ws/google.ai.generativelanguage.v1test.GenerativeService.BidiGenerateMusic' + assert capture['headers']['x-goog-api-key'] == 'TEST_API_KEY' + assert 'Authorization' not in capture['headers'] + +@pytest.mark.parametrize( + "creds, existing_headers, expected_auth, expect_refresh", + [ + (FakeCredentials(), {}, 'Bearer fake_token', False), + (FakeCredentials(valid=False), {}, 'Bearer refreshed_token', True), + (FakeCredentials(token=None, valid=False), {}, 'Bearer refreshed_token', True), + (FakeCredentials(token='existing_token', valid=True), {}, 'Bearer existing_token', False), + (FakeCredentials(token='new_token', valid=True), {'Authorization': 'Bearer old_token'}, 'Bearer old_token', False), + ], +) +@pytest.mark.asyncio +async def test_connect_with_credentials( + mock_websocket, creds, existing_headers, expected_auth, expect_refresh +): + client = api_client.BaseApiClient(credentials=creds, http_options={'api_version': 'v1test'}) + if existing_headers: + client._http_options.headers = existing_headers + client._websocket_base_url = lambda: 'wss://test.com' + live_module = live_music.AsyncLiveMusic(client) + capture = {} + + @contextlib.asynccontextmanager + async def mock_connect(uri, additional_headers=None, **kwargs): + capture['uri'] = uri + capture['headers'] = additional_headers + yield mock_websocket + + with patch('google.genai.live_music.connect', new=mock_connect): + with patch('google.auth.transport.requests.Request', autospec=True): + async with live_module.connect(model='test-model'): + pass + + assert capture['uri'] == 'wss://test.com/ws/google.ai.generativelanguage.v1test.GenerativeService.BidiGenerateMusic' + headers = capture['headers'] + assert headers.get('Authorization') == expected_auth + assert 'x-goog-api-key' not in headers + assert creds.refresh_called == expect_refresh + +@pytest.mark.asyncio +async def test_connect_vertex_unsupported(mock_websocket): + client = Client(vertexai=True, project='test', location='us-central1') + live_module = client.aio.live.music + with pytest.raises(NotImplementedError): + async with live_module.connect(model='test-model'): + pass