diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index 2b5d7108f..f7204fd72 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -1,10 +1,8 @@ #!/usr/bin/env python # Copyright (c) 2024, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ -import json import os import pathlib -import re from datetime import datetime, timedelta from threading import Lock from typing import Any, Dict, List, Optional, Set, Union @@ -39,12 +37,11 @@ generate_tei_cmd_var, get_artifact_path, get_hf_model_info, - get_preferred_compatible_family, list_os_files_with_extension, load_config, upload_folder, ) -from ads.aqua.config.container_config import AquaContainerConfig, Usage +from ads.aqua.config.container_config import AquaContainerConfig from ads.aqua.constants import ( AQUA_MODEL_ARTIFACT_CONFIG, AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME, @@ -79,11 +76,8 @@ AquaModelReadme, AquaModelSummary, ImportModelDetails, - ModelFileDescription, ModelValidationResult, ) -from ads.aqua.model.enums import MultiModelSupportedTaskType -from ads.aqua.model.utils import extract_fine_tune_artifacts_path from ads.common.auth import default_signer from ads.common.oci_resource import SEARCH_TYPE, OCIResource from ads.common.utils import UNKNOWN, get_console_link, is_path_exists, read_file @@ -102,6 +96,7 @@ ) from ads.model import DataScienceModel from ads.model.common.utils import MetadataArtifactPathType +from ads.model.datascience_model_group import DataScienceModelGroup from ads.model.model_metadata import ( MetadataCustomCategory, ModelCustomMetadata, @@ -235,13 +230,16 @@ def create( def create_multi( self, models: List[AquaMultiModelRef], + model_custom_metadata: ModelCustomMetadata, + model_group_display_name: str, + model_group_description: str, + tags: Dict, + combined_model_names: str, project_id: Optional[str] = None, compartment_id: Optional[str] = None, - freeform_tags: Optional[Dict] = None, defined_tags: Optional[Dict] = None, - source_models: Optional[Dict[str, DataScienceModel]] = None, **kwargs, # noqa: ARG002 - ) -> DataScienceModel: + ) -> DataScienceModelGroup: """ Creates a multi-model grouping using the provided model list. @@ -249,250 +247,32 @@ def create_multi( ---------- models : List[AquaMultiModelRef] List of AquaMultiModelRef instances for creating a multi-model group. + model_custom_metadata : ModelCustomMetadata + Custom metadata for creating model group. + All model group custom metadata, including 'multi_model_metadata' and 'MULTI_MODEL_CONFIG' will be translated as a + list of dict and placed under environment variable 'OCI_MODEL_GROUP_CUSTOM_METADATA' in model deployment. + model_group_display_name: str + The model group display name. + model_group_description: str + The model group description. + tags: Dict + The tags of model group. + combined_model_names: str + The name of models to be grouped and deployed. project_id : Optional[str] The project ID for the multi-model group. compartment_id : Optional[str] The compartment ID for the multi-model group. - freeform_tags : Optional[Dict] - Freeform tags for the model. defined_tags : Optional[Dict] Defined tags for the model. - source_models: Optional[Dict[str, DataScienceModel]] - A mapping of model OCIDs to their corresponding `DataScienceModel` objects. - This dictionary contains metadata for all models involved in the multi-model deployment, - including both base models and fine-tuned weights. Returns ------- - DataScienceModel - Instance of DataScienceModel object. + DataScienceModelGroup + Instance of DataScienceModelGroup object. """ - - if not models: - raise AquaValueError( - "Model list cannot be empty. Please provide at least one model for deployment." - ) - - display_name_list = [] - model_file_description_list: List[ModelFileDescription] = [] - model_custom_metadata = ModelCustomMetadata() - - service_inference_containers = ( - self.get_container_config().to_dict().get("inference") - ) - - supported_container_families = [ - container_config_item.family - for container_config_item in service_inference_containers - if any( - usage.upper() in container_config_item.usages - for usage in [Usage.MULTI_MODEL, Usage.OTHER] - ) - ] - - if not supported_container_families: - raise AquaValueError( - "Currently, there are no containers that support multi-model deployment." - ) - - selected_models_deployment_containers = set() - - if not source_models: - # Collect all unique model IDs (including fine-tuned models) - source_model_ids = list( - {model_id for model in models for model_id in model.all_model_ids()} - ) - logger.debug( - "Fetching source model metadata for model IDs: %s", source_model_ids - ) - - # Fetch source model metadata - source_models = self.get_multi_source(source_model_ids) or {} - - # Process each model in the input list - for model in models: - # Retrieve base model metadata - source_model: DataScienceModel = source_models.get(model.model_id) - if not source_model: - logger.error( - "Failed to fetch metadata for base model ID: %s", model.model_id - ) - raise AquaValueError( - f"Unable to retrieve metadata for base model ID: {model.model_id}." - ) - - # Use display name as fallback if model name not provided - model.model_name = model.model_name or source_model.display_name - - # Validate model file description - model_file_description = source_model.model_file_description - if not model_file_description: - logger.error( - "Model '%s' (%s) has no file description.", - source_model.display_name, - model.model_id, - ) - raise AquaValueError( - f"Model '{source_model.display_name}' (ID: {model.model_id}) has no file description. " - "Please register the model with a file description." - ) - - # Track model file description in a validated structure - model_file_description_list.append( - ModelFileDescription(**model_file_description) - ) - - # Ensure base model has a valid artifact - if not source_model.artifact: - logger.error( - "Base model '%s' (%s) has no artifact.", - model.model_name, - model.model_id, - ) - raise AquaValueError( - f"Model '{model.model_name}' (ID: {model.model_id}) has no registered artifacts. " - "Please register the model before deployment." - ) - - # Set base model artifact path - model.artifact_location = source_model.artifact - logger.debug( - "Model '%s' artifact path set to: %s", - model.model_name, - model.artifact_location, - ) - - display_name_list.append(model.model_name) - - # Extract model task metadata from source model - self._extract_model_task(model, source_model) - - # Process fine-tuned weights if provided - for ft_model in model.fine_tune_weights or []: - fine_tune_source_model: DataScienceModel = source_models.get( - ft_model.model_id - ) - if not fine_tune_source_model: - logger.error( - "Failed to fetch metadata for fine-tuned model ID: %s", - ft_model.model_id, - ) - raise AquaValueError( - f"Unable to retrieve metadata for fine-tuned model ID: {ft_model.model_id}." - ) - - # Validate model file description - ft_model_file_description = ( - fine_tune_source_model.model_file_description - ) - if not ft_model_file_description: - logger.error( - "Model '%s' (%s) has no file description.", - fine_tune_source_model.display_name, - ft_model.model_id, - ) - raise AquaValueError( - f"Model '{fine_tune_source_model.display_name}' (ID: {ft_model.model_id}) has no file description. " - "Please register the model with a file description." - ) - - # Track model file description in a validated structure - model_file_description_list.append( - ModelFileDescription(**ft_model_file_description) - ) - - # Extract fine-tuned model path - _, fine_tune_path = extract_fine_tune_artifacts_path( - fine_tune_source_model - ) - logger.debug( - "Resolved fine-tuned model path for '%s': %s", - ft_model.model_id, - fine_tune_path, - ) - ft_model.model_path = fine_tune_path - - # Use fallback name if needed - ft_model.model_name = ( - ft_model.model_name or fine_tune_source_model.display_name - ) - - display_name_list.append(ft_model.model_name) - - # Validate deployment container consistency - deployment_container = source_model.custom_metadata_list.get( - ModelCustomMetadataFields.DEPLOYMENT_CONTAINER, - ModelCustomMetadataItem( - key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER - ), - ).value - - if deployment_container not in supported_container_families: - logger.error( - "Unsupported deployment container '%s' for model '%s'. Supported: %s", - deployment_container, - source_model.id, - supported_container_families, - ) - raise AquaValueError( - f"Unsupported deployment container '{deployment_container}' for model '{source_model.id}'. " - f"Only {supported_container_families} are supported for multi-model deployments." - ) - - selected_models_deployment_containers.add(deployment_container) - - if not selected_models_deployment_containers: - raise AquaValueError( - "None of the selected models are associated with a recognized container family. " - "Please review the selected models, or select a different group of models." - ) - - # Check if the all models in the group shares same container family - if len(selected_models_deployment_containers) > 1: - deployment_container = get_preferred_compatible_family( - selected_families=selected_models_deployment_containers - ) - if not deployment_container: - raise AquaValueError( - "The selected models are associated with different container families: " - f"{list(selected_models_deployment_containers)}." - "For multi-model deployment, all models in the group must belong to the same container " - "family or to compatible container families." - ) - else: - deployment_container = selected_models_deployment_containers.pop() - - # Generate model group details - timestamp = datetime.now().strftime("%Y%m%d") - model_group_display_name = f"model_group_{timestamp}" - combined_models = ", ".join(display_name_list) - model_group_description = f"Multi-model grouping using {combined_models}." - - # Add global metadata - model_custom_metadata.add( - key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER, - value=deployment_container, - description=f"Inference container mapping for {model_group_display_name}", - category="Other", - ) - model_custom_metadata.add( - key=ModelCustomMetadataFields.MULTIMODEL_GROUP_COUNT, - value=str(len(models)), - description="Number of models in the group.", - category="Other", - ) - - # Combine tags. The `Tags.AQUA_TAG` has been excluded, because we don't want to show - # the models created for multi-model purpose in the AQUA models list. - tags = { - # Tags.AQUA_TAG: "active", - Tags.MULTIMODEL_TYPE_TAG: "true", - **(freeform_tags or {}), - } - - # Create multi-model group - custom_model = ( - DataScienceModel() + custom_model_group = ( + DataScienceModelGroup() .with_compartment_id(compartment_id) .with_project_id(project_id) .with_display_name(model_group_display_name) @@ -500,47 +280,23 @@ def create_multi( .with_freeform_tags(**tags) .with_defined_tags(**(defined_tags or {})) .with_custom_metadata_list(model_custom_metadata) + # TODO: add member model inference key + .with_member_models([{"model_id": model.model_id for model in models}]) ) - - # Update multi model file description to attach artifacts - custom_model.with_model_file_description( - json_dict=ModelFileDescription( - models=[ - models - for model_file_description in model_file_description_list - for models in model_file_description.models - ] - ).model_dump(by_alias=True) - ) - - # Finalize creation - custom_model.create(model_by_reference=True) + custom_model_group.create() logger.info( - f"Aqua Model '{custom_model.id}' created with models: {', '.join(display_name_list)}." - ) - - # Create custom metadata for multi model metadata - custom_model.create_custom_metadata_artifact( - metadata_key_name=ModelCustomMetadataFields.MULTIMODEL_METADATA, - artifact_path_or_content=json.dumps( - [model.model_dump() for model in models] - ).encode(), - path_type=MetadataArtifactPathType.CONTENT, - ) - - logger.debug( - f"Multi model metadata uploaded for Aqua model: {custom_model.id}." + f"Aqua Model Group'{custom_model_group.id}' created with models: {combined_model_names}." ) # Track telemetry event self.telemetry.record_event_async( category="aqua/multimodel", action="create", - detail=combined_models, + detail=combined_model_names, ) - return custom_model + return custom_model_group @telemetry(entry_point="plugin=model&action=get", name="aqua") def get(self, model_id: str) -> "AquaModel": @@ -806,26 +562,6 @@ def edit_registered_model( else: raise AquaRuntimeError("Only registered unverified models can be edited.") - def _extract_model_task( - self, - model: AquaMultiModelRef, - source_model: DataScienceModel, - ) -> None: - """In a Multi Model Deployment, will set model_task parameter in AquaMultiModelRef from freeform tags or user""" - # user does not supply model task, we extract from model metadata - if not model.model_task: - model.model_task = source_model.freeform_tags.get(Tags.TASK, UNKNOWN) - - task_tag = re.sub(r"-", "_", model.model_task).lower() - # re-visit logic when more model task types are supported - if task_tag in MultiModelSupportedTaskType: - model.model_task = task_tag - else: - raise AquaValueError( - f"Invalid or missing {task_tag} tag for selected model {source_model.display_name}. " - f"Currently only `{MultiModelSupportedTaskType.values()}` models are supported for multi model deployment." - ) - def _fetch_metric_from_metadata( self, custom_metadata_list: ModelCustomMetadata, diff --git a/ads/aqua/modeldeployment/deployment.py b/ads/aqua/modeldeployment/deployment.py index f86881b75..c000c9059 100644 --- a/ads/aqua/modeldeployment/deployment.py +++ b/ads/aqua/modeldeployment/deployment.py @@ -4,10 +4,11 @@ import json +import re import shlex import threading from datetime import datetime, timedelta -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from cachetools import TTLCache, cached from oci.data_science.models import ModelDeploymentShapeSummary @@ -29,6 +30,7 @@ get_container_params_type, get_ocid_substring, get_params_list, + get_preferred_compatible_family, get_resource_name, get_restricted_params_by_container, load_gpu_shapes_index, @@ -48,6 +50,7 @@ from ads.aqua.data import AquaResourceIdentifier from ads.aqua.model import AquaModelApp from ads.aqua.model.constants import AquaModelMetadataKeys, ModelCustomMetadataFields +from ads.aqua.model.enums import MultiModelSupportedTaskType from ads.aqua.model.utils import ( extract_base_model_from_ft, extract_fine_tune_artifacts_path, @@ -79,13 +82,14 @@ PROJECT_OCID, ) from ads.model.datascience_model import DataScienceModel +from ads.model.datascience_model_group import DataScienceModelGroup from ads.model.deployment import ( ModelDeployment, ModelDeploymentContainerRuntime, ModelDeploymentInfrastructure, ModelDeploymentMode, ) -from ads.model.model_metadata import ModelCustomMetadataItem +from ads.model.model_metadata import ModelCustomMetadata, ModelCustomMetadataItem from ads.telemetry import telemetry @@ -325,21 +329,352 @@ def create( f"Multi models ({source_model_ids}) provided. Delegating to multi model creation method." ) - aqua_model = model_app.create_multi( + ( + model_group_display_name, + model_group_description, + tags, + model_custom_metadata, + combined_model_names, + ) = self._build_model_group_configs( + models=create_deployment_details.models, + create_deployment_details=create_deployment_details, + model_config_summary=model_config_summary, + freeform_tags=freeform_tags, + source_models=source_models, + ) + + aqua_model_group = model_app.create_multi( models=create_deployment_details.models, + model_custom_metadata=model_custom_metadata, + model_group_display_name=model_group_display_name, + model_group_description=model_group_description, + tags=tags, + combined_model_names=combined_model_names, compartment_id=compartment_id, project_id=project_id, - freeform_tags=freeform_tags, defined_tags=defined_tags, - source_models=source_models, ) return self._create_multi( - aqua_model=aqua_model, - model_config_summary=model_config_summary, + aqua_model_group=aqua_model_group, create_deployment_details=create_deployment_details, container_config=container_config, ) + def _build_model_group_configs( + self, + models: List[AquaMultiModelRef], + create_deployment_details: CreateModelDeploymentDetails, + model_config_summary: ModelDeploymentConfigSummary, + freeform_tags: Optional[Dict] = None, + source_models: Optional[Dict[str, DataScienceModel]] = None, + **kwargs, # noqa: ARG002 + ) -> tuple: + """ + Builds configs for a multi-model grouping using the provided model list. + + Parameters + ---------- + models : List[AquaMultiModelRef] + List of AquaMultiModelRef instances for creating a multi-model group. + create_deployment_details : CreateModelDeploymentDetails + An instance of CreateModelDeploymentDetails containing all required and optional + fields for creating a model deployment via Aqua. + model_config_summary : ModelConfigSummary + Summary Model Deployment configuration for the group of models. + freeform_tags : Optional[Dict] + Freeform tags for the model. + source_models: Optional[Dict[str, DataScienceModel]] + A mapping of model OCIDs to their corresponding `DataScienceModel` objects. + This dictionary contains metadata for all models involved in the multi-model deployment, + including both base models and fine-tuned weights. + + Returns + ------- + tuple + A tuple of required metadata ('multi_model_metadata' and 'MULTI_MODEL_CONFIG') and strings to create model group. + """ + + if not models: + raise AquaValueError( + "Model list cannot be empty. Please provide at least one model for deployment." + ) + + display_name_list = [] + model_custom_metadata = ModelCustomMetadata() + + service_inference_containers = ( + self.get_container_config().to_dict().get("inference") + ) + + supported_container_families = [ + container_config_item.family + for container_config_item in service_inference_containers + if any( + usage.upper() in container_config_item.usages + for usage in [Usage.MULTI_MODEL, Usage.OTHER] + ) + ] + + if not supported_container_families: + raise AquaValueError( + "Currently, there are no containers that support multi-model deployment." + ) + + selected_models_deployment_containers = set() + + if not source_models: + # Collect all unique model IDs (including fine-tuned models) + source_model_ids = list( + {model_id for model in models for model_id in model.all_model_ids()} + ) + logger.debug( + "Fetching source model metadata for model IDs: %s", source_model_ids + ) + + # Fetch source model metadata + source_models = self.get_multi_source(source_model_ids) or {} + + # Process each model in the input list + for model in models: + # Retrieve base model metadata + source_model: DataScienceModel = source_models.get(model.model_id) + if not source_model: + logger.error( + "Failed to fetch metadata for base model ID: %s", model.model_id + ) + raise AquaValueError( + f"Unable to retrieve metadata for base model ID: {model.model_id}." + ) + + # Use display name as fallback if model name not provided + model.model_name = model.model_name or source_model.display_name + + # Validate model file description + model_file_description = source_model.model_file_description + if not model_file_description: + logger.error( + "Model '%s' (%s) has no file description.", + source_model.display_name, + model.model_id, + ) + raise AquaValueError( + f"Model '{source_model.display_name}' (ID: {model.model_id}) has no file description. " + "Please register the model with a file description." + ) + + # Ensure base model has a valid artifact + if not source_model.artifact: + logger.error( + "Base model '%s' (%s) has no artifact.", + model.model_name, + model.model_id, + ) + raise AquaValueError( + f"Model '{model.model_name}' (ID: {model.model_id}) has no registered artifacts. " + "Please register the model before deployment." + ) + + # Set base model artifact path + model.artifact_location = source_model.artifact + logger.debug( + "Model '%s' artifact path set to: %s", + model.model_name, + model.artifact_location, + ) + + display_name_list.append(model.model_name) + + # Extract model task metadata from source model + self._extract_model_task(model, source_model) + + # Process fine-tuned weights if provided + for ft_model in model.fine_tune_weights or []: + fine_tune_source_model: DataScienceModel = source_models.get( + ft_model.model_id + ) + if not fine_tune_source_model: + logger.error( + "Failed to fetch metadata for fine-tuned model ID: %s", + ft_model.model_id, + ) + raise AquaValueError( + f"Unable to retrieve metadata for fine-tuned model ID: {ft_model.model_id}." + ) + + # Validate model file description + ft_model_file_description = ( + fine_tune_source_model.model_file_description + ) + if not ft_model_file_description: + logger.error( + "Model '%s' (%s) has no file description.", + fine_tune_source_model.display_name, + ft_model.model_id, + ) + raise AquaValueError( + f"Model '{fine_tune_source_model.display_name}' (ID: {ft_model.model_id}) has no file description. " + "Please register the model with a file description." + ) + + # Extract fine-tuned model path + _, fine_tune_path = extract_fine_tune_artifacts_path( + fine_tune_source_model + ) + logger.debug( + "Resolved fine-tuned model path for '%s': %s", + ft_model.model_id, + fine_tune_path, + ) + ft_model.model_path = ( + ft_model.model_id + "/" + fine_tune_path.lstrip("/") + ) + + # Use fallback name if needed + ft_model.model_name = ( + ft_model.model_name or fine_tune_source_model.display_name + ) + + display_name_list.append(ft_model.model_name) + + # Validate deployment container consistency + deployment_container = source_model.custom_metadata_list.get( + ModelCustomMetadataFields.DEPLOYMENT_CONTAINER, + ModelCustomMetadataItem( + key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER + ), + ).value + + if deployment_container not in supported_container_families: + logger.error( + "Unsupported deployment container '%s' for model '%s'. Supported: %s", + deployment_container, + source_model.id, + supported_container_families, + ) + raise AquaValueError( + f"Unsupported deployment container '{deployment_container}' for model '{source_model.id}'. " + f"Only {supported_container_families} are supported for multi-model deployments." + ) + + selected_models_deployment_containers.add(deployment_container) + + if not selected_models_deployment_containers: + raise AquaValueError( + "None of the selected models are associated with a recognized container family. " + "Please review the selected models, or select a different group of models." + ) + + # Check if the all models in the group shares same container family + if len(selected_models_deployment_containers) > 1: + deployment_container = get_preferred_compatible_family( + selected_families=selected_models_deployment_containers + ) + if not deployment_container: + raise AquaValueError( + "The selected models are associated with different container families: " + f"{list(selected_models_deployment_containers)}." + "For multi-model deployment, all models in the group must belong to the same container " + "family or to compatible container families." + ) + else: + deployment_container = selected_models_deployment_containers.pop() + + # Generate model group details + timestamp = datetime.now().strftime("%Y%m%d") + model_group_display_name = f"model_group_{timestamp}" + combined_model_names = ", ".join(display_name_list) + model_group_description = f"Multi-model grouping using {combined_model_names}." + + # Add global metadata + model_custom_metadata.add( + key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER, + value=deployment_container, + description=f"Inference container mapping for {model_group_display_name}", + category="Other", + ) + model_custom_metadata.add( + key=ModelCustomMetadataFields.MULTIMODEL_GROUP_COUNT, + value=str(len(models)), + description="Number of models in the group.", + category="Other", + ) + model_custom_metadata.add( + key=AQUA_MULTI_MODEL_CONFIG, + value=self._build_model_group_config( + create_deployment_details=create_deployment_details, + model_config_summary=model_config_summary, + deployment_container=deployment_container, + ).model_dump_json(), + description="Configs required to deploy multi models.", + category="Other", + ) + model_custom_metadata.add( + key=ModelCustomMetadataFields.MULTIMODEL_METADATA, + value=json.dumps([model.model_dump() for model in models]), + description="Metadata to store user's multi model input.", + category="Other", + ) + + # Combine tags. The `Tags.AQUA_TAG` has been excluded, because we don't want to show + # the models created for multi-model purpose in the AQUA models list. + tags = { + # Tags.AQUA_TAG: "active", + Tags.MULTIMODEL_TYPE_TAG: "true", + **(freeform_tags or {}), + } + + return ( + model_group_display_name, + model_group_description, + tags, + model_custom_metadata, + combined_model_names, + ) + + def _extract_model_task( + self, + model: AquaMultiModelRef, + source_model: DataScienceModel, + ) -> None: + """In a Multi Model Deployment, will set model_task parameter in AquaMultiModelRef from freeform tags or user""" + # user does not supply model task, we extract from model metadata + if not model.model_task: + model.model_task = source_model.freeform_tags.get(Tags.TASK, UNKNOWN) + + task_tag = re.sub(r"-", "_", model.model_task).lower() + # re-visit logic when more model task types are supported + if task_tag in MultiModelSupportedTaskType: + model.model_task = task_tag + else: + raise AquaValueError( + f"Invalid or missing {task_tag} tag for selected model {source_model.display_name}. " + f"Currently only `{MultiModelSupportedTaskType.values()}` models are supported for multi model deployment." + ) + + def _build_model_group_config( + self, + create_deployment_details, + model_config_summary, + deployment_container: str, + ) -> ModelGroupConfig: + """Builds model group config required to deploy multi models.""" + container_type_key = ( + create_deployment_details.container_family or deployment_container + ) + container_config = self.get_container_config_item(container_type_key) + container_spec = container_config.spec if container_config else UNKNOWN + + container_params = container_spec.cli_param if container_spec else UNKNOWN + + multi_model_config = ModelGroupConfig.from_create_model_deployment_details( + create_deployment_details, + model_config_summary, + container_type_key, + container_params, + ) + + return multi_model_config + def _create( self, aqua_model: DataScienceModel, @@ -562,8 +897,7 @@ def _create( def _create_multi( self, - aqua_model: DataScienceModel, - model_config_summary: ModelDeploymentConfigSummary, + aqua_model_group: DataScienceModelGroup, create_deployment_details: CreateModelDeploymentDetails, container_config: AquaContainerConfig, ) -> AquaDeployment: @@ -571,15 +905,14 @@ def _create_multi( Parameters ---------- - model_config_summary : model_config_summary - Summary Model Deployment configuration for the group of models. - aqua_model : DataScienceModel - An instance of Aqua data science model. + aqua_model_group : DataScienceModelGroup + An instance of Aqua data science model group. create_deployment_details : CreateModelDeploymentDetails An instance of CreateModelDeploymentDetails containing all required and optional fields for creating a model deployment via Aqua. container_config: Dict Container config dictionary. + Returns ------- AquaDeployment @@ -589,23 +922,12 @@ def _create_multi( env_var = {**(create_deployment_details.env_var or UNKNOWN_DICT)} container_type_key = self._get_container_type_key( - model=aqua_model, + model=aqua_model_group, container_family=create_deployment_details.container_family, ) container_config = self.get_container_config_item(container_type_key) container_spec = container_config.spec if container_config else UNKNOWN - container_params = container_spec.cli_param if container_spec else UNKNOWN - - multi_model_config = ModelGroupConfig.from_create_model_deployment_details( - create_deployment_details, - model_config_summary, - container_type_key, - container_params, - ) - - env_var.update({AQUA_MULTI_MODEL_CONFIG: multi_model_config.model_dump_json()}) - env_vars = container_spec.env_vars if container_spec else [] for env in env_vars: if isinstance(env, dict): @@ -614,7 +936,7 @@ def _create_multi( if key not in env_var: env_var.update(env) - logger.info(f"Env vars used for deploying {aqua_model.id} : {env_var}.") + logger.info(f"Env vars used for deploying {aqua_model_group.id} : {env_var}.") container_image_uri = ( create_deployment_details.container_image_uri @@ -627,7 +949,7 @@ def _create_multi( container_spec.health_check_port if container_spec else None ) tags = { - Tags.AQUA_MODEL_ID_TAG: aqua_model.id, + Tags.AQUA_MODEL_ID_TAG: aqua_model_group.id, Tags.MULTIMODEL_TYPE_TAG: "true", Tags.AQUA_TAG: "active", **(create_deployment_details.freeform_tags or UNKNOWN_DICT), @@ -637,7 +959,7 @@ def _create_multi( aqua_deployment = self._create_deployment( create_deployment_details=create_deployment_details, - aqua_model_id=aqua_model.id, + aqua_model_id=aqua_model_group.id, model_name=model_name, model_type=AQUA_MODEL_TYPE_MULTI, container_image_uri=container_image_uri, @@ -732,11 +1054,14 @@ def _create_deployment( .with_health_check_port(health_check_port) .with_env(env_var) .with_deployment_mode(ModelDeploymentMode.HTTPS) - .with_model_uri(aqua_model_id) .with_region(self.region) .with_overwrite_existing_artifact(True) .with_remove_existing_artifact(True) ) + if self._if_model_group(aqua_model_id): + container_runtime.with_model_group_id(aqua_model_id) + else: + container_runtime.with_model_uri(aqua_model_id) if cmd_var: container_runtime.with_cmd(cmd_var) @@ -794,7 +1119,9 @@ def _create_deployment( ) @staticmethod - def _get_container_type_key(model: DataScienceModel, container_family: str) -> str: + def _get_container_type_key( + model: Union[DataScienceModel, DataScienceModelGroup], container_family: str + ) -> str: container_type_key = UNKNOWN if container_family: container_type_key = container_family @@ -977,7 +1304,12 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail": f"Invalid multi model deployment {model_deployment_id}." f"Make sure the {Tags.AQUA_MODEL_ID_TAG} tag is added to the deployment." ) - aqua_model = DataScienceModel.from_id(aqua_model_id) + + if self._if_model_group(aqua_model_id): + aqua_model = DataScienceModelGroup.from_id(aqua_model_id) + else: + aqua_model = DataScienceModel.from_id(aqua_model_id) + custom_metadata_list = aqua_model.custom_metadata_list multi_model_metadata_value = custom_metadata_list.get( ModelCustomMetadataFields.MULTIMODEL_METADATA, @@ -991,7 +1323,9 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail": f"Ensure that the required custom metadata `{ModelCustomMetadataFields.MULTIMODEL_METADATA}` is added to the AQUA multi-model `{aqua_model.display_name}` ({aqua_model.id})." ) multi_model_metadata = json.loads( - aqua_model.dsc_model.get_custom_metadata_artifact( + multi_model_metadata_value + if isinstance(aqua_model, DataScienceModelGroup) + else aqua_model.dsc_model.get_custom_metadata_artifact( metadata_key_name=ModelCustomMetadataFields.MULTIMODEL_METADATA ).decode("utf-8") ) @@ -1006,6 +1340,11 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail": log=AquaResourceIdentifier(log_id, log_name, log_url), ) + @staticmethod + def _if_model_group(model_id: str) -> bool: + """Checks if it's model group id or not.""" + return "datasciencemodelgroup" in model_id.lower() + @telemetry( entry_point="plugin=deployment&action=get_deployment_config", name="aqua" ) diff --git a/ads/aqua/modeldeployment/entities.py b/ads/aqua/modeldeployment/entities.py index ebce26dc8..0b65bc213 100644 --- a/ads/aqua/modeldeployment/entities.py +++ b/ads/aqua/modeldeployment/entities.py @@ -10,6 +10,7 @@ from ads.aqua import logger from ads.aqua.common.entities import AquaMultiModelRef from ads.aqua.common.enums import Tags +from ads.aqua.common.errors import AquaValueError from ads.aqua.config.utils.serializer import Serializable from ads.aqua.constants import UNKNOWN_DICT from ads.aqua.data import AquaResourceIdentifier @@ -21,6 +22,7 @@ from ads.common.serializer import DataClassSerializable from ads.common.utils import UNKNOWN, get_console_link from ads.model.datascience_model import DataScienceModel +from ads.model.deployment.model_deployment import ModelDeploymentType from ads.model.model_metadata import ModelCustomMetadataItem @@ -147,13 +149,39 @@ def from_oci_model_deployment( AquaDeployment: The instance of the Aqua model deployment. """ - instance_configuration = oci_model_deployment.model_deployment_configuration_details.model_configuration_details.instance_configuration + model_deployment_configuration_details = ( + oci_model_deployment.model_deployment_configuration_details + ) + if ( + model_deployment_configuration_details.deployment_type + == ModelDeploymentType.SINGLE_MODEL + ): + instance_configuration = model_deployment_configuration_details.model_configuration_details.instance_configuration + instance_count = model_deployment_configuration_details.model_configuration_details.scaling_policy.instance_count + model_id = model_deployment_configuration_details.model_configuration_details.model_id + elif ( + model_deployment_configuration_details.deployment_type + == ModelDeploymentType.MODEL_GROUP + ): + instance_configuration = model_deployment_configuration_details.infrastructure_configuration_details.instance_configuration + instance_count = model_deployment_configuration_details.infrastructure_configuration_details.scaling_policy.instance_count + model_id = model_deployment_configuration_details.model_group_configuration_details.model_group_id + else: + allowed_deployment_types = ", ".join( + [key for key in dir(ModelDeploymentType) if not key.startswith("__")] + ) + raise AquaValueError( + f"Invalid AQUA deployment with type {model_deployment_configuration_details.deployment_type}." + f"Only {allowed_deployment_types} are supported at this moment. Specify a different AQUA model deployment." + ) + instance_shape_config_details = ( instance_configuration.model_deployment_instance_shape_config_details ) - instance_count = oci_model_deployment.model_deployment_configuration_details.model_configuration_details.scaling_policy.instance_count - environment_variables = oci_model_deployment.model_deployment_configuration_details.environment_configuration_details.environment_variables - cmd = oci_model_deployment.model_deployment_configuration_details.environment_configuration_details.cmd + environment_variables = model_deployment_configuration_details.environment_configuration_details.environment_variables + cmd = ( + model_deployment_configuration_details.environment_configuration_details.cmd + ) shape_info = ShapeInfo( instance_shape=instance_configuration.instance_shape_name, instance_count=instance_count, @@ -168,7 +196,6 @@ def from_oci_model_deployment( else None ), ) - model_id = oci_model_deployment._model_deployment_configuration_details.model_configuration_details.model_id tags = {} tags.update(oci_model_deployment.freeform_tags or UNKNOWN_DICT) tags.update(oci_model_deployment.defined_tags or UNKNOWN_DICT) diff --git a/ads/aqua/modeldeployment/model_group_config.py b/ads/aqua/modeldeployment/model_group_config.py index e452ec7f5..99f7c2fee 100644 --- a/ads/aqua/modeldeployment/model_group_config.py +++ b/ads/aqua/modeldeployment/model_group_config.py @@ -4,7 +4,7 @@ from typing import List, Optional, Tuple, Union -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field from typing_extensions import Self from ads.aqua import logger @@ -61,18 +61,19 @@ class BaseModelSpec(BaseModel): description="Optional list of fine-tuned model variants associated with this base model.", ) - @field_validator("model_path") @classmethod - def clean_model_path(cls, artifact_path_prefix: str) -> str: - """Validates and cleans the file path for model_path parameter.""" - if ObjectStorageDetails.is_oci_path(artifact_path_prefix): - os_path = ObjectStorageDetails.from_path(artifact_path_prefix) - artifact_path_prefix = os_path.filepath.rstrip("/") - return artifact_path_prefix - - raise AquaValueError( - "The base model path is not available in the model artifact." - ) + def build_model_path(cls, model_id: str, artifact_path_prefix: str) -> str: + """Cleans and builds the file path for model_path parameter + to format: / + """ + if not ObjectStorageDetails.is_oci_path(artifact_path_prefix): + raise AquaValueError( + "The base model path is not available in the model artifact." + ) + + os_path = ObjectStorageDetails.from_path(artifact_path_prefix) + artifact_path_prefix = os_path.filepath.rstrip("/") + return model_id + "/" + artifact_path_prefix.lstrip("/") @classmethod def dedup_lora_modules(cls, fine_tune_weights: List[LoraModuleSpec]): @@ -99,7 +100,7 @@ def from_aqua_multi_model_ref( return cls( model_id=model.model_id, - model_path=model.artifact_location, + model_path=cls.build_model_path(model.model_id, model.artifact_location), params=model_params, model_task=model.model_task, fine_tune_weights=cls.dedup_lora_modules(model.fine_tune_weights), diff --git a/ads/model/datascience_model_group.py b/ads/model/datascience_model_group.py new file mode 100644 index 000000000..cc32ffa9c --- /dev/null +++ b/ads/model/datascience_model_group.py @@ -0,0 +1,837 @@ +#!/usr/bin/env python + +# Copyright (c) 2025 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +import copy +from typing import Dict, List, Union + +from ads.common.utils import batch_convert_case +from ads.config import COMPARTMENT_OCID, PROJECT_OCID +from ads.jobs.builders.base import Builder +from ads.model.model_metadata import ModelCustomMetadata +from ads.model.service.oci_datascience_model_group import OCIDataScienceModelGroup + +try: + from oci.data_science.models import ( + CreateModelGroupDetails, + CustomMetadata, + HomogeneousModelGroupDetails, + MemberModelDetails, + MemberModelEntries, + ModelGroup, + ModelGroupDetails, + ModelGroupSummary, + UpdateModelGroupDetails, + ) +except ModuleNotFoundError as err: + raise ModuleNotFoundError( + "The oci model group module was not found. Please run `pip install oci` " + "to install the latest oci sdk." + ) from err + +DEFAULT_WAIT_TIME = 1200 +DEFAULT_POLL_INTERVAL = 10 +ALLOWED_CREATE_TYPES = ["CREATE", "CLONE"] +MODEL_GROUP_KIND = "datascienceModelGroup" + + +class DataScienceModelGroup(Builder): + """Represents a Data Science Model Group. + + Attributes + ---------- + id: str + Model group ID. + project_id: str + Project OCID. + compartment_id: str + Compartment OCID. + display_name: str + Model group name. + description: str + Model group description. + freeform_tags: Dict[str, str] + Model group freeform tags. + defined_tags: Dict[str, Dict[str, object]] + Model group defined tags. + custom_metadata_list: ModelCustomMetadata + Model group custom metadata. + model_group_version_history_name: str + Model group version history name + model_group_version_history_id: str + Model group version history ID + version_label: str + Model group version label + version_id: str + Model group version id + lifecycle_state: str + Model group lifecycle state + lifecycle_details: str + Model group lifecycle details + + Methods + ------- + activate(self, ...) -> "DataScienceModelGroup" + Activates model group. + create(self, ...) -> "DataScienceModelGroup" + Creates model group. + deactivate(self, ...) -> "DataScienceModelGroup" + Deactivates model group. + delete(self, ...) -> "DataScienceModelGroup": + Deletes model group. + to_dict(self) -> dict + Serializes model group to a dictionary. + from_id(cls, id: str) -> "DataScienceModelGroup" + Gets an existing model group by OCID. + from_dict(cls, config: dict) -> "DataScienceModelGroup" + Loads model group instance from a dictionary of configurations. + update(self, ...) -> "DataScienceModelGroup" + Updates datascience model group in model catalog. + list(cls, compartment_id: str = None, **kwargs) -> List["DataScienceModelGroup"] + Lists datascience model groups in a given compartment. + sync(self): + Sync up a datascience model group with OCI datascience model group. + with_project_id(self, project_id: str) -> "DataScienceModelGroup" + Sets the project ID. + with_description(self, description: str) -> "DataScienceModelGroup" + Sets the description. + with_compartment_id(self, compartment_id: str) -> "DataScienceModelGroup" + Sets the compartment ID. + with_display_name(self, name: str) -> "DataScienceModelGroup" + Sets the name. + with_freeform_tags(self, **kwargs: Dict[str, str]) -> "DataScienceModelGroup" + Sets freeform tags. + with_defined_tags(self, **kwargs: Dict[str, Dict[str, object]]) -> "DataScienceModelGroup" + Sets defined tags. + with_custom_metadata_list(self, metadata: Union[ModelCustomMetadata, Dict]) -> "DataScienceModelGroup" + Sets model group custom metadata. + with_model_group_version_history_id(self, model_group_version_history_id: str) -> "DataScienceModelGroup": + Sets the model group version history ID. + with_version_label(self, version_label: str) -> "DataScienceModelGroup": + Sets the model group version label. + with_base_model_id(self, base_model_id) -> "DataScienceModelGroup": + Sets the base model ID. + with_member_models(self, member_models: List[Dict[str, str]]) -> "DataScienceModelGroup": + Sets the list of member models to be grouped. + + + Examples + -------- + >>> ds_model_group = (DataScienceModelGroup() + ... .with_compartment_id(os.environ["NB_SESSION_COMPARTMENT_OCID"]) + ... .with_project_id(os.environ["PROJECT_OCID"]) + ... .with_display_name("TestModelGroup") + ... .with_description("Testing the model group") + ... .with_freeform_tags(tag1="val1", tag2="val2") + >>> ds_model_group.create() + >>> ds_model_group.with_description("new description").update() + >>> ds_model_group.delete() + >>> DataScienceModelGroup.list() + """ + + CONST_ID = "id" + CONST_CREATE_TYPE = "createType" + CONST_COMPARTMENT_ID = "compartmentId" + CONST_PROJECT_ID = "projectId" + CONST_DISPLAY_NAME = "displayName" + CONST_DESCRIPTION = "description" + CONST_FREEFORM_TAG = "freeformTags" + CONST_DEFINED_TAG = "definedTags" + CONST_MODEL_GROUP_DETAILS = "modelGroupDetails" + CONST_MEMBER_MODEL_ENTRIES = "memberModelEntries" + CONST_CUSTOM_METADATA_LIST = "customMetadataList" + CONST_BASE_MODEL_ID = "baseModelId" + CONST_MEMBER_MODELS = "memberModels" + CONST_MODEL_GROUP_VERSION_HISTORY_ID = "modelGroupVersionHistoryId" + CONST_MODEL_GROUP_VERSION_HISTORY_NAME = "modelGroupVersionHistoryName" + CONST_LIFECYCLE_STATE = "lifecycleState" + CONST_LIFECYCLE_DETAILS = "lifecycleDetails" + CONST_TIME_CREATED = "timeCreated" + CONST_TIME_UPDATED = "timeUpdated" + CONST_CREATED_BY = "createdBy" + CONST_VERSION_LABEL = "versionLabel" + CONST_VERSION_ID = "versionId" + + attribute_map = { + CONST_ID: "id", + CONST_COMPARTMENT_ID: "compartment_id", + CONST_PROJECT_ID: "project_id", + CONST_DISPLAY_NAME: "display_name", + CONST_DESCRIPTION: "description", + CONST_FREEFORM_TAG: "freeform_tags", + CONST_DEFINED_TAG: "defined_tags", + CONST_LIFECYCLE_STATE: "lifecycle_state", + CONST_LIFECYCLE_DETAILS: "lifecycle_details", + CONST_TIME_CREATED: "time_created", + CONST_TIME_UPDATED: "time_updated", + CONST_CREATED_BY: "created_by", + CONST_MODEL_GROUP_VERSION_HISTORY_ID: "model_group_version_history_id", + CONST_MODEL_GROUP_VERSION_HISTORY_NAME: "model_group_version_history_name", + CONST_VERSION_LABEL: "version_label", + CONST_VERSION_ID: "version_id", + } + + def __init__(self, spec=None, **kwargs): + """Initializes datascience model group. + + Parameters + ---------- + spec: (Dict, optional). Defaults to None. + Object specification. + + kwargs: Dict + Specification as keyword arguments. + If 'spec' contains the same key as the one in kwargs, + the value from kwargs will be used. + + - project_id: str + - compartment_id: str + - display_name: str + - description: str + - defined_tags: Dict[str, Dict[str, object]] + - freeform_tags: Dict[str, str] + - custom_metadata_list: Union[ModelCustomMetadata, Dict] + - base_model_id: str + - member_models: List[Dict[str, str]] + - model_group_version_history_id: str + - version_label: str + """ + super().__init__(spec, **kwargs) + self.dsc_model_group = OCIDataScienceModelGroup() + + @property + def kind(self) -> str: + """The kind of the model group as showing in a YAML.""" + return MODEL_GROUP_KIND + + @property + def id(self) -> str: + """The model group OCID.""" + return self.get_spec(self.CONST_ID) + + @property + def lifecycle_state(self) -> str: + """The model group lifecycle state.""" + return self.get_spec(self.CONST_LIFECYCLE_STATE) + + @property + def lifecycle_details(self) -> str: + """The model group lifecycle details.""" + return self.get_spec(self.CONST_LIFECYCLE_DETAILS) + + @property + def create_type(self) -> str: + """The model group create type.""" + return self.get_spec(self.CONST_CREATE_TYPE) + + @property + def model_group_version_history_name(self) -> str: + """The model group version history name.""" + return self.get_spec(self.CONST_MODEL_GROUP_VERSION_HISTORY_NAME) + + @property + def version_id(self) -> str: + """The model group version id.""" + return self.get_spec(self.CONST_VERSION_ID) + + def with_create_type(self, create_type: str) -> "DataScienceModelGroup": + """Sets the create type. + + Parameters + ---------- + create_type: str + The create type of model group. + + Returns + ------- + DataScienceModelGroup + The DataScienceModelGroup instance (self) + """ + if create_type not in ALLOWED_CREATE_TYPES: + raise ValueError( + f"Invalid create type. Allowed create type are {ALLOWED_CREATE_TYPES}." + ) + return self.set_spec(self.CONST_CREATE_TYPE, create_type) + + @property + def compartment_id(self) -> str: + """The model group compartment id.""" + return self.get_spec(self.CONST_COMPARTMENT_ID) + + def with_compartment_id(self, compartment_id: str) -> "DataScienceModelGroup": + """Sets the compartment OCID. + + Parameters + ---------- + compartment_id: str + The compartment id of model group. + + Returns + ------- + DataScienceModelGroup + The DataScienceModelGroup instance (self) + """ + return self.set_spec(self.CONST_COMPARTMENT_ID, compartment_id) + + @property + def project_id(self) -> str: + """The model group project id.""" + return self.get_spec(self.CONST_PROJECT_ID) + + def with_project_id(self, project_id: str) -> "DataScienceModelGroup": + """Sets the project OCID. + + Parameters + ---------- + project_id: str + The project id of model group. + + Returns + ------- + DataScienceModelGroup + The DataScienceModelGroup instance (self) + """ + return self.set_spec(self.CONST_PROJECT_ID, project_id) + + @property + def display_name(self) -> str: + """The model group display name.""" + return self.get_spec(self.CONST_DISPLAY_NAME) + + def with_display_name(self, display_name: str) -> "DataScienceModelGroup": + """Sets the display name. + + Parameters + ---------- + display_name: str + The display name of model group. + + Returns + ------- + DataScienceModelGroup + The DataScienceModelGroup instance (self) + """ + return self.set_spec(self.CONST_DISPLAY_NAME, display_name) + + @property + def description(self) -> str: + """The model group description.""" + return self.get_spec(self.CONST_DESCRIPTION) + + def with_description(self, description: str) -> "DataScienceModelGroup": + """Sets the description. + + Parameters + ---------- + description: str + The description of model group. + + Returns + ------- + DataScienceModelGroup + The DataScienceModelGroup instance (self) + """ + return self.set_spec(self.CONST_DESCRIPTION, description) + + @property + def freeform_tags(self) -> Dict[str, str]: + """The model group freeform tags.""" + return self.get_spec(self.CONST_FREEFORM_TAG) + + def with_freeform_tags(self, **kwargs) -> "DataScienceModelGroup": + """Sets the freeform tags. + + Returns + ------- + DataScienceModelGroup + The DataScienceModelGroup instance (self) + """ + return self.set_spec(self.CONST_FREEFORM_TAG, kwargs) + + @property + def defined_tags(self) -> Dict[str, Dict[str, object]]: + """The model group defined tags.""" + return self.get_spec(self.CONST_DEFINED_TAG) + + def with_defined_tags(self, **kwargs) -> "DataScienceModelGroup": + """Sets the defined tags. + + Returns + ------- + DataScienceModelGroup + The DataScienceModelGroup instance (self) + """ + return self.set_spec(self.CONST_DEFINED_TAG, kwargs) + + @property + def custom_metadata_list(self) -> ModelCustomMetadata: + """The model group custom metadata list.""" + return self.get_spec(self.CONST_CUSTOM_METADATA_LIST) + + def with_custom_metadata_list( + self, metadata: Union[ModelCustomMetadata, Dict] + ) -> "DataScienceModelGroup": + """Sets model group custom metadata. + + Parameters + ---------- + metadata: Union[ModelCustomMetadata, Dict] + The custom metadata. + + Returns + ------- + DataScienceModelGroup + The DataScienceModelGroup instance (self) + """ + if metadata and isinstance(metadata, Dict): + metadata = ModelCustomMetadata.from_dict(metadata) + return self.set_spec(self.CONST_CUSTOM_METADATA_LIST, metadata) + + @property + def base_model_id(self) -> str: + """The model group base model id.""" + return self.get_spec(self.CONST_BASE_MODEL_ID) + + def with_base_model_id(self, base_model_id: str) -> "DataScienceModelGroup": + """Sets base model id. + + Parameters + ---------- + base_model_id: str + The base model id. + + Returns + ------- + DataScienceModelGroup + The DataScienceModelGroup instance (self) + """ + return self.set_spec(self.CONST_BASE_MODEL_ID, base_model_id) + + @property + def member_models(self) -> List[Dict[str, str]]: + """The member models of model group.""" + return self.get_spec(self.CONST_MEMBER_MODELS) + + def with_member_models( + self, member_models: List[Dict[str, str]] + ) -> "DataScienceModelGroup": + """Sets member models to be grouped. + + Parameters + ---------- + member_models: List[Dict[str, str]] + The member models to be grouped. + + Returns + ------- + DataScienceModelGroup + The DataScienceModelGroup instance (self) + """ + return self.set_spec(self.CONST_MEMBER_MODELS, member_models) + + @property + def model_group_version_history_id(self) -> str: + """The model group version history id.""" + return self.get_spec(self.CONST_MODEL_GROUP_VERSION_HISTORY_ID) + + def with_model_group_version_history_id( + self, model_group_version_history_id: str + ) -> "DataScienceModelGroup": + """Sets model group version history id. + + Parameters + ---------- + model_group_version_history_id: str + The model group version history id. + + Returns + ------- + DataScienceModelGroup + The DataScienceModelGroup instance (self) + """ + return self.set_spec( + self.CONST_MODEL_GROUP_VERSION_HISTORY_ID, model_group_version_history_id + ) + + @property + def version_label(self) -> str: + """The model group version label.""" + return self.get_spec(self.CONST_VERSION_LABEL) + + def with_version_label(self, version_label: str) -> "DataScienceModelGroup": + """Sets model group version label. + + Parameters + ---------- + version_label: str + The model group version label. + + Returns + ------- + DataScienceModelGroup + The DataScienceModelGroup instance (self) + """ + return self.set_spec(self.CONST_VERSION_LABEL, version_label) + + def create( + self, + wait_for_completion: bool = True, + max_wait_time: int = DEFAULT_WAIT_TIME, + poll_interval: int = DEFAULT_POLL_INTERVAL, + ) -> "DataScienceModelGroup": + """Creates the datascience model group. + + Parameters + ---------- + wait_for_completion: bool + Flag set for whether to wait for model group to be created before proceeding. + Defaults to True. + max_wait_time: int + Maximum amount of time to wait in seconds (Defaults to 1200). + Negative implies infinite wait time. + poll_interval: int + Poll interval in seconds (Defaults to 10). + + Returns + ------- + DataScienceModelGroup + The instance of DataScienceModelGroup. + """ + response = self.dsc_model_group.create( + create_model_group_details=CreateModelGroupDetails( + **batch_convert_case(self._build_model_group_details(), "snake") + ), + wait_for_completion=wait_for_completion, + max_wait_time=max_wait_time, + poll_interval=poll_interval, + ) + + return self._update_from_oci_model(response) + + def _build_model_group_details(self) -> dict: + """Builds model group details dict for creating or updating oci model group.""" + model_group_details = HomogeneousModelGroupDetails( + custom_metadata_list=[ + CustomMetadata( + key=custom_metadata.key, + value=custom_metadata.value, + description=custom_metadata.description, + category=custom_metadata.category, + ) + for custom_metadata in self.custom_metadata_list._to_oci_metadata() + ] + ) + + member_model_entries = MemberModelEntries( + member_model_details=[ + MemberModelDetails(**member_model) + for member_model in self.member_models + ] + ) + + build_model_group_details = copy.deepcopy(self._spec) + build_model_group_details.pop(self.CONST_CUSTOM_METADATA_LIST) + build_model_group_details.pop(self.CONST_MEMBER_MODELS) + build_model_group_details.update( + { + self.CONST_COMPARTMENT_ID: self.compartment_id or COMPARTMENT_OCID, + self.CONST_PROJECT_ID: self.project_id or PROJECT_OCID, + self.CONST_MODEL_GROUP_DETAILS: model_group_details, + self.CONST_MEMBER_MODEL_ENTRIES: member_model_entries, + } + ) + + return build_model_group_details + + def _update_from_oci_model( + self, oci_model_group_instance: Union[ModelGroup, ModelGroupSummary] + ) -> "DataScienceModelGroup": + """Updates self spec from oci model group instance. + + Parameters + ---------- + oci_model_group_instance: Union[ModelGroup, ModelGroupSummary] + The oci model group instance, could be an instance of oci.data_science.models.ModelGroup + or oci.data_science.models.ModelGroupSummary. + + Returns + ------- + DataScienceModelGroup + The instance of DataScienceModelGroup. + """ + self.dsc_model_group = oci_model_group_instance + for key, value in self.attribute_map.items(): + if hasattr(oci_model_group_instance, value): + self.set_spec(key, getattr(oci_model_group_instance, value)) + + model_group_details: ModelGroupDetails = ( + oci_model_group_instance.model_group_details + ) + custom_metadata_list: List[CustomMetadata] = ( + model_group_details.custom_metadata_list + ) + model_custom_metadata = ModelCustomMetadata() + for metadata in custom_metadata_list: + model_custom_metadata.add( + key=metadata.key, + value=metadata.value, + description=metadata.description, + category=metadata.category, + ) + self.set_spec(self.CONST_CUSTOM_METADATA_LIST, model_custom_metadata) + + # only updates member_models when oci_model_group_instance is an instance of + # oci.data_science.models.ModelGroup as oci.data_science.models.ModelGroupSummary + # doesn't have member_model_entries property. + if isinstance(oci_model_group_instance, ModelGroup): + member_model_entries: MemberModelEntries = ( + oci_model_group_instance.member_model_entries + ) + member_model_details: List[MemberModelDetails] = ( + member_model_entries.member_model_details + ) + + self.set_spec( + self.CONST_MEMBER_MODELS, + [ + { + "inference_key": member_model_detail.inference_key, + "model_id": member_model_detail.model_id, + } + for member_model_detail in member_model_details + ], + ) + + return self + + def update( + self, + wait_for_completion: bool = True, + max_wait_time: int = DEFAULT_WAIT_TIME, + poll_interval: int = DEFAULT_POLL_INTERVAL, + ) -> "DataScienceModelGroup": + """Updates a datascience model group. + + Parameters + ---------- + wait_for_completion: bool + Flag set for whether to wait for model group to be updated before proceeding. + Defaults to True. + max_wait_time: int + Maximum amount of time to wait in seconds (Defaults to 1200). + Negative implies infinite wait time. + poll_interval: int + Poll interval in seconds (Defaults to 10). + + Returns + ------- + DataScienceModelGroup + The instance of DataScienceModelGroup. + """ + update_model_group_details = OCIDataScienceModelGroup( + **self._build_model_group_details() + ).to_oci_model(UpdateModelGroupDetails) + + response = self.dsc_model_group.update( + update_model_group_details=update_model_group_details, + wait_for_completion=wait_for_completion, + max_wait_time=max_wait_time, + poll_interval=poll_interval, + ) + + return self._update_from_oci_model(response) + + def activate( + self, + wait_for_completion: bool = True, + max_wait_time: int = DEFAULT_WAIT_TIME, + poll_interval: int = DEFAULT_POLL_INTERVAL, + ) -> "DataScienceModelGroup": + """Activates a datascience model group. + + Parameters + ---------- + wait_for_completion: bool + Flag set for whether to wait for model group to be activated before proceeding. + Defaults to True. + max_wait_time: int + Maximum amount of time to wait in seconds (Defaults to 1200). + Negative implies infinite wait time. + poll_interval: int + Poll interval in seconds (Defaults to 10). + + Returns + ------- + DataScienceModelGroup + The instance of DataScienceModelGroup. + """ + response = self.dsc_model_group.activate( + wait_for_completion=wait_for_completion, + max_wait_time=max_wait_time, + poll_interval=poll_interval, + ) + + return self._update_from_oci_model(response) + + def deactivate( + self, + wait_for_completion: bool = True, + max_wait_time: int = DEFAULT_WAIT_TIME, + poll_interval: int = DEFAULT_POLL_INTERVAL, + ) -> "DataScienceModelGroup": + """Deactivates a datascience model group. + + Parameters + ---------- + wait_for_completion: bool + Flag set for whether to wait for model group to be deactivated before proceeding. + Defaults to True. + max_wait_time: int + Maximum amount of time to wait in seconds (Defaults to 1200). + Negative implies infinite wait time. + poll_interval: int + Poll interval in seconds (Defaults to 10). + + Returns + ------- + DataScienceModelGroup + The instance of DataScienceModelGroup. + """ + response = self.dsc_model_group.deactivate( + wait_for_completion=wait_for_completion, + max_wait_time=max_wait_time, + poll_interval=poll_interval, + ) + + return self._update_from_oci_model(response) + + def delete( + self, + wait_for_completion: bool = True, + max_wait_time: int = DEFAULT_WAIT_TIME, + poll_interval: int = DEFAULT_POLL_INTERVAL, + ) -> "DataScienceModelGroup": + """Deletes a datascience model group. + + Parameters + ---------- + wait_for_completion: bool + Flag set for whether to wait for model group to be deleted before proceeding. + Defaults to True. + max_wait_time: int + Maximum amount of time to wait in seconds (Defaults to 1200). + Negative implies infinite wait time. + poll_interval: int + Poll interval in seconds (Defaults to 10). + + Returns + ------- + DataScienceModelGroup + The instance of DataScienceModelGroup. + """ + response = self.dsc_model_group.delete( + wait_for_completion=wait_for_completion, + max_wait_time=max_wait_time, + poll_interval=poll_interval, + ) + return self._update_from_oci_model(response) + + def sync(self) -> "DataScienceModelGroup": + """Updates the model group instance from backend. + + Returns + ------- + DataScienceModelGroup + The DataScienceModelGroup instance (self). + """ + if not self.id: + raise ValueError( + "Model group needs to be created before it can be fetched." + ) + return self._update_from_oci_model(OCIDataScienceModelGroup.from_id(self.id)) + + @classmethod + def list( + cls, + status: str = None, + compartment_id: str = None, + **kwargs, + ) -> List["DataScienceModelGroup"]: + """Lists datascience model groups in a given compartment. + + Parameters + ---------- + status: (str, optional). Defaults to `None`. + The status of model group. Allowed values: `ACTIVE`, `CREATING`, `DELETED`, `DELETING`, `FAILED` and `INACTIVE`. + compartment_id: (str, optional). Defaults to `None`. + The compartment OCID. + kwargs + Additional keyword arguments for filtering model groups. + + Returns + ------- + List[DataScienceModelGroup] + The list of the datascience model groups. + """ + return [ + cls()._update_from_oci_model(model_group_summary) + for model_group_summary in OCIDataScienceModelGroup.list( + status=status, + compartment_id=compartment_id, + **kwargs, + ) + ] + + @classmethod + def from_id(cls, model_group_id: str) -> "DataScienceModelGroup": + """Loads the model group instance from ocid. + + Parameters + ---------- + model_group_id: str + The ocid of model group. + + Returns + ------- + DataScienceModelGroup + The DataScienceModelGroup instance (self). + """ + oci_model_group = OCIDataScienceModelGroup.from_id(model_group_id) + return cls()._update_from_oci_model(oci_model_group) + + def to_dict(self) -> Dict: + """Serializes model group to a dictionary. + + Returns + ------- + dict + The model group serialized as a dictionary. + """ + spec = copy.deepcopy(self._spec) + for key, value in spec.items(): + if hasattr(value, "to_dict"): + value = value.to_dict() + spec[key] = value + + return { + "kind": self.kind, + "type": self.type, + "spec": batch_convert_case(spec, "camel"), + } + + @classmethod + def from_dict(cls, config: Dict) -> "DataScienceModelGroup": + """Loads model group instance from a dictionary of configurations. + + Parameters + ---------- + config: Dict + A dictionary of configurations. + + Returns + ------- + DataScienceModelGroup + The model group instance. + """ + return cls(spec=batch_convert_case(copy.deepcopy(config["spec"]), "snake")) diff --git a/ads/model/deployment/model_deployment.py b/ads/model/deployment/model_deployment.py index 56a70c112..57a4c683b 100644 --- a/ads/model/deployment/model_deployment.py +++ b/ads/model/deployment/model_deployment.py @@ -1,22 +1,27 @@ #!/usr/bin/env python -# -*- coding: utf-8; -*- -# Copyright (c) 2021, 2023 Oracle and/or its affiliates. +# Copyright (c) 2021, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import collections import copy import datetime -import oci -import warnings import time -from typing import Dict, List, Union, Any +import warnings +from typing import Any, Dict, List, Union +import oci import oci.loggingsearch -from ads.common import auth as authutil import pandas as pd -from ads.model.serde.model_input import JsonModelInputSERDE +from oci.data_science.models import ( + CreateModelDeploymentDetails, + LogDetails, + UpdateModelDeploymentDetails, +) + +from ads.common import auth as authutil +from ads.common import utils as ads_utils from ads.common.oci_logging import ( LOG_INTERVAL, LOG_RECORDS_LIMIT, @@ -30,10 +35,10 @@ from ads.model.deployment.common.utils import send_request from ads.model.deployment.model_deployment_infrastructure import ( DEFAULT_BANDWIDTH_MBPS, + DEFAULT_MEMORY_IN_GBS, + DEFAULT_OCPUS, DEFAULT_REPLICA, DEFAULT_SHAPE_NAME, - DEFAULT_OCPUS, - DEFAULT_MEMORY_IN_GBS, MODEL_DEPLOYMENT_INFRASTRUCTURE_TYPE, ModelDeploymentInfrastructure, ) @@ -45,18 +50,14 @@ ModelDeploymentRuntimeType, OCIModelDeploymentRuntimeType, ) +from ads.model.serde.model_input import JsonModelInputSERDE from ads.model.service.oci_datascience_model_deployment import ( OCIDataScienceModelDeployment, ) -from ads.common import utils as ads_utils + from .common import utils from .common.utils import State from .model_deployment_properties import ModelDeploymentProperties -from oci.data_science.models import ( - LogDetails, - CreateModelDeploymentDetails, - UpdateModelDeploymentDetails, -) DEFAULT_WAIT_TIME = 1200 DEFAULT_POLL_INTERVAL = 10 @@ -80,6 +81,11 @@ class ModelDeploymentLogType: ACCESS = "access" +class ModelDeploymentType: + SINGLE_MODEL = "SINGLE_MODEL" + MODEL_GROUP = "MODEL_GROUP" + + class LogNotConfiguredError(Exception): # pragma: no cover pass @@ -964,7 +970,9 @@ def predict( except oci.exceptions.ServiceError as ex: # When bandwidth exceeds the allocated value, TooManyRequests error (429) will be raised by oci backend. if ex.status == 429: - bandwidth_mbps = self.infrastructure.bandwidth_mbps or DEFAULT_BANDWIDTH_MBPS + bandwidth_mbps = ( + self.infrastructure.bandwidth_mbps or DEFAULT_BANDWIDTH_MBPS + ) utils.get_logger().warning( f"Load balancer bandwidth exceeds the allocated {bandwidth_mbps} Mbps." "To estimate the actual bandwidth, use formula: (payload size in KB) * (estimated requests per second) * 8 / 1024." @@ -1644,22 +1652,22 @@ def _build_model_deployment_configuration_details(self) -> Dict: } if infrastructure.subnet_id: - instance_configuration[ - infrastructure.CONST_SUBNET_ID - ] = infrastructure.subnet_id + instance_configuration[infrastructure.CONST_SUBNET_ID] = ( + infrastructure.subnet_id + ) if infrastructure.private_endpoint_id: if not hasattr( oci.data_science.models.InstanceConfiguration, "private_endpoint_id" ): # TODO: add oci version with private endpoint support. - raise EnvironmentError( + raise OSError( "Private endpoint is not supported in the current OCI SDK installed." ) - instance_configuration[ - infrastructure.CONST_PRIVATE_ENDPOINT_ID - ] = infrastructure.private_endpoint_id + instance_configuration[infrastructure.CONST_PRIVATE_ENDPOINT_ID] = ( + infrastructure.private_endpoint_id + ) scaling_policy = { infrastructure.CONST_POLICY_TYPE: "FIXED_SIZE", @@ -1667,13 +1675,13 @@ def _build_model_deployment_configuration_details(self) -> Dict: or DEFAULT_REPLICA, } - if not runtime.model_uri: + if not (runtime.model_uri or runtime.model_group_id): raise ValueError( - "Missing parameter model uri. Try reruning it after model uri is configured." + "Missing parameter model uri and model group id. Try reruning it after model or model group is configured." ) model_id = runtime.model_uri - if not model_id.startswith("ocid"): + if model_id and not model_id.startswith("ocid"): from ads.model.datascience_model import DataScienceModel dsc_model = DataScienceModel( @@ -1704,7 +1712,7 @@ def _build_model_deployment_configuration_details(self) -> Dict: oci.data_science.models, "ModelDeploymentEnvironmentConfigurationDetails", ): - raise EnvironmentError( + raise OSError( "Environment variable hasn't been supported in the current OCI SDK installed." ) @@ -1720,9 +1728,9 @@ def _build_model_deployment_configuration_details(self) -> Dict: and runtime.inference_server.upper() == MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON ): - environment_variables[ - "CONTAINER_TYPE" - ] = MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON + environment_variables["CONTAINER_TYPE"] = ( + MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON + ) runtime.set_spec(runtime.CONST_ENV, environment_variables) environment_configuration_details = { runtime.CONST_ENVIRONMENT_CONFIG_TYPE: runtime.environment_config_type, @@ -1734,7 +1742,7 @@ def _build_model_deployment_configuration_details(self) -> Dict: oci.data_science.models, "OcirModelDeploymentEnvironmentConfigurationDetails", ): - raise EnvironmentError( + raise OSError( "Container runtime hasn't been supported in the current OCI SDK installed." ) environment_configuration_details["image"] = runtime.image @@ -1742,19 +1750,37 @@ def _build_model_deployment_configuration_details(self) -> Dict: environment_configuration_details["cmd"] = runtime.cmd environment_configuration_details["entrypoint"] = runtime.entrypoint environment_configuration_details["serverPort"] = runtime.server_port - environment_configuration_details[ - "healthCheckPort" - ] = runtime.health_check_port + environment_configuration_details["healthCheckPort"] = ( + runtime.health_check_port + ) model_deployment_configuration_details = { - infrastructure.CONST_DEPLOYMENT_TYPE: "SINGLE_MODEL", + infrastructure.CONST_DEPLOYMENT_TYPE: ModelDeploymentType.SINGLE_MODEL, infrastructure.CONST_MODEL_CONFIG_DETAILS: model_configuration_details, runtime.CONST_ENVIRONMENT_CONFIG_DETAILS: environment_configuration_details, } + if runtime.model_group_id: + model_deployment_configuration_details[ + infrastructure.CONST_DEPLOYMENT_TYPE + ] = ModelDeploymentType.MODEL_GROUP + model_deployment_configuration_details["modelGroupConfigurationDetails"] = { + runtime.CONST_MODEL_GROUP_ID: runtime.model_group_id + } + model_deployment_configuration_details[ + "infrastructureConfigurationDetails" + ] = { + "infrastructureType": "INSTANCE_POOL", + infrastructure.CONST_BANDWIDTH_MBPS: infrastructure.bandwidth_mbps + or DEFAULT_BANDWIDTH_MBPS, + infrastructure.CONST_INSTANCE_CONFIG: instance_configuration, + infrastructure.CONST_SCALING_POLICY: scaling_policy, + } + model_configuration_details.pop(runtime.CONST_MODEL_ID) + if runtime.deployment_mode == ModelDeploymentMode.STREAM: if not hasattr(oci.data_science.models, "StreamConfigurationDetails"): - raise EnvironmentError( + raise OSError( "Model deployment mode hasn't been supported in the current OCI SDK installed." ) model_deployment_configuration_details[ @@ -1786,9 +1812,13 @@ def _build_category_log_details(self) -> Dict: logs = {} if ( - self.infrastructure.access_log and - self.infrastructure.access_log.get(self.infrastructure.CONST_LOG_GROUP_ID, None) - and self.infrastructure.access_log.get(self.infrastructure.CONST_LOG_ID, None) + self.infrastructure.access_log + and self.infrastructure.access_log.get( + self.infrastructure.CONST_LOG_GROUP_ID, None + ) + and self.infrastructure.access_log.get( + self.infrastructure.CONST_LOG_ID, None + ) ): logs[self.infrastructure.CONST_ACCESS] = { self.infrastructure.CONST_LOG_GROUP_ID: self.infrastructure.access_log.get( @@ -1799,9 +1829,13 @@ def _build_category_log_details(self) -> Dict: ), } if ( - self.infrastructure.predict_log and - self.infrastructure.predict_log.get(self.infrastructure.CONST_LOG_GROUP_ID, None) - and self.infrastructure.predict_log.get(self.infrastructure.CONST_LOG_ID, None) + self.infrastructure.predict_log + and self.infrastructure.predict_log.get( + self.infrastructure.CONST_LOG_GROUP_ID, None + ) + and self.infrastructure.predict_log.get( + self.infrastructure.CONST_LOG_ID, None + ) ): logs[self.infrastructure.CONST_PREDICT] = { self.infrastructure.CONST_LOG_GROUP_ID: self.infrastructure.predict_log.get( diff --git a/ads/model/deployment/model_deployment_runtime.py b/ads/model/deployment/model_deployment_runtime.py index 26e31f9cd..adfa48d1d 100644 --- a/ads/model/deployment/model_deployment_runtime.py +++ b/ads/model/deployment/model_deployment_runtime.py @@ -1,11 +1,11 @@ #!/usr/bin/env python -# -*- coding: utf-8; -*- -# Copyright (c) 2023 Oracle and/or its affiliates. +# Copyright (c) 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/from typing import Dict from typing import Dict, List + from ads.jobs.builders.base import Builder MODEL_DEPLOYMENT_RUNTIME_KIND = "runtime" @@ -41,6 +41,8 @@ class ModelDeploymentRuntime(Builder): The output stream ids of model deployment. model_uri: str The model uri of model deployment. + model_group_id: str + The model group id of model deployment. bucket_uri: str The OCI Object Storage URI where large size model artifacts will be copied to. auth: Dict @@ -66,6 +68,8 @@ class ModelDeploymentRuntime(Builder): Sets the output stream ids of model deployment with_model_uri(model_uri) Sets the model uri of model deployment + with_model_group_id(model_group_id) + Sets the model group id of model deployment with_bucket_uri(bucket_uri) Sets the bucket uri when uploading large size model. with_auth(auth) @@ -82,6 +86,7 @@ class ModelDeploymentRuntime(Builder): CONST_MODEL_ID = "modelId" CONST_MODEL_URI = "modelUri" + CONST_MODEL_GROUP_ID = "modelGroupId" CONST_ENV = "env" CONST_ENVIRONMENT_VARIABLES = "environmentVariables" CONST_ENVIRONMENT_CONFIG_TYPE = "environmentConfigurationType" @@ -103,6 +108,7 @@ class ModelDeploymentRuntime(Builder): CONST_OUTPUT_STREAM_IDS: "output_stream_ids", CONST_DEPLOYMENT_MODE: "deployment_mode", CONST_MODEL_URI: "model_uri", + CONST_MODEL_GROUP_ID: "model_group_id", CONST_BUCKET_URI: "bucket_uri", CONST_AUTH: "auth", CONST_REGION: "region", @@ -120,6 +126,9 @@ class ModelDeploymentRuntime(Builder): MODEL_CONFIG_DETAILS_PATH = ( "model_deployment_configuration_details.model_configuration_details" ) + MODEL_GROUP_CONFIG_DETAILS_PATH = ( + "model_deployment_configuration_details.model_group_configuration_details" + ) payload_attribute_map = { CONST_ENV: f"{ENVIRONMENT_CONFIG_DETAILS_PATH}.environment_variables", @@ -127,6 +136,7 @@ class ModelDeploymentRuntime(Builder): CONST_OUTPUT_STREAM_IDS: f"{STREAM_CONFIG_DETAILS_PATH}.output_stream_ids", CONST_DEPLOYMENT_MODE: "deployment_mode", CONST_MODEL_URI: f"{MODEL_CONFIG_DETAILS_PATH}.model_id", + CONST_MODEL_GROUP_ID: f"{MODEL_GROUP_CONFIG_DETAILS_PATH}.model_group_id", } def __init__(self, spec: Dict = None, **kwargs) -> None: @@ -278,6 +288,32 @@ def with_model_uri(self, model_uri: str) -> "ModelDeploymentRuntime": """ return self.set_spec(self.CONST_MODEL_URI, model_uri) + @property + def model_group_id(self) -> str: + """The model group id of model deployment. + + Returns + ------- + str + The model group id of model deployment. + """ + return self.get_spec(self.CONST_MODEL_GROUP_ID, None) + + def with_model_group_id(self, model_group_id: str) -> "ModelDeploymentRuntime": + """Sets the model group id of model deployment. + + Parameters + ---------- + model_group_id: str + The model group id of model deployment. + + Returns + ------- + ModelDeploymentRuntime + The ModelDeploymentRuntime instance (self). + """ + return self.set_spec(self.CONST_MODEL_GROUP_ID, model_group_id) + @property def bucket_uri(self) -> str: """The bucket uri of model. diff --git a/ads/model/model_metadata.py b/ads/model/model_metadata.py index 6b73b17f5..f0428ec9c 100644 --- a/ads/model/model_metadata.py +++ b/ads/model/model_metadata.py @@ -37,7 +37,7 @@ logger = logging.getLogger("ADS") METADATA_SIZE_LIMIT = 32000 -METADATA_VALUE_LENGTH_LIMIT = 255 +METADATA_VALUE_LENGTH_LIMIT = 16000 METADATA_DESCRIPTION_LENGTH_LIMIT = 255 _METADATA_EMPTY_VALUE = "NA" CURRENT_WORKING_DIR = "." diff --git a/ads/model/service/oci_datascience_model_group.py b/ads/model/service/oci_datascience_model_group.py new file mode 100644 index 000000000..b1fb72a3c --- /dev/null +++ b/ads/model/service/oci_datascience_model_group.py @@ -0,0 +1,488 @@ +#!/usr/bin/env python + +# Copyright (c) 2025 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +import logging +from functools import wraps +from typing import Callable + +import oci + +from ads.common.oci_datascience import OCIDataScienceMixin +from ads.common.work_request import DataScienceWorkRequest +from ads.model.deployment.common.utils import OCIClientManager, State + +try: + from oci.data_science.models import CreateModelGroupDetails, UpdateModelGroupDetails +except ModuleNotFoundError as err: + raise ModuleNotFoundError( + "The oci model group module was not found. Please run `pip install oci` " + "to install the latest oci sdk." + ) from err + +logger = logging.getLogger(__name__) + +DEFAULT_WAIT_TIME = 1200 +DEFAULT_POLL_INTERVAL = 10 +ALLOWED_STATUS = [ + State.ACTIVE.name, + State.CREATING.name, + State.DELETED.name, + State.DELETING.name, + State.FAILED.name, + State.INACTIVE.name, +] +MODEL_GROUP_NEEDS_TO_BE_CREATED = ( + "Missing model group id. Model group needs to be created before it can be accessed." +) + + +def check_for_model_group_id(msg: str = MODEL_GROUP_NEEDS_TO_BE_CREATED): + """The decorator helping to check if the ID attribute sepcified for a datascience model group. + + Parameters + ---------- + msg: str + The message that will be thrown. + + Raises + ------ + MissingModelGroupIdError + In case if the ID attribute not specified. + + Examples + -------- + >>> @check_for_id(msg="Some message.") + ... def test_function(self, name: str, last_name: str) + ... pass + """ + + def decorator(func: Callable): + @wraps(func) + def wrapper(self, *args, **kwargs): + if not self.id: + raise MissingModelGroupIdError(msg) + return func(self, *args, **kwargs) + + return wrapper + + return decorator + + +class MissingModelGroupIdError(Exception): # pragma: no cover + pass + + +class OCIDataScienceModelGroup( + OCIDataScienceMixin, + oci.data_science.models.ModelGroup, +): + """Represents an OCI Data Science Model Group. + This class contains all attributes of the `oci.data_science.models.ModelGroup`. + The main purpose of this class is to link the `oci.data_science.models.ModelGroup` + and the related client methods. + Linking the `ModelGroup` (payload) to Create/Update/Delete/Activate/Deactivate methods. + + The `OCIDataScienceModelGroup` can be initialized by unpacking the properties stored in a dictionary: + + .. code-block:: python + + properties = { + "compartment_id": "", + "name": "", + "description": "", + } + ds_model_group = OCIDataScienceModelGroup(**properties) + + The properties can also be OCI REST API payload, in which the keys are in camel format. + + .. code-block:: python + + payload = { + "compartmentId": "", + "name": "", + "description": "", + } + ds_model_group = OCIDataScienceModelGroup(**payload) + + Methods + ------- + activate(self, ...) -> "OCIDataScienceModelGroup": + Activates datascience model group. + create(self, ...) -> "OCIDataScienceModelGroup" + Creates datascience model group. + deactivate(self, ...) -> "OCIDataScienceModelGroup": + Deactivates datascience model group. + delete(self, ...) -> "OCIDataScienceModelGroup": + Deletes datascience model group. + update(self, ...) -> "OCIDataScienceModelGroup": + Updates datascience model group. + list(self, ...) -> list[oci.data_science.models.ModelGroupSummary]: + List oci.data_science.models.ModelGroupSummary instances within given compartment. + from_id(cls, model_group: str) -> "OCIDataScienceModelGroup": + Gets model group by OCID. + + Examples + -------- + >>> oci_model_group = OCIDataScienceModelGroup.from_id() + >>> oci_model_group.deactivate() + >>> oci_model_group.activate(wait_for_completion=False) + >>> oci_model_group.description = "A brand new description" + ... oci_model_group.update() + >>> oci_model_group.sync() + >>> oci_model_group.list(status="ACTIVE") + >>> oci_model_group.delete(wait_for_completion=False) + """ + + def __init__(self, config=None, signer=None, client_kwargs=None, **kwargs): + super().__init__(config, signer, client_kwargs, **kwargs) + self.workflow_req_id = None + + def create( + self, + create_model_group_details: CreateModelGroupDetails, + wait_for_completion: bool = True, + max_wait_time: int = DEFAULT_WAIT_TIME, + poll_interval: int = DEFAULT_POLL_INTERVAL, + ) -> "OCIDataScienceModelGroup": + """Creates datascience model group. + + Parameters + ---------- + create_model_group_details: CreateModelGroupDetails + An instance of CreateModelGroupDetails which consists of all + necessary parameters to create a data science model group. + wait_for_completion: bool + Flag set for whether to wait for process to be completed. + Defaults to True. + max_wait_time: int + Maximum amount of time to wait in seconds (Defaults to 1200). + Negative implies infinite wait time. + poll_interval: int + Poll interval in seconds (Defaults to 10). + + Returns + ------- + OCIDataScienceModelGroup + The `OCIDataScienceModelGroup` instance (self). + """ + response = self.client.create_model_group(create_model_group_details) + self.update_from_oci_model(response.data) + logger.info(f"Creating model group `{self.id}`.") + print(f"Model Group OCID: {self.id}") + + if wait_for_completion: + self.workflow_req_id = response.headers.get("opc-work-request-id", None) + + try: + DataScienceWorkRequest(self.workflow_req_id).wait_work_request( + progress_bar_description="Creating model group", + max_wait_time=max_wait_time, + poll_interval=poll_interval, + ) + except Exception as e: + logger.error("Error while trying to create model group: " + str(e)) + + return self.sync() + + @check_for_model_group_id( + msg="Model group needs to be created before it can be activated or deactivated.." + ) + def activate( + self, + wait_for_completion: bool = True, + max_wait_time: int = DEFAULT_WAIT_TIME, + poll_interval: int = DEFAULT_POLL_INTERVAL, + ) -> "OCIDataScienceModelGroup": + """Activates datascience model group. + + Parameters + ---------- + wait_for_completion: bool + Flag set for whether to wait for process to be completed. + Defaults to True. + max_wait_time: int + Maximum amount of time to wait in seconds (Defaults to 1200). + Negative implies infinite wait time. + poll_interval: int + Poll interval in seconds (Defaults to 10). + + Returns + ------- + OCIDataScienceModelGroup + The `OCIDataScienceModelGroup` instance (self). + """ + dsc_model_group = OCIDataScienceModelGroup.from_id(self.id) + if dsc_model_group.lifecycle_state == self.LIFECYCLE_STATE_ACTIVE: + raise Exception( + f"Model group {dsc_model_group.id} is already in active state." + ) + + if dsc_model_group.lifecycle_state == self.LIFECYCLE_STATE_INACTIVE: + logger.info(f"Activating model group `{self.id}`.") + response = self.client.activate_model_group( + self.id, + ) + + if wait_for_completion: + self.workflow_req_id = response.headers.get("opc-work-request-id", None) + + try: + DataScienceWorkRequest(self.workflow_req_id).wait_work_request( + progress_bar_description="Activating model group", + max_wait_time=max_wait_time, + poll_interval=poll_interval, + ) + except Exception as e: + logger.error( + "Error while trying to activate model group: " + str(e) + ) + + return self.sync() + else: + raise Exception( + f"Can't activate model group {dsc_model_group.id} when it's in {dsc_model_group.lifecycle_state} state." + ) + + @check_for_model_group_id( + msg="Model group needs to be created before it can be activated or deactivated.." + ) + def deactivate( + self, + wait_for_completion: bool = True, + max_wait_time: int = DEFAULT_WAIT_TIME, + poll_interval: int = DEFAULT_POLL_INTERVAL, + ) -> "OCIDataScienceModelGroup": + """Deactivates datascience model group. + + Parameters + ---------- + wait_for_completion: bool + Flag set for whether to wait for process to be completed. + Defaults to True. + max_wait_time: int + Maximum amount of time to wait in seconds (Defaults to 1200). + Negative implies infinite wait time. + poll_interval: int + Poll interval in seconds (Defaults to 10). + + Returns + ------- + OCIDataScienceModelGroup + The `OCIDataScienceModelGroup` instance (self). + """ + dsc_model_group = self.from_id(self.id) + if dsc_model_group.lifecycle_state == self.LIFECYCLE_STATE_INACTIVE: + raise Exception( + f"Model group {dsc_model_group.id} is already in inactive state." + ) + + if dsc_model_group.lifecycle_state == self.LIFECYCLE_STATE_ACTIVE: + logger.info(f"Deactivating model group `{self.id}`.") + response = self.client.deactivate_model_group( + self.id, + ) + + if wait_for_completion: + self.workflow_req_id = response.headers.get("opc-work-request-id", None) + + try: + DataScienceWorkRequest(self.workflow_req_id).wait_work_request( + progress_bar_description="Deactivating model group", + max_wait_time=max_wait_time, + poll_interval=poll_interval, + ) + except Exception as e: + logger.error( + "Error while trying to deactivate model group: " + str(e) + ) + + return self.sync() + else: + raise Exception( + f"Can't deactivate model group {dsc_model_group.id} when it's in {dsc_model_group.lifecycle_state} state." + ) + + @check_for_model_group_id( + msg="Model group needs to be created before it can be deleted." + ) + def delete( + self, + wait_for_completion: bool = True, + max_wait_time: int = DEFAULT_WAIT_TIME, + poll_interval: int = DEFAULT_POLL_INTERVAL, + ) -> "OCIDataScienceModelGroup": + """Deletes datascience model group. + + Parameters + ---------- + wait_for_completion: bool + Flag set for whether to wait for process to be completed. + Defaults to True. + max_wait_time: int + Maximum amount of time to wait in seconds (Defaults to 1200). + Negative implies infinite wait time. + poll_interval: int + Poll interval in seconds (Defaults to 10). + + Returns + ------- + OCIDataScienceModelGroup + The `OCIDataScienceModelGroup` instance (self). + """ + dsc_model_group = self.from_id(self.id) + if dsc_model_group.lifecycle_state in [ + self.LIFECYCLE_STATE_DELETED, + self.LIFECYCLE_STATE_DELETING, + ]: + raise Exception( + f"Model group {dsc_model_group.id} is either deleted or being deleted." + ) + if dsc_model_group.lifecycle_state not in [ + self.LIFECYCLE_STATE_ACTIVE, + self.LIFECYCLE_STATE_FAILED, + self.LIFECYCLE_STATE_INACTIVE, + ]: + raise Exception( + f"Can't delete model group {dsc_model_group.id} when it's in {dsc_model_group.lifecycle_state} state." + ) + logger.info(f"Deleting model group `{self.id}`.") + response = self.client.delete_model_group( + self.id, + ) + + if wait_for_completion: + self.workflow_req_id = response.headers.get("opc-work-request-id", None) + + try: + DataScienceWorkRequest(self.workflow_req_id).wait_work_request( + progress_bar_description="Deleting model group", + max_wait_time=max_wait_time, + poll_interval=poll_interval, + ) + except Exception as e: + logger.error("Error while trying to delete model group: " + str(e)) + + return self.sync() + + @check_for_model_group_id( + msg="Model group needs to be created before it can be updated." + ) + def update( + self, + update_model_group_details: UpdateModelGroupDetails, + wait_for_completion: bool = True, + max_wait_time: int = DEFAULT_WAIT_TIME, + poll_interval: int = DEFAULT_POLL_INTERVAL, + ) -> "OCIDataScienceModelGroup": + """Updates datascience model group. + + Parameters + ---------- + update_model_group_details: UpdateModelGroupDetails + Details to update model group. + wait_for_completion: bool + Flag set for whether to wait for process to be completed. + Defaults to True. + max_wait_time: int + Maximum amount of time to wait in seconds (Defaults to 1200). + Negative implies infinite wait time. + poll_interval: int + Poll interval in seconds (Defaults to 10). + + Returns + ------- + OCIDataScienceModelGroup + The `OCIDataScienceModelGroup` instance (self). + """ + if wait_for_completion: + wait_for_states = [ + self.LIFECYCLE_STATE_ACTIVE, + self.LIFECYCLE_STATE_FAILED, + ] + else: + wait_for_states = [] + + try: + response = self.client_composite.update_model_group_and_wait_for_state( + self.id, + update_model_group_details, + wait_for_states=wait_for_states, + waiter_kwargs={ + "max_interval_seconds": poll_interval, + "max_wait_seconds": max_wait_time, + }, + ) + self.workflow_req_id = response.headers.get("opc-work-request-id", None) + except Exception as e: + logger.error("Error while trying to update model group: " + str(e)) + + return self.sync() + + @classmethod + def list( + cls, + status: str = None, + compartment_id: str = None, + **kwargs, + ) -> list: + """Lists the model group associated with current compartment id and status + + Parameters + ---------- + status : str + Status of model group. Defaults to None. + Allowed values: `ACTIVE`, `CREATING`, `DELETED`, `DELETING`, `FAILED` and `INACTIVE`. + compartment_id : str + Target compartment to list model groups from. + Defaults to the compartment set in the environment variable "NB_SESSION_COMPARTMENT_OCID". + If "NB_SESSION_COMPARTMENT_OCID" is not set, the root compartment ID will be used. + An ValueError will be raised if root compartment ID cannot be determined. + kwargs : + The values are passed to oci.data_science.DataScienceClient.list_model_groups. + + Returns + ------- + list + A list of oci.data_science.models.ModelGroupSummary objects. + + Raises + ------ + ValueError + If compartment_id is not specified and cannot be determined from the environment. + """ + compartment_id = compartment_id or OCIClientManager().default_compartment_id() + + if not compartment_id: + raise ValueError( + "Unable to determine compartment ID from environment. Specify `compartment_id`." + ) + + if status is not None: + if status not in ALLOWED_STATUS: + raise ValueError( + f"Allowed `status` values are: {', '.join(ALLOWED_STATUS)}." + ) + kwargs["lifecycle_state"] = status + + # https://oracle-cloud-infrastructure-python-sdk.readthedocs.io/en/latest/api/pagination.html#module-oci.pagination + return oci.pagination.list_call_get_all_results( + cls().client.list_model_groups, compartment_id, **kwargs + ).data + + @classmethod + def from_id(cls, model_group_id: str) -> "OCIDataScienceModelGroup": + """Gets datascience model group by OCID. + + Parameters + ---------- + model_group_id: str + The OCID of the datascience model group. + + Returns + ------- + OCIDataScienceModelGroup + An instance of `OCIDataScienceModelGroup`. + """ + return super().from_ocid(model_group_id) diff --git a/tests/unitary/default_setup/model/test_model_group.py b/tests/unitary/default_setup/model/test_model_group.py new file mode 100644 index 000000000..6b7ac2358 --- /dev/null +++ b/tests/unitary/default_setup/model/test_model_group.py @@ -0,0 +1,280 @@ +#!/usr/bin/env python + +# Copyright (c) 2025 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +import copy +import unittest +from unittest.mock import patch +from ads.model.datascience_model_group import DataScienceModelGroup +from ads.model.model_metadata import ModelCustomMetadata + +try: + from oci.data_science.models import ( + ModelGroup, + HomogeneousModelGroupDetails, + MemberModelEntries, + CustomMetadata, + MemberModelDetails, + ModelGroupSummary, + ) +except (ImportError, AttributeError) as e: + raise unittest.SkipTest( + "Support for OCI Model Group is not available. Skipping the Model Group tests." + ) + +MODEL_GROUP_DICT = { + "kind": "datascienceModelGroup", + "type": "dataScienceModelGroup", + "spec": { + "displayName": "test_create_model_group", + "description": "test create model group description", + "freeformTags": {"test_key": "test_value"}, + "customMetadataList": { + "data": [ + { + "key": "test_key", + "value": "test_value", + "description": "test_description", + "category": "other", + "has_artifact": False, + } + ] + }, + "memberModels": [ + {"inference_key": "model_one", "model_id": "model_id_one"}, + {"inference_key": "model_two", "model_id": "model_id_two"}, + ], + }, +} + +MODEL_GROUP_SPEC = { + "display_name": "test_create_model_group", + "description": "test create model group description", + "freeform_tags": {"test_key": "test_value"}, + "custom_metadata_list": { + "data": [ + { + "key": "test_key", + "value": "test_value", + "description": "test_description", + "category": "other", + "has_artifact": False, + } + ] + }, + "member_models": [ + {"inference_key": "model_one", "model_id": "model_id_one"}, + {"inference_key": "model_two", "model_id": "model_id_two"}, + ], +} + +OCI_MODEL_GROUP_RESPONSE = { + "id": "test_model_group_id", + "compartment_id": "test_model_group_compartment_id", + "project_id": "test_model_group_project_id", + "display_name": "test_create_model_group", + "description": "test create model group description", + "created_by": "test_create_by", + "time_created": "2025-06-10T18:21:17.613000Z", + "time_updated": "2025-06-10T18:21:17.613000Z", + "lifecycle_state": "ACTIVE", + "lifecycle_details": "test lifecycle details", + "model_group_version_history_id": "test_model_group_version_history_id", + "model_group_version_history_name": "test_model_group_version_history_name", + "version_label": "test_version_label", + "version_id": 1, + "model_group_details": HomogeneousModelGroupDetails( + custom_metadata_list=[ + CustomMetadata( + key="test_key", + value="test_value", + description="test_description", + category="other", + ) + ] + ), + "member_model_entries": MemberModelEntries( + member_model_details=[ + MemberModelDetails(inference_key="model_one", model_id="model_id_one"), + MemberModelDetails(inference_key="model_two", model_id="model_id_two"), + ] + ), + "freeform_tags": {"test_key": "test_value"}, +} + + +class TestModelGroup: + def initialize_model_group(self): + custom_metadata = ModelCustomMetadata() + custom_metadata.add( + key="test_key", + value="test_value", + description="test_description", + category="other", + ) + + model_group = ( + DataScienceModelGroup() + .with_display_name("test_create_model_group") + .with_description("test create model group description") + .with_freeform_tags(**{"test_key": "test_value"}) + .with_custom_metadata_list(custom_metadata) + .with_member_models( + [ + {"inference_key": "model_one", "model_id": "model_id_one"}, + {"inference_key": "model_two", "model_id": "model_id_two"}, + ] + ) + ) + + return model_group + + def test_initialize_model_group(self): + model_group_one = self.initialize_model_group() + assert model_group_one.to_dict() == MODEL_GROUP_DICT + + model_group_two = DataScienceModelGroup.from_dict(MODEL_GROUP_DICT) + assert model_group_two.to_dict() == MODEL_GROUP_DICT + + model_group_three = DataScienceModelGroup(spec=MODEL_GROUP_SPEC) + assert model_group_three.to_dict() == MODEL_GROUP_DICT + + model_group_four = DataScienceModelGroup(**MODEL_GROUP_SPEC) + assert model_group_four.to_dict() == MODEL_GROUP_DICT + + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.create" + ) + def test_create(self, mock_dsc_model_group_create): + mock_dsc_model_group_create.return_value = ModelGroup( + **OCI_MODEL_GROUP_RESPONSE + ) + model_group = self.initialize_model_group() + model_group.create() + + mock_dsc_model_group_create.assert_called() + + assert model_group.id == OCI_MODEL_GROUP_RESPONSE["id"] + assert model_group.display_name == OCI_MODEL_GROUP_RESPONSE["display_name"] + assert model_group.description == OCI_MODEL_GROUP_RESPONSE["description"] + + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.activate" + ) + def test_activate(self, mock_dsc_model_group_activate): + mock_dsc_model_group_activate.return_value = ModelGroup( + **OCI_MODEL_GROUP_RESPONSE + ) + model_group = self.initialize_model_group() + model_group.activate( + wait_for_completion=False, + max_wait_time=1, + poll_interval=2, + ) + + mock_dsc_model_group_activate.assert_called_with( + wait_for_completion=False, + max_wait_time=1, + poll_interval=2, + ) + + assert model_group.lifecycle_state == "ACTIVE" + + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.deactivate" + ) + def test_deactivate(self, mock_dsc_model_group_deactivate): + mock_dsc_model_group_deactivate_response = copy.deepcopy( + OCI_MODEL_GROUP_RESPONSE + ) + mock_dsc_model_group_deactivate_response["lifecycle_state"] = "INACTIVE" + + mock_dsc_model_group_deactivate.return_value = ModelGroup( + **mock_dsc_model_group_deactivate_response + ) + model_group = self.initialize_model_group() + model_group.deactivate( + wait_for_completion=False, + max_wait_time=1, + poll_interval=2, + ) + + mock_dsc_model_group_deactivate.assert_called_with( + wait_for_completion=False, + max_wait_time=1, + poll_interval=2, + ) + + assert model_group.lifecycle_state == "INACTIVE" + + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.delete" + ) + def test_delete(self, mock_dsc_model_group_delete): + mock_dsc_model_group_delete_response = copy.deepcopy(OCI_MODEL_GROUP_RESPONSE) + mock_dsc_model_group_delete_response["lifecycle_state"] = "DELETED" + + mock_dsc_model_group_delete.return_value = ModelGroup( + **mock_dsc_model_group_delete_response + ) + model_group = self.initialize_model_group() + model_group.delete( + wait_for_completion=False, + max_wait_time=1, + poll_interval=2, + ) + + mock_dsc_model_group_delete.assert_called_with( + wait_for_completion=False, + max_wait_time=1, + poll_interval=2, + ) + + assert model_group.lifecycle_state == "DELETED" + + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.update" + ) + def test_update(self, mock_dsc_model_group_update): + mock_dsc_model_group_update_response = copy.deepcopy(OCI_MODEL_GROUP_RESPONSE) + mock_dsc_model_group_update_response["display_name"] = "updated display name" + mock_dsc_model_group_update_response["description"] = "updated description" + + mock_dsc_model_group_update.return_value = ModelGroup( + **mock_dsc_model_group_update_response + ) + model_group = self.initialize_model_group() + model_group.update() + + mock_dsc_model_group_update.assert_called() + + assert ( + model_group.display_name + == mock_dsc_model_group_update_response["display_name"] + ) + assert ( + model_group.description + == mock_dsc_model_group_update_response["description"] + ) + + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.list" + ) + def test_list(self, mock_dsc_model_group_list): + mock_dsc_model_group_list_response = copy.deepcopy(OCI_MODEL_GROUP_RESPONSE) + mock_dsc_model_group_list_response.pop("member_model_entries") + mock_dsc_model_group_list_response.pop("description") + mock_dsc_model_group_list.return_value = [ + ModelGroupSummary(**mock_dsc_model_group_list_response) + ] + + model_groups = DataScienceModelGroup.list( + status="ACTIVE", compartment_id="test_model_group_compartment_id" + ) + + mock_dsc_model_group_list.assert_called_with( + status="ACTIVE", compartment_id="test_model_group_compartment_id" + ) + + assert len(model_groups) == 1 diff --git a/tests/unitary/default_setup/model/test_oci_model_group.py b/tests/unitary/default_setup/model/test_oci_model_group.py new file mode 100644 index 000000000..90c370609 --- /dev/null +++ b/tests/unitary/default_setup/model/test_oci_model_group.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python + +# Copyright (c) 2025 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +import copy +import unittest +from unittest.mock import MagicMock, patch +from ads.model.service.oci_datascience_model_group import OCIDataScienceModelGroup + +try: + from oci.data_science.models import ( + ModelGroup, + HomogeneousModelGroupDetails, + MemberModelEntries, + CustomMetadata, + MemberModelDetails, + CreateModelGroupDetails, + UpdateModelGroupDetails, + ) +except (ImportError, AttributeError) as e: + raise unittest.SkipTest( + "Support for OCI Model Group is not available. Skipping the Model Group tests." + ) + +CREATE_MODEL_GROUP_DETAILS = { + "create_type": "CREATE", + "compartment_id": "test_model_group_compartment_id", + "project_id": "test_model_group_project_id", + "display_name": "test_create_model_group", + "description": "test create model group description", + "model_group_details": HomogeneousModelGroupDetails( + custom_metadata_list=[ + CustomMetadata( + key="test_key", + value="test_value", + description="test_description", + category="other", + ) + ] + ), + "member_model_entries": MemberModelEntries( + member_model_details=[ + MemberModelDetails(inference_key="model_one", model_id="model_id_one"), + MemberModelDetails(inference_key="model_two", model_id="model_id_two"), + ] + ), + "freeform_tags": {"test_key": "test_value"}, + "model_group_version_history_id": "test_model_group_version_history_id", + "version_label": "test_version_label", +} + +UPDATE_MODEL_GROUP_DETAILS = { + "display_name": "test_update_model_group", + "description": "test update model group description", + "model_group_version_history_id": "test_model_group_version_history_id", + "version_label": "test_version_label", + "freeform_tags": {"test_updated_key": "test_updated_value"}, +} + +OCI_MODEL_GROUP_RESPONSE = { + "id": "test_model_group_id", + "compartment_id": "test_model_group_compartment_id", + "project_id": "test_model_group_project_id", + "display_name": "test_create_model_group", + "description": "test create model group description", + "created_by": "test_create_by", + "time_created": "2025-06-10T18:21:17.613000Z", + "time_updated": "2025-06-10T18:21:17.613000Z", + "lifecycle_state": "ACTIVE", + "lifecycle_details": "test lifecycle details", + "model_group_version_history_id": "test_model_group_version_history_id", + "model_group_version_history_name": "test_model_group_version_history_name", + "version_label": "test_version_label", + "version_id": 1, + "model_group_details": HomogeneousModelGroupDetails( + custom_metadata_list=[ + CustomMetadata( + key="test_key", + value="test_value", + description="test_description", + category="other", + ) + ] + ), + "member_model_entries": MemberModelEntries( + member_model_details=[ + MemberModelDetails(inference_key="model_one", model_id="model_id_one"), + MemberModelDetails(inference_key="model_two", model_id="model_id_two"), + ] + ), + "freeform_tags": {"test_key": "test_value"}, +} + + +class TestOCIModelGroup: + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.sync" + ) + @patch("oci.data_science.DataScienceClient.create_model_group") + def test_create(self, mock_create_model_group, mock_sync): + mock_sync.return_value = ModelGroup(**OCI_MODEL_GROUP_RESPONSE) + create_model_group_details = CreateModelGroupDetails( + **CREATE_MODEL_GROUP_DETAILS + ) + oci_model_group = OCIDataScienceModelGroup().create( + create_model_group_details=create_model_group_details, + wait_for_completion=False, + max_wait_time=1, + poll_interval=2, + ) + + mock_create_model_group.assert_called_with(create_model_group_details) + + assert oci_model_group.id == OCI_MODEL_GROUP_RESPONSE["id"] + assert oci_model_group.display_name == OCI_MODEL_GROUP_RESPONSE["display_name"] + assert oci_model_group.description == OCI_MODEL_GROUP_RESPONSE["description"] + + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.sync" + ) + @patch("oci.data_science.DataScienceClient.activate_model_group") + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.from_id" + ) + def test_activate(self, mock_from_id, mock_activate_model_group, mock_sync): + mock_oci_model_group_activate_response = copy.deepcopy(OCI_MODEL_GROUP_RESPONSE) + mock_oci_model_group_activate_response["lifecycle_state"] = "INACTIVE" + mock_from_id.return_value = ModelGroup(**mock_oci_model_group_activate_response) + mock_sync.return_value = ModelGroup(**OCI_MODEL_GROUP_RESPONSE) + oci_model_group = OCIDataScienceModelGroup(**OCI_MODEL_GROUP_RESPONSE).activate( + wait_for_completion=False, max_wait_time=1, poll_interval=2 + ) + + mock_activate_model_group.assert_called_with(oci_model_group.id) + assert oci_model_group.lifecycle_state == "ACTIVE" + + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.sync" + ) + @patch("oci.data_science.DataScienceClient.deactivate_model_group") + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.from_id" + ) + def test_deactivate(self, mock_from_id, mock_deactivate_model_group, mock_sync): + mock_oci_model_group_deactivate_response = copy.deepcopy( + OCI_MODEL_GROUP_RESPONSE + ) + mock_oci_model_group_deactivate_response["lifecycle_state"] = "INACTIVE" + mock_from_id.return_value = ModelGroup(**OCI_MODEL_GROUP_RESPONSE) + mock_sync.return_value = ModelGroup(**mock_oci_model_group_deactivate_response) + oci_model_group = OCIDataScienceModelGroup( + **mock_oci_model_group_deactivate_response + ).deactivate(wait_for_completion=False, max_wait_time=1, poll_interval=2) + + mock_deactivate_model_group.assert_called_with(oci_model_group.id) + assert oci_model_group.lifecycle_state == "INACTIVE" + + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.sync" + ) + @patch("oci.data_science.DataScienceClient.delete_model_group") + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.from_id" + ) + def test_delete(self, mock_from_id, mock_delete_model_group, mock_sync): + mock_oci_model_group_delete_response = copy.deepcopy(OCI_MODEL_GROUP_RESPONSE) + mock_oci_model_group_delete_response["lifecycle_state"] = "DELETED" + mock_from_id.return_value = ModelGroup(**OCI_MODEL_GROUP_RESPONSE) + mock_sync.return_value = ModelGroup(**mock_oci_model_group_delete_response) + + oci_model_group = OCIDataScienceModelGroup(**OCI_MODEL_GROUP_RESPONSE).delete( + wait_for_completion=False, max_wait_time=1, poll_interval=2 + ) + + mock_delete_model_group.assert_called_with(oci_model_group.id) + assert oci_model_group.lifecycle_state == "DELETED" + + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.sync" + ) + @patch( + "oci.data_science.DataScienceClientCompositeOperations.update_model_group_and_wait_for_state" + ) + def test_update(self, mock_update_model_group, mock_sync): + mock_oci_model_group_update_response = copy.deepcopy(OCI_MODEL_GROUP_RESPONSE) + mock_oci_model_group_update_response.update(**UPDATE_MODEL_GROUP_DETAILS) + mock_sync.return_value = ModelGroup(**mock_oci_model_group_update_response) + update_model_group_details = UpdateModelGroupDetails( + **UPDATE_MODEL_GROUP_DETAILS + ) + oci_model_group = OCIDataScienceModelGroup(**OCI_MODEL_GROUP_RESPONSE).update( + update_model_group_details=update_model_group_details, + wait_for_completion=False, + max_wait_time=1, + poll_interval=2, + ) + + mock_update_model_group.assert_called_with( + oci_model_group.id, + update_model_group_details, + wait_for_states=[], + waiter_kwargs={ + "max_interval_seconds": 2, + "max_wait_seconds": 1, + }, + ) + + assert oci_model_group.id == mock_oci_model_group_update_response["id"] + assert ( + oci_model_group.display_name + == mock_oci_model_group_update_response["display_name"] + ) + assert ( + oci_model_group.description + == mock_oci_model_group_update_response["description"] + ) + assert ( + oci_model_group.freeform_tags + == mock_oci_model_group_update_response["freeform_tags"] + ) + + @patch( + "ads.model.service.oci_datascience_model_group.OCIDataScienceModelGroup.from_id" + ) + def test_from_id(self, mock_from_id): + OCIDataScienceModelGroup.from_id(OCI_MODEL_GROUP_RESPONSE["id"]) + mock_from_id.assert_called_with(OCI_MODEL_GROUP_RESPONSE["id"]) + + @patch("oci.pagination.list_call_get_all_results") + def test_list(self, mock_list_call_get_all_results): + response = MagicMock() + response.data = [MagicMock()] + mock_list_call_get_all_results.return_value = response + model_groups = OCIDataScienceModelGroup.list( + status="ACTIVE", + compartment_id="test_compartment_id", + ) + mock_list_call_get_all_results.assert_called() + assert isinstance(model_groups, list) diff --git a/tests/unitary/default_setup/model_deployment/test_model_deployment_v2.py b/tests/unitary/default_setup/model_deployment/test_model_deployment_v2.py index 589c58d70..2a6b19c66 100644 --- a/tests/unitary/default_setup/model_deployment/test_model_deployment_v2.py +++ b/tests/unitary/default_setup/model_deployment/test_model_deployment_v2.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*-- -# Copyright (c) 2023 Oracle and/or its affiliates. +# Copyright (c) 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import copy @@ -370,6 +370,7 @@ def test__load_default_properties(self, mock_from_ocid): ModelDeploymentInfrastructure.CONST_SHAPE_NAME: infrastructure.shape_name, ModelDeploymentInfrastructure.CONST_BANDWIDTH_MBPS: 10, ModelDeploymentInfrastructure.CONST_SHAPE_CONFIG_DETAILS: { + "cpu_baseline": None, "ocpus": 10.0, "memory_in_gbs": 36.0, }, @@ -589,151 +590,6 @@ def test_build_category_log_details(self): }, } - @patch.object(DataScienceModel, "create") - def test_build_model_deployment_details(self, mock_create): - dsc_model = MagicMock() - dsc_model.id = "fakeid.datasciencemodel.oc1.iad.xxx" - mock_create.return_value = dsc_model - model_deployment = self.initialize_model_deployment() - create_model_deployment_details = ( - model_deployment._build_model_deployment_details() - ) - - mock_create.assert_called() - - assert isinstance( - create_model_deployment_details, - CreateModelDeploymentDetails, - ) - assert ( - create_model_deployment_details.display_name - == model_deployment.display_name - ) - assert ( - create_model_deployment_details.description == model_deployment.description - ) - assert ( - create_model_deployment_details.freeform_tags - == model_deployment.freeform_tags - ) - assert ( - create_model_deployment_details.defined_tags - == model_deployment.defined_tags - ) - - category_log_details = create_model_deployment_details.category_log_details - assert isinstance(category_log_details, CategoryLogDetails) - assert ( - category_log_details.access.log_id - == model_deployment.infrastructure.access_log["logId"] - ) - assert ( - category_log_details.access.log_group_id - == model_deployment.infrastructure.access_log["logGroupId"] - ) - assert ( - category_log_details.predict.log_id - == model_deployment.infrastructure.predict_log["logId"] - ) - assert ( - category_log_details.predict.log_group_id - == model_deployment.infrastructure.predict_log["logGroupId"] - ) - - model_deployment_configuration_details = ( - create_model_deployment_details.model_deployment_configuration_details - ) - assert isinstance( - model_deployment_configuration_details, - SingleModelDeploymentConfigurationDetails, - ) - assert model_deployment_configuration_details.deployment_type == "SINGLE_MODEL" - - environment_configuration_details = ( - model_deployment_configuration_details.environment_configuration_details - ) - assert isinstance( - environment_configuration_details, - OcirModelDeploymentEnvironmentConfigurationDetails, - ) - assert ( - environment_configuration_details.environment_configuration_type - == "OCIR_CONTAINER" - ) - assert ( - environment_configuration_details.environment_variables - == model_deployment.runtime.env - ) - assert environment_configuration_details.cmd == model_deployment.runtime.cmd - assert environment_configuration_details.image == model_deployment.runtime.image - assert ( - environment_configuration_details.image_digest - == model_deployment.runtime.image_digest - ) - assert ( - environment_configuration_details.entrypoint - == model_deployment.runtime.entrypoint - ) - assert ( - environment_configuration_details.server_port - == model_deployment.runtime.server_port - ) - assert ( - environment_configuration_details.health_check_port - == model_deployment.runtime.health_check_port - ) - - model_configuration_details = ( - model_deployment_configuration_details.model_configuration_details - ) - assert isinstance( - model_configuration_details, - ModelConfigurationDetails, - ) - assert ( - model_configuration_details.bandwidth_mbps - == model_deployment.infrastructure.bandwidth_mbps - ) - assert ( - model_configuration_details.model_id == model_deployment.runtime.model_uri - ) - - instance_configuration = model_configuration_details.instance_configuration - assert isinstance(instance_configuration, InstanceConfiguration) - assert ( - instance_configuration.instance_shape_name - == model_deployment.infrastructure.shape_name - ) - assert ( - instance_configuration.model_deployment_instance_shape_config_details.ocpus - == model_deployment.infrastructure.shape_config_details["ocpus"] - ) - assert ( - instance_configuration.model_deployment_instance_shape_config_details.memory_in_gbs - == model_deployment.infrastructure.shape_config_details["memoryInGBs"] - ) - - scaling_policy = model_configuration_details.scaling_policy - assert isinstance(scaling_policy, FixedSizeScalingPolicy) - assert scaling_policy.policy_type == "FIXED_SIZE" - assert scaling_policy.instance_count == model_deployment.infrastructure.replica - - # stream_configuration_details = ( - # model_deployment_configuration_details.stream_configuration_details - # ) - # assert isinstance( - # stream_configuration_details, - # StreamConfigurationDetails, - # ) - # assert ( - # stream_configuration_details.input_stream_ids - # == model_deployment.runtime.input_stream_ids - # ) - # assert ( - # stream_configuration_details.output_stream_ids - # == model_deployment.runtime.output_stream_ids - # ) - def test_update_from_oci_model(self): model_deployment = self.initialize_model_deployment() model_deployment_from_oci = model_deployment._update_from_oci_model( @@ -882,151 +738,6 @@ def test_model_deployment_from_dict(self): assert new_model_deployment.to_dict() == model_deployment.to_dict() - @patch.object(DataScienceModel, "create") - def test_update_model_deployment_details(self, mock_create): - dsc_model = MagicMock() - dsc_model.id = "fakeid.datasciencemodel.oc1.iad.xxx" - mock_create.return_value = dsc_model - model_deployment = self.initialize_model_deployment() - update_model_deployment_details = ( - model_deployment._update_model_deployment_details() - ) - - mock_create.assert_called() - - assert isinstance( - update_model_deployment_details, - UpdateModelDeploymentDetails, - ) - assert ( - update_model_deployment_details.display_name - == model_deployment.display_name - ) - assert ( - update_model_deployment_details.description == model_deployment.description - ) - assert ( - update_model_deployment_details.freeform_tags - == model_deployment.freeform_tags - ) - assert ( - update_model_deployment_details.defined_tags - == model_deployment.defined_tags - ) - - category_log_details = update_model_deployment_details.category_log_details - assert isinstance(category_log_details, UpdateCategoryLogDetails) - assert ( - category_log_details.access.log_id - == model_deployment.infrastructure.access_log["logId"] - ) - assert ( - category_log_details.access.log_group_id - == model_deployment.infrastructure.access_log["logGroupId"] - ) - assert ( - category_log_details.predict.log_id - == model_deployment.infrastructure.predict_log["logId"] - ) - assert ( - category_log_details.predict.log_group_id - == model_deployment.infrastructure.predict_log["logGroupId"] - ) - - model_deployment_configuration_details = ( - update_model_deployment_details.model_deployment_configuration_details - ) - assert isinstance( - model_deployment_configuration_details, - UpdateSingleModelDeploymentConfigurationDetails, - ) - assert model_deployment_configuration_details.deployment_type == "SINGLE_MODEL" - - environment_configuration_details = ( - model_deployment_configuration_details.environment_configuration_details - ) - assert isinstance( - environment_configuration_details, - UpdateOcirModelDeploymentEnvironmentConfigurationDetails, - ) - assert ( - environment_configuration_details.environment_configuration_type - == "OCIR_CONTAINER" - ) - assert ( - environment_configuration_details.environment_variables - == model_deployment.runtime.env - ) - assert environment_configuration_details.cmd == model_deployment.runtime.cmd - assert environment_configuration_details.image == model_deployment.runtime.image - assert ( - environment_configuration_details.image_digest - == model_deployment.runtime.image_digest - ) - assert ( - environment_configuration_details.entrypoint - == model_deployment.runtime.entrypoint - ) - assert ( - environment_configuration_details.server_port - == model_deployment.runtime.server_port - ) - assert ( - environment_configuration_details.health_check_port - == model_deployment.runtime.health_check_port - ) - - model_configuration_details = ( - model_deployment_configuration_details.model_configuration_details - ) - assert isinstance( - model_configuration_details, - UpdateModelConfigurationDetails, - ) - assert ( - model_configuration_details.bandwidth_mbps - == model_deployment.infrastructure.bandwidth_mbps - ) - assert ( - model_configuration_details.model_id == model_deployment.runtime.model_uri - ) - - instance_configuration = model_configuration_details.instance_configuration - assert isinstance(instance_configuration, InstanceConfiguration) - assert ( - instance_configuration.instance_shape_name - == model_deployment.infrastructure.shape_name - ) - assert ( - instance_configuration.model_deployment_instance_shape_config_details.ocpus - == model_deployment.infrastructure.shape_config_details["ocpus"] - ) - assert ( - instance_configuration.model_deployment_instance_shape_config_details.memory_in_gbs - == model_deployment.infrastructure.shape_config_details["memoryInGBs"] - ) - - scaling_policy = model_configuration_details.scaling_policy - assert isinstance(scaling_policy, FixedSizeScalingPolicy) - assert scaling_policy.policy_type == "FIXED_SIZE" - assert scaling_policy.instance_count == model_deployment.infrastructure.replica - - # stream_configuration_details = ( - # model_deployment_configuration_details.stream_configuration_details - # ) - # assert isinstance( - # stream_configuration_details, - # UpdateStreamConfigurationDetails, - # ) - # assert ( - # stream_configuration_details.input_stream_ids - # == model_deployment.runtime.input_stream_ids - # ) - # assert ( - # stream_configuration_details.output_stream_ids - # == model_deployment.runtime.output_stream_ids - # ) - @patch.object( ModelDeploymentInfrastructure, "_load_default_properties", return_value={} ) @@ -1127,9 +838,7 @@ def test_from_ocid(self, mock_from_ocid): "create_model_deployment", ) @patch.object(DataScienceModel, "create") - def test_deploy( - self, mock_create, mock_create_model_deployment, mock_sync - ): + def test_deploy(self, mock_create, mock_create_model_deployment, mock_sync): dsc_model = MagicMock() dsc_model.id = "fakeid.datasciencemodel.oc1.iad.xxx" mock_create.return_value = dsc_model @@ -1346,44 +1055,35 @@ def test_update_spec(self): model_deployment = self.initialize_model_deployment() model_deployment._update_spec( display_name="test_updated_name", - freeform_tags={"test_updated_key":"test_updated_value"}, - access_log={ - "log_id": "test_updated_access_log_id" - }, - predict_log={ - "log_group_id": "test_updated_predict_log_group_id" - }, - shape_config_details={ - "ocpus": 100, - "memoryInGBs": 200 - }, + freeform_tags={"test_updated_key": "test_updated_value"}, + access_log={"log_id": "test_updated_access_log_id"}, + predict_log={"log_group_id": "test_updated_predict_log_group_id"}, + shape_config_details={"ocpus": 100, "memoryInGBs": 200}, replica=20, image="test_updated_image", - env={ - "test_updated_env_key":"test_updated_env_value" - } + env={"test_updated_env_key": "test_updated_env_value"}, ) assert model_deployment.display_name == "test_updated_name" assert model_deployment.freeform_tags == { - "test_updated_key":"test_updated_value" + "test_updated_key": "test_updated_value" } assert model_deployment.infrastructure.access_log == { "logId": "test_updated_access_log_id", - "logGroupId": "fakeid.loggroup.oc1.iad.xxx" + "logGroupId": "fakeid.loggroup.oc1.iad.xxx", } assert model_deployment.infrastructure.predict_log == { "logId": "fakeid.log.oc1.iad.xxx", - "logGroupId": "test_updated_predict_log_group_id" + "logGroupId": "test_updated_predict_log_group_id", } assert model_deployment.infrastructure.shape_config_details == { "ocpus": 100, - "memoryInGBs": 200 + "memoryInGBs": 200, } assert model_deployment.infrastructure.replica == 20 assert model_deployment.runtime.image == "test_updated_image" assert model_deployment.runtime.env == { - "test_updated_env_key":"test_updated_env_value" + "test_updated_env_key": "test_updated_env_value" } @patch.object(OCIDataScienceMixin, "sync") @@ -1393,18 +1093,14 @@ def test_update_spec(self): ) @patch.object(DataScienceModel, "create") def test_model_deployment_with_large_size_artifact( - self, - mock_create, - mock_create_model_deployment, - mock_sync + self, mock_create, mock_create_model_deployment, mock_sync ): dsc_model = MagicMock() dsc_model.id = "fakeid.datasciencemodel.oc1.iad.xxx" mock_create.return_value = dsc_model model_deployment = self.initialize_model_deployment() ( - model_deployment.runtime - .with_auth({"test_key":"test_value"}) + model_deployment.runtime.with_auth({"test_key": "test_value"}) .with_region("test_region") .with_overwrite_existing_artifact(True) .with_remove_existing_artifact(True) @@ -1425,18 +1121,18 @@ def test_model_deployment_with_large_size_artifact( mock_create_model_deployment.return_value = response model_deployment = self.initialize_model_deployment() model_deployment.set_spec(model_deployment.CONST_ID, "test_model_deployment_id") - + create_model_deployment_details = ( model_deployment._build_model_deployment_details() ) model_deployment.deploy(wait_for_completion=False) mock_create.assert_called_with( bucket_uri="test_bucket_uri", - auth={"test_key":"test_value"}, + auth={"test_key": "test_value"}, region="test_region", overwrite_existing_artifact=True, remove_existing_artifact=True, - timeout=100 + timeout=100, ) mock_create_model_deployment.assert_called_with(create_model_deployment_details) mock_sync.assert_called() diff --git a/tests/unitary/with_extras/aqua/test_deployment.py b/tests/unitary/with_extras/aqua/test_deployment.py index 2e6329fa8..5d71cbc27 100644 --- a/tests/unitary/with_extras/aqua/test_deployment.py +++ b/tests/unitary/with_extras/aqua/test_deployment.py @@ -23,7 +23,6 @@ ModelDeployWorkloadConfigurationDetails, ) from parameterized import parameterized -from pydantic import ValidationError import ads.aqua.modeldeployment.deployment import ads.config @@ -1014,14 +1013,14 @@ class TestDataset: { "model_id": "model_a", "fine_tune_weights": [], - "model_path": "", + "model_path": "model_a/", "model_task": "text_embedding", "params": "--example-container-params test --served-model-name test_model_1 --tensor-parallel-size 1 --trust-remote-code --max-model-len 60000", }, { "model_id": "model_b", "fine_tune_weights": [], - "model_path": "", + "model_path": "model_b/", "model_task": "image_text_to_text", "params": "--example-container-params test --served-model-name test_model_2 --tensor-parallel-size 2 --trust-remote-code --max-model-len 32000", }, @@ -1034,7 +1033,7 @@ class TestDataset: "model_path": "oci://test_bucket@test_namespace/models/ft-models/meta-llama-3b/ocid1.datasciencejob.oc1.iad.", }, ], - "model_path": "", + "model_path": "model_c/", "model_task": "code_synthesis", "params": "--example-container-params test --served-model-name test_model_3 --tensor-parallel-size 4", }, @@ -1046,14 +1045,14 @@ class TestDataset: { "model_id": "model_a", "fine_tune_weights": [], - "model_path": "", + "model_path": "model_a/", "model_task": "text_embedding", "params": "--example-container-params test --served-model-name test_model_1 --tensor-parallel-size 1 --trust-remote-code --max-model-len 60000", }, { "model_id": "model_b", "fine_tune_weights": [], - "model_path": "", + "model_path": "model_b/", "model_task": "image_text_to_text", "params": "--example-container-params test --served-model-name test_model_2 --tensor-parallel-size 2 --trust-remote-code --max-model-len 32000", }, @@ -1794,6 +1793,7 @@ def test_create_deployment_for_tei_byoc_embedding_model( @patch.object(AquaApp, "get_container_image") @patch("ads.model.deployment.model_deployment.ModelDeployment.deploy") @patch("ads.aqua.modeldeployment.AquaDeploymentApp.get_deployment_config") + @patch("ads.aqua.modeldeployment.AquaDeploymentApp._build_model_group_configs") @patch( "ads.aqua.modeldeployment.entities.CreateModelDeploymentDetails.validate_multimodel_deployment_feasibility" ) @@ -1806,6 +1806,7 @@ def test_create_deployment_for_multi_model( mock_get_multi_source, mock_validate_input_models, mock_validate_multimodel_deployment_feasibility, + mock_build_model_group_configs, mock_get_deployment_config, mock_deploy, mock_get_container_image, @@ -1813,6 +1814,13 @@ def test_create_deployment_for_multi_model( mock_get_container_config, ): """Test to create a deployment for multi models.""" + mock_build_model_group_configs.return_value = ( + "mock_group_name", + "mock_group_description", + {}, + MagicMock(), + "mock_combined_models", + ) mock_get_container_config.return_value = ( AquaContainerConfig.from_service_config( service_containers=TestDataset.CONTAINER_LIST @@ -2434,16 +2442,11 @@ def test_invalid_from_aqua_multi_model_ref( model_params = "--dummy-param" if expect_error: - with pytest.raises(ValidationError) as excinfo: + with pytest.raises( + AquaValueError, + match="The base model path is not available in the model artifact.", + ): BaseModelSpec.from_aqua_multi_model_ref(model_ref, model_params) - errs = excinfo.value.errors() - if not model_path.startswith("oci://"): - model_path_errors = [e for e in errs if e["loc"] == ("model_path",)] - assert model_path_errors, f"expected a model_path error, got: {errs!r}" - assert ( - "the base model path is not available in the model artifact." - in model_path_errors[0]["msg"].lower() - ) else: BaseModelSpec.from_aqua_multi_model_ref(model_ref, model_params) diff --git a/tests/unitary/with_extras/aqua/test_model.py b/tests/unitary/with_extras/aqua/test_model.py index 61d9b849d..878f75a9c 100644 --- a/tests/unitary/with_extras/aqua/test_model.py +++ b/tests/unitary/with_extras/aqua/test_model.py @@ -42,6 +42,7 @@ from ads.aqua.model.enums import MultiModelSupportedTaskType from ads.common.object_storage_details import ObjectStorageDetails from ads.model.datascience_model import DataScienceModel +from ads.model.datascience_model_group import DataScienceModelGroup from ads.model.model_metadata import ( ModelCustomMetadata, ModelProvenanceMetadata, @@ -457,14 +458,12 @@ def test_create_model(self, mock_from_id, mock_validate, mock_create): ) assert model.provenance_metadata.training_id == "test_training_id" - @patch.object(DataScienceModel, "create_custom_metadata_artifact") - @patch.object(DataScienceModel, "create") + @patch.object(DataScienceModelGroup, "create") @patch.object(AquaApp, "get_container_config") def test_create_multimodel( self, mock_get_container_config, - mock_create, - mock_create_custom_metadata_artifact, + mock_create_group, ): mock_get_container_config.return_value = get_container_config() mock_model = MagicMock() @@ -476,12 +475,8 @@ def test_create_multimodel( } mock_model.id = "mock_model_id" mock_model.artifact = "mock_artifact_path" - custom_metadata_list = ModelCustomMetadata() - custom_metadata_list.add( - **{"key": "deployment-container", "value": "odsc-tgi-serving"} - ) - mock_model.custom_metadata_list = custom_metadata_list + model_custom_metadata = MagicMock() model_info_1 = AquaMultiModelRef( model_id="test_model_id_1", @@ -497,39 +492,24 @@ def test_create_multimodel( env_var={"params": "--trust-remote-code --max-model-len 32000"}, ) + # testing fine tuned model in model group + model_info_3 = AquaMultiModelRef( + model_id="test_model_id_3", + gpu_count=2, + model_task="image_text_to_text", + env_var={"params": "--trust-remote-code --max-model-len 32000"}, + artifact_location="oci://test_bucket@test_namespace/models/meta-llama/Llama-3.2-3B-Instruct", + fine_tune_artifact="oci://test_bucket@test_namespace/models/ft-models/meta-llama-3b/ocid1.datasciencejob.oc1.iad.", + ) + model_details = { model_info_1.model_id: mock_model, model_info_2.model_id: mock_model, + model_info_3.model_id: mock_model, } - with pytest.raises(AquaValueError): - model = self.app.create_multi( - models=[model_info_1, model_info_2], - project_id="test_project_id", - compartment_id="test_compartment_id", - source_models=model_details, - ) - mock_model.freeform_tags["aqua_service_model"] = TestDataset.SERVICE_MODEL_ID - - with pytest.raises(AquaValueError): - model = self.app.create_multi( - models=[model_info_1, model_info_2], - project_id="test_project_id", - compartment_id="test_compartment_id", - source_models=model_details, - ) - mock_model.freeform_tags["task"] = "text-generation" - - with pytest.raises(AquaValueError): - model = self.app.create_multi( - models=[model_info_1, model_info_2], - project_id="test_project_id", - compartment_id="test_compartment_id", - source_models=model_details, - ) - custom_metadata_list = ModelCustomMetadata() custom_metadata_list.add( **{"key": "deployment-container", "value": "odsc-vllm-serving"} @@ -537,67 +517,27 @@ def test_create_multimodel( mock_model.custom_metadata_list = custom_metadata_list - # testing _extract_model_task when a user passes an invalid task to AquaMultiModelRef - model_info_1.model_task = "invalid_task" - - with pytest.raises(AquaValueError): - model = self.app.create_multi( - models=[model_info_1, model_info_2], - project_id="test_project_id", - compartment_id="test_compartment_id", - source_models=model_details, - ) - - # testing if a user tries to invoke a model with a task mode that is not yet supported - model_info_1.model_task = None - mock_model.freeform_tags["task"] = "unsupported_task" - with pytest.raises(AquaValueError): - model = self.app.create_multi( - models=[model_info_1, model_info_2], - project_id="test_project_id", - compartment_id="test_compartment_id", - source_models=model_details, - ) - - mock_model.freeform_tags["task"] = "text-generation" - model_info_1.model_task = "text_embedding" - # testing requesting metadata from fine tuned model to add to model group mock_model.model_file_description = ( TestDataset.fine_tuned_model_file_description ) - # testing fine tuned model in model group - model_info_3 = AquaMultiModelRef( - model_id="test_model_id_3", - gpu_count=2, - model_task="image_text_to_text", - env_var={"params": "--trust-remote-code --max-model-len 32000"}, - artifact_location="oci://test_bucket@test_namespace/models/meta-llama/Llama-3.2-3B-Instruct", - fine_tune_artifact="oci://test_bucket@test_namespace/models/ft-models/meta-llama-3b/ocid1.datasciencejob.oc1.iad.", - ) - - model_details[model_info_3.model_id] = mock_model - # will create a multi-model group - model = self.app.create_multi( + model_group = self.app.create_multi( models=[model_info_1, model_info_2, model_info_3], + model_custom_metadata=model_custom_metadata, + model_group_display_name="test_model_group_name", + model_group_description="test_model_group_description", + tags={"aqua_multimodel": "true"}, + combined_model_names="test_combined_models", project_id="test_project_id", compartment_id="test_compartment_id", source_models=model_details, ) - mock_create.assert_called_with(model_by_reference=True) - - mock_model.compartment_id = TestDataset.SERVICE_COMPARTMENT_ID - mock_create.return_value = mock_model + mock_create_group.assert_called() - assert model.freeform_tags == {"aqua_multimodel": "true"} - assert model.custom_metadata_list.get("model_group_count").value == "3" - assert ( - model.custom_metadata_list.get("deployment-container").value - == "odsc-vllm-serving" - ) + assert model_group.freeform_tags == {"aqua_multimodel": "true"} @pytest.mark.parametrize( "foundation_model_type",