diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/client.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/client.py index a28d03aaf8..93519d00ad 100644 --- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/client.py +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/client.py @@ -14,7 +14,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from google.auth import default, transport +from google import auth from openai import OpenAI as _OpenAI @@ -26,10 +26,10 @@ def __new__(cls, **openai_params) -> _OpenAI: location = openai_params.get('location') project_id = openai_params.get('project_id') if project_id: - credentials, _ = default() + credentials, _ = auth.default() else: - credentials, project_id = default() + credentials, project_id = auth.default() - credentials.refresh(transport.requests.Request()) + credentials.refresh(auth.transport.requests.Request()) base_url = f'https://{location}-aiplatform.googleapis.com/v1beta1/projects/{project_id}/locations/{location}/endpoints/openapi' return _OpenAI(api_key=credentials.token, base_url=base_url) diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/model_garden.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/model_garden.py index e03a801d8a..7b522412ed 100644 --- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/model_garden.py +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/model_garden.py @@ -23,8 +23,7 @@ ) from genkit.plugins.compat_oai.models.model import OpenAIModel from genkit.plugins.compat_oai.typing import OpenAIConfig - -from .client import OpenAIClient +from genkit.plugins.vertex_ai.model_garden.client import OpenAIClient OPENAI_COMPAT = 'openai-compat' MODELGARDEN_PLUGIN_NAME = 'modelgarden' diff --git a/py/plugins/vertex-ai/tests/model_garden/test_client.py b/py/plugins/vertex-ai/tests/model_garden/test_client.py new file mode 100644 index 0000000000..58a7fe9371 --- /dev/null +++ b/py/plugins/vertex-ai/tests/model_garden/test_client.py @@ -0,0 +1,68 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Unittests for VertexAI Model Garden OpenAI Client.""" + +from unittest.mock import MagicMock, patch + +from genkit.plugins.vertex_ai.model_garden.client import OpenAIClient + + +@patch('google.auth.default') +@patch('google.auth.transport.requests.Request') +@patch('openai.OpenAI') +def test_client_initialization_with_explicit_project_id(mock_openai_cls, mock_request_cls, mock_default_auth): + """Unittests for init client.""" + mock_location = 'location' + mock_project_id = 'project_id' + mock_token = 'token' + + mock_credentials = MagicMock() + mock_credentials.token = mock_token + + mock_default_auth.return_value = (mock_credentials, 'project_id') + + client_instance = OpenAIClient(location=mock_location, project_id=mock_project_id) + + mock_default_auth.assert_called_once() + mock_credentials.refresh.assert_called_once() + mock_request_cls.assert_called_once() + + assert client_instance is not None + + +@patch('google.auth.default') +@patch('google.auth.transport.requests.Request') +@patch('openai.OpenAI') +def test_client_initialization_without_explicit_project_id(mock_openai_cls, mock_request_cls, mock_default_auth): + """Unittests for init client.""" + mock_location = 'location' + mock_token = 'token' + + mock_credentials = MagicMock() + mock_credentials.token = mock_token + + mock_default_auth.return_value = (mock_credentials, 'project_id') + + client_instance = OpenAIClient( + location=mock_location, + ) + + mock_default_auth.assert_called_once() + mock_credentials.refresh.assert_called_once() + mock_request_cls.assert_called_once() + + assert client_instance is not None diff --git a/py/plugins/vertex-ai/tests/model_garden/test_model_garden.py b/py/plugins/vertex-ai/tests/model_garden/test_model_garden.py new file mode 100644 index 0000000000..e71ed181d4 --- /dev/null +++ b/py/plugins/vertex-ai/tests/model_garden/test_model_garden.py @@ -0,0 +1,87 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Unittests for VertexAI Model Garden Models.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from genkit.ai import GenkitRegistry +from genkit.plugins.vertex_ai.model_garden.model_garden import ModelGarden + + +@pytest.fixture +@patch('genkit.plugins.vertex_ai.model_garden.model_garden.OpenAIClient') +def model_garden_instance(client): + """Model Garden fixture.""" + return ModelGarden( + model='test', location='us-central1', project_id='project', registry=MagicMock(spec=GenkitRegistry) + ) + + +@pytest.mark.parametrize( + 'model_name, expected', + [ + ( + 'meta/llama-3.1-405b-instruct-maas', + { + 'name': 'ModelGarden - Meta - llama-3.1', + 'supports': { + 'constrained': None, + 'content_type': None, + 'context': None, + 'multiturn': True, + 'media': False, + 'tools': True, + 'system_role': True, + 'output': [ + 'json_mode', + 'text', + ], + 'tool_choice': None, + }, + }, + ), + ( + 'meta/lazaro-model-pro-max', + { + 'name': 'ModelGarden - meta/lazaro-model-pro-max', + 'supports': { + 'constrained': None, + 'content_type': None, + 'context': None, + 'multiturn': True, + 'media': True, + 'tools': True, + 'system_role': True, + 'output': [ + 'json_mode', + 'text', + ], + 'tool_choice': None, + }, + }, + ), + ], +) +def test_get_model_info(model_name, expected, model_garden_instance): + """Unittest for get_model_info.""" + model_garden_instance.name = model_name + + result = model_garden_instance.get_model_info() + + assert result == expected