Skip to content

Commit 9409b9c

Browse files
author
Abraham Lazaro Martinez
committed
fix: add more tests
1 parent 8946bd2 commit 9409b9c

File tree

2 files changed

+80
-17
lines changed

2 files changed

+80
-17
lines changed

py/plugins/vertex-ai/tests/model_garden/test_client.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,18 @@
2424
@patch('google.auth.default')
2525
@patch('google.auth.transport.requests.Request')
2626
@patch('openai.OpenAI')
27-
def test_client_initialization_with_explicit_project_id(
28-
mock_openai_cls, mock_request_cls, mock_default_auth
29-
):
27+
def test_client_initialization_with_explicit_project_id(mock_openai_cls, mock_request_cls, mock_default_auth):
3028
"""Unittests for init client."""
31-
mock_location = "location"
32-
mock_project_id = "project_id"
33-
mock_token = "token"
29+
mock_location = 'location'
30+
mock_project_id = 'project_id'
31+
mock_token = 'token'
3432

3533
mock_credentials = MagicMock()
3634
mock_credentials.token = mock_token
3735

38-
mock_default_auth.return_value = (mock_credentials, "project_id")
36+
mock_default_auth.return_value = (mock_credentials, 'project_id')
3937

40-
client_instance = OpenAIClient(
41-
location=mock_location,
42-
project_id=mock_project_id
43-
)
38+
client_instance = OpenAIClient(location=mock_location, project_id=mock_project_id)
4439

4540
mock_default_auth.assert_called_once()
4641
mock_credentials.refresh.assert_called_once()
@@ -52,17 +47,15 @@ def test_client_initialization_with_explicit_project_id(
5247
@patch('google.auth.default')
5348
@patch('google.auth.transport.requests.Request')
5449
@patch('openai.OpenAI')
55-
def test_client_initialization_without_explicit_project_id(
56-
mock_openai_cls, mock_request_cls, mock_default_auth
57-
):
50+
def test_client_initialization_without_explicit_project_id(mock_openai_cls, mock_request_cls, mock_default_auth):
5851
"""Unittests for init client."""
59-
mock_location = "location"
60-
mock_token = "token"
52+
mock_location = 'location'
53+
mock_token = 'token'
6154

6255
mock_credentials = MagicMock()
6356
mock_credentials.token = mock_token
6457

65-
mock_default_auth.return_value = (mock_credentials, "project_id")
58+
mock_default_auth.return_value = (mock_credentials, 'project_id')
6659

6760
client_instance = OpenAIClient(
6861
location=mock_location,

py/plugins/vertex-ai/tests/model_garden/test_model_garden.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,73 @@
1616

1717
"""Unittests for VertexAI Model Garden Models."""
1818

19+
from unittest.mock import MagicMock, patch
20+
21+
import pytest
22+
23+
from genkit.ai import GenkitRegistry
24+
from genkit.plugins.compat_oai.typing import SupportedOutputFormat
25+
from genkit.plugins.vertex_ai.model_garden.model_garden import ModelGarden
26+
27+
28+
@pytest.fixture
29+
@patch('genkit.plugins.vertex_ai.model_garden.client.OpenAIClient')
30+
def model_garden_instance(client):
31+
"""Model Garden fixture."""
32+
return ModelGarden(
33+
model='test', location='us-central1', project_id='project', registry=MagicMock(spec=GenkitRegistry)
34+
)
35+
36+
37+
@pytest.mark.parametrize(
38+
'model_name, expected',
39+
[
40+
(
41+
'meta/llama-3.1-405b-instruct-maas',
42+
{
43+
'name': 'ModelGarden - Meta - llama-3.1',
44+
'supports': {
45+
'constrained': None,
46+
'content_type': None,
47+
'context': None,
48+
'multiturn': True,
49+
'media': False,
50+
'tools': True,
51+
'system_role': True,
52+
'output': [
53+
'json_mode',
54+
'text',
55+
],
56+
'tool_choice': None,
57+
},
58+
},
59+
),
60+
(
61+
'meta/lazaro-model-pro-max',
62+
{
63+
'name': 'ModelGarden - meta/lazaro-model-pro-max',
64+
'supports': {
65+
'constrained': None,
66+
'content_type': None,
67+
'context': None,
68+
'multiturn': True,
69+
'media': True,
70+
'tools': True,
71+
'system_role': True,
72+
'output': [
73+
'json_mode',
74+
'text',
75+
],
76+
'tool_choice': None,
77+
},
78+
},
79+
),
80+
],
81+
)
82+
def test_get_model_info(model_name, expected, model_garden_instance):
83+
"""Unittest for get_model_info."""
84+
model_garden_instance.name = model_name
85+
86+
result = model_garden_instance.get_model_info()
87+
88+
assert result == expected

0 commit comments

Comments
 (0)