Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,11 @@ def normalize_config(config: Any) -> OpenAIConfig:
return config

if isinstance(config, dict):
if config.get('topK'):
del config['topK']
if config.get('topP'):
config['top_p'] = config['topP']
del config['topP']
return OpenAIConfig(**config)

raise ValueError(f'Expected request.config to be a dict or OpenAIConfig, got {type(config).__name__}.')
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class PluginSource(StrEnum):
GPT_4O_MINI = 'gpt-4o-mini'
O1_MINI = 'o1-mini'

LLAMA_3_1 = 'meta/llama3-405b-instruct-maas'
LLAMA_3_1 = 'meta/llama-3.1-405b-instruct-maas'
LLAMA_3_2 = 'meta/llama-3.2-90b-vision-instruct-maas'

SUPPORTED_OPENAI_MODELS: dict[str, ModelInfo] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@
"""ModelGarden API Compatible Plugin for Genkit."""

import os
from functools import cached_property

from genkit.ai import GenkitRegistry, Plugin
from genkit.blocks.model import model_action_metadata
from genkit.core.action.types import ActionKind
from genkit.plugins.compat_oai.models import SUPPORTED_OPENAI_COMPAT_MODELS
from genkit.plugins.compat_oai.typing import OpenAIConfig
from genkit.plugins.vertex_ai import constants as const

from .model_garden import MODELGARDEN_PLUGIN_NAME, ModelGarden
from .model_garden import MODELGARDEN_PLUGIN_NAME, ModelGarden, model_garden_name


class VertexAIModelGarden(Plugin):
Expand All @@ -35,8 +40,24 @@ class VertexAIModelGarden(Plugin):

name = MODELGARDEN_PLUGIN_NAME

def __init__(self, project_id: str | None = None, location: str | None = None, models: list[str] | None = None):
"""Initialize the plugin by registering actions with the registry."""
def __init__(
self,
project_id: str | None = None,
location: str | None = None,
models: list[str] | None = None,
) -> None:
"""Initializes the plugin and sets up its configuration.

This constructor prepares the plugin by assigning the Google Cloud project ID,
location, and a list of models to be used.

Args:
project_id: The Google Cloud project ID to use. If not provided, it attempts
to load from the `GCLOUD_PROJECT` environment variable.
location: The Google Cloud region to use for services. If not provided,
it defaults to `DEFAULT_REGION`.
models: An optional list of model names to register with the plugin.
"""
self.project_id = project_id if project_id is not None else os.getenv(const.GCLOUD_PROJECT)
self.location = location if location is not None else const.DEFAULT_REGION
self.models = models
Expand All @@ -55,3 +76,66 @@ def initialize(self, ai: GenkitRegistry) -> None:
registry=ai,
)
model_proxy.define_model()

def resolve_action(
self,
ai: GenkitRegistry,
kind: ActionKind,
name: str,
) -> None:
"""Resolves and action.

Args:
ai: The Genkit registry.
kind: The kind of action to resolve.
name: The name of the action to resolve.
"""
if kind == ActionKind.MODEL:
self._resolve_model(ai=ai, name=name)

def _resolve_model(self, ai: GenkitRegistry, name: str) -> None:
"""Resolves and defines a Model Garden Vertex AI model within the Genkit registry.

This internal method handles the logic for registering new models
of Vertex AI Model Garden that are compatible with OpenaI
based on the provided name.
It extracts a clean name, determines the model type, instantiates the
appropriate model class, and registers it with the Genkit AI registry.

Args:
ai: The Genkit AI registry instance to define the model in.
name: The name of the model to resolve. This name might include a
prefix indicating it's from a specific plugin.
"""
clean_name = (
name.replace(MODELGARDEN_PLUGIN_NAME + '/', '') if name.startswith(MODELGARDEN_PLUGIN_NAME) else name
)

model_proxy = ModelGarden(
model=clean_name,
location=self.location,
project_id=self.project_id,
registry=ai,
)
model_proxy.define_model()

@cached_property
def list_actions(self) -> list[dict[str, str]]:
"""Generate a list of available actions or models.

Returns:
list of actions dicts with the following shape:
{
'name': str,
'kind': ActionKind,
}
"""
actions_list = []
for model, model_info in SUPPORTED_OPENAI_COMPAT_MODELS.items():
actions_list.append(
model_action_metadata(
name=model_garden_name(model), info=model_info.model_dump(), config_schema=OpenAIConfig
)
)

return actions_list
1 change: 0 additions & 1 deletion py/samples/model-garden/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
plugins=[
VertexAIModelGarden(
location='us-central1',
models=['meta/llama-3.2-90b-vision-instruct-maas'],
),
],
)
Expand Down
Loading