Skip to content

Commit a84b58b

Browse files
authored
[AQUA] Integrate aqua to use model group (#1214)
1 parent 395e5ce commit a84b58b

File tree

10 files changed

+621
-809
lines changed

10 files changed

+621
-809
lines changed

ads/aqua/model/model.py

Lines changed: 30 additions & 294 deletions
Large diffs are not rendered by default.

ads/aqua/modeldeployment/deployment.py

Lines changed: 371 additions & 32 deletions
Large diffs are not rendered by default.

ads/aqua/modeldeployment/entities.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ads.aqua import logger
1111
from ads.aqua.common.entities import AquaMultiModelRef
1212
from ads.aqua.common.enums import Tags
13+
from ads.aqua.common.errors import AquaValueError
1314
from ads.aqua.config.utils.serializer import Serializable
1415
from ads.aqua.constants import UNKNOWN_DICT
1516
from ads.aqua.data import AquaResourceIdentifier
@@ -21,6 +22,7 @@
2122
from ads.common.serializer import DataClassSerializable
2223
from ads.common.utils import UNKNOWN, get_console_link
2324
from ads.model.datascience_model import DataScienceModel
25+
from ads.model.deployment.model_deployment import ModelDeploymentType
2426
from ads.model.model_metadata import ModelCustomMetadataItem
2527

2628

@@ -147,13 +149,39 @@ def from_oci_model_deployment(
147149
AquaDeployment:
148150
The instance of the Aqua model deployment.
149151
"""
150-
instance_configuration = oci_model_deployment.model_deployment_configuration_details.model_configuration_details.instance_configuration
152+
model_deployment_configuration_details = (
153+
oci_model_deployment.model_deployment_configuration_details
154+
)
155+
if (
156+
model_deployment_configuration_details.deployment_type
157+
== ModelDeploymentType.SINGLE_MODEL
158+
):
159+
instance_configuration = model_deployment_configuration_details.model_configuration_details.instance_configuration
160+
instance_count = model_deployment_configuration_details.model_configuration_details.scaling_policy.instance_count
161+
model_id = model_deployment_configuration_details.model_configuration_details.model_id
162+
elif (
163+
model_deployment_configuration_details.deployment_type
164+
== ModelDeploymentType.MODEL_GROUP
165+
):
166+
instance_configuration = model_deployment_configuration_details.infrastructure_configuration_details.instance_configuration
167+
instance_count = model_deployment_configuration_details.infrastructure_configuration_details.scaling_policy.instance_count
168+
model_id = model_deployment_configuration_details.model_group_configuration_details.model_group_id
169+
else:
170+
allowed_deployment_types = ", ".join(
171+
[key for key in dir(ModelDeploymentType) if not key.startswith("__")]
172+
)
173+
raise AquaValueError(
174+
f"Invalid AQUA deployment with type {model_deployment_configuration_details.deployment_type}."
175+
f"Only {allowed_deployment_types} are supported at this moment. Specify a different AQUA model deployment."
176+
)
177+
151178
instance_shape_config_details = (
152179
instance_configuration.model_deployment_instance_shape_config_details
153180
)
154-
instance_count = oci_model_deployment.model_deployment_configuration_details.model_configuration_details.scaling_policy.instance_count
155-
environment_variables = oci_model_deployment.model_deployment_configuration_details.environment_configuration_details.environment_variables
156-
cmd = oci_model_deployment.model_deployment_configuration_details.environment_configuration_details.cmd
181+
environment_variables = model_deployment_configuration_details.environment_configuration_details.environment_variables
182+
cmd = (
183+
model_deployment_configuration_details.environment_configuration_details.cmd
184+
)
157185
shape_info = ShapeInfo(
158186
instance_shape=instance_configuration.instance_shape_name,
159187
instance_count=instance_count,
@@ -168,7 +196,6 @@ def from_oci_model_deployment(
168196
else None
169197
),
170198
)
171-
model_id = oci_model_deployment._model_deployment_configuration_details.model_configuration_details.model_id
172199
tags = {}
173200
tags.update(oci_model_deployment.freeform_tags or UNKNOWN_DICT)
174201
tags.update(oci_model_deployment.defined_tags or UNKNOWN_DICT)

ads/aqua/modeldeployment/model_group_config.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from typing import List, Optional, Tuple, Union
66

7-
from pydantic import BaseModel, Field, field_validator
7+
from pydantic import BaseModel, Field
88
from typing_extensions import Self
99

1010
from ads.aqua import logger
@@ -61,18 +61,19 @@ class BaseModelSpec(BaseModel):
6161
description="Optional list of fine-tuned model variants associated with this base model.",
6262
)
6363

64-
@field_validator("model_path")
6564
@classmethod
66-
def clean_model_path(cls, artifact_path_prefix: str) -> str:
67-
"""Validates and cleans the file path for model_path parameter."""
68-
if ObjectStorageDetails.is_oci_path(artifact_path_prefix):
69-
os_path = ObjectStorageDetails.from_path(artifact_path_prefix)
70-
artifact_path_prefix = os_path.filepath.rstrip("/")
71-
return artifact_path_prefix
72-
73-
raise AquaValueError(
74-
"The base model path is not available in the model artifact."
75-
)
65+
def build_model_path(cls, model_id: str, artifact_path_prefix: str) -> str:
66+
"""Cleans and builds the file path for model_path parameter
67+
to format: <model_id>/<artifact_path_prefix>
68+
"""
69+
if not ObjectStorageDetails.is_oci_path(artifact_path_prefix):
70+
raise AquaValueError(
71+
"The base model path is not available in the model artifact."
72+
)
73+
74+
os_path = ObjectStorageDetails.from_path(artifact_path_prefix)
75+
artifact_path_prefix = os_path.filepath.rstrip("/")
76+
return model_id + "/" + artifact_path_prefix.lstrip("/")
7677

7778
@classmethod
7879
def dedup_lora_modules(cls, fine_tune_weights: List[LoraModuleSpec]):
@@ -99,7 +100,7 @@ def from_aqua_multi_model_ref(
99100

100101
return cls(
101102
model_id=model.model_id,
102-
model_path=model.artifact_location,
103+
model_path=cls.build_model_path(model.model_id, model.artifact_location),
103104
params=model_params,
104105
model_task=model.model_task,
105106
fine_tune_weights=cls.dedup_lora_modules(model.fine_tune_weights),

ads/model/deployment/model_deployment.py

Lines changed: 76 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,27 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8; -*-
32

4-
# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2021, 2025 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76

87
import collections
98
import copy
109
import datetime
11-
import oci
12-
import warnings
1310
import time
14-
from typing import Dict, List, Union, Any
11+
import warnings
12+
from typing import Any, Dict, List, Union
1513

14+
import oci
1615
import oci.loggingsearch
17-
from ads.common import auth as authutil
1816
import pandas as pd
19-
from ads.model.serde.model_input import JsonModelInputSERDE
17+
from oci.data_science.models import (
18+
CreateModelDeploymentDetails,
19+
LogDetails,
20+
UpdateModelDeploymentDetails,
21+
)
22+
23+
from ads.common import auth as authutil
24+
from ads.common import utils as ads_utils
2025
from ads.common.oci_logging import (
2126
LOG_INTERVAL,
2227
LOG_RECORDS_LIMIT,
@@ -30,10 +35,10 @@
3035
from ads.model.deployment.common.utils import send_request
3136
from ads.model.deployment.model_deployment_infrastructure import (
3237
DEFAULT_BANDWIDTH_MBPS,
38+
DEFAULT_MEMORY_IN_GBS,
39+
DEFAULT_OCPUS,
3340
DEFAULT_REPLICA,
3441
DEFAULT_SHAPE_NAME,
35-
DEFAULT_OCPUS,
36-
DEFAULT_MEMORY_IN_GBS,
3742
MODEL_DEPLOYMENT_INFRASTRUCTURE_TYPE,
3843
ModelDeploymentInfrastructure,
3944
)
@@ -45,18 +50,14 @@
4550
ModelDeploymentRuntimeType,
4651
OCIModelDeploymentRuntimeType,
4752
)
53+
from ads.model.serde.model_input import JsonModelInputSERDE
4854
from ads.model.service.oci_datascience_model_deployment import (
4955
OCIDataScienceModelDeployment,
5056
)
51-
from ads.common import utils as ads_utils
57+
5258
from .common import utils
5359
from .common.utils import State
5460
from .model_deployment_properties import ModelDeploymentProperties
55-
from oci.data_science.models import (
56-
LogDetails,
57-
CreateModelDeploymentDetails,
58-
UpdateModelDeploymentDetails,
59-
)
6061

6162
DEFAULT_WAIT_TIME = 1200
6263
DEFAULT_POLL_INTERVAL = 10
@@ -80,6 +81,11 @@ class ModelDeploymentLogType:
8081
ACCESS = "access"
8182

8283

84+
class ModelDeploymentType:
85+
SINGLE_MODEL = "SINGLE_MODEL"
86+
MODEL_GROUP = "MODEL_GROUP"
87+
88+
8389
class LogNotConfiguredError(Exception): # pragma: no cover
8490
pass
8591

@@ -964,7 +970,9 @@ def predict(
964970
except oci.exceptions.ServiceError as ex:
965971
# When bandwidth exceeds the allocated value, TooManyRequests error (429) will be raised by oci backend.
966972
if ex.status == 429:
967-
bandwidth_mbps = self.infrastructure.bandwidth_mbps or DEFAULT_BANDWIDTH_MBPS
973+
bandwidth_mbps = (
974+
self.infrastructure.bandwidth_mbps or DEFAULT_BANDWIDTH_MBPS
975+
)
968976
utils.get_logger().warning(
969977
f"Load balancer bandwidth exceeds the allocated {bandwidth_mbps} Mbps."
970978
"To estimate the actual bandwidth, use formula: (payload size in KB) * (estimated requests per second) * 8 / 1024."
@@ -1644,36 +1652,36 @@ def _build_model_deployment_configuration_details(self) -> Dict:
16441652
}
16451653

16461654
if infrastructure.subnet_id:
1647-
instance_configuration[
1648-
infrastructure.CONST_SUBNET_ID
1649-
] = infrastructure.subnet_id
1655+
instance_configuration[infrastructure.CONST_SUBNET_ID] = (
1656+
infrastructure.subnet_id
1657+
)
16501658

16511659
if infrastructure.private_endpoint_id:
16521660
if not hasattr(
16531661
oci.data_science.models.InstanceConfiguration, "private_endpoint_id"
16541662
):
16551663
# TODO: add oci version with private endpoint support.
1656-
raise EnvironmentError(
1664+
raise OSError(
16571665
"Private endpoint is not supported in the current OCI SDK installed."
16581666
)
16591667

1660-
instance_configuration[
1661-
infrastructure.CONST_PRIVATE_ENDPOINT_ID
1662-
] = infrastructure.private_endpoint_id
1668+
instance_configuration[infrastructure.CONST_PRIVATE_ENDPOINT_ID] = (
1669+
infrastructure.private_endpoint_id
1670+
)
16631671

16641672
scaling_policy = {
16651673
infrastructure.CONST_POLICY_TYPE: "FIXED_SIZE",
16661674
infrastructure.CONST_INSTANCE_COUNT: infrastructure.replica
16671675
or DEFAULT_REPLICA,
16681676
}
16691677

1670-
if not runtime.model_uri:
1678+
if not (runtime.model_uri or runtime.model_group_id):
16711679
raise ValueError(
1672-
"Missing parameter model uri. Try reruning it after model uri is configured."
1680+
"Missing parameter model uri and model group id. Try reruning it after model or model group is configured."
16731681
)
16741682

16751683
model_id = runtime.model_uri
1676-
if not model_id.startswith("ocid"):
1684+
if model_id and not model_id.startswith("ocid"):
16771685
from ads.model.datascience_model import DataScienceModel
16781686

16791687
dsc_model = DataScienceModel(
@@ -1704,7 +1712,7 @@ def _build_model_deployment_configuration_details(self) -> Dict:
17041712
oci.data_science.models,
17051713
"ModelDeploymentEnvironmentConfigurationDetails",
17061714
):
1707-
raise EnvironmentError(
1715+
raise OSError(
17081716
"Environment variable hasn't been supported in the current OCI SDK installed."
17091717
)
17101718

@@ -1720,9 +1728,9 @@ def _build_model_deployment_configuration_details(self) -> Dict:
17201728
and runtime.inference_server.upper()
17211729
== MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON
17221730
):
1723-
environment_variables[
1724-
"CONTAINER_TYPE"
1725-
] = MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON
1731+
environment_variables["CONTAINER_TYPE"] = (
1732+
MODEL_DEPLOYMENT_INFERENCE_SERVER_TRITON
1733+
)
17261734
runtime.set_spec(runtime.CONST_ENV, environment_variables)
17271735
environment_configuration_details = {
17281736
runtime.CONST_ENVIRONMENT_CONFIG_TYPE: runtime.environment_config_type,
@@ -1734,27 +1742,45 @@ def _build_model_deployment_configuration_details(self) -> Dict:
17341742
oci.data_science.models,
17351743
"OcirModelDeploymentEnvironmentConfigurationDetails",
17361744
):
1737-
raise EnvironmentError(
1745+
raise OSError(
17381746
"Container runtime hasn't been supported in the current OCI SDK installed."
17391747
)
17401748
environment_configuration_details["image"] = runtime.image
17411749
environment_configuration_details["imageDigest"] = runtime.image_digest
17421750
environment_configuration_details["cmd"] = runtime.cmd
17431751
environment_configuration_details["entrypoint"] = runtime.entrypoint
17441752
environment_configuration_details["serverPort"] = runtime.server_port
1745-
environment_configuration_details[
1746-
"healthCheckPort"
1747-
] = runtime.health_check_port
1753+
environment_configuration_details["healthCheckPort"] = (
1754+
runtime.health_check_port
1755+
)
17481756

17491757
model_deployment_configuration_details = {
1750-
infrastructure.CONST_DEPLOYMENT_TYPE: "SINGLE_MODEL",
1758+
infrastructure.CONST_DEPLOYMENT_TYPE: ModelDeploymentType.SINGLE_MODEL,
17511759
infrastructure.CONST_MODEL_CONFIG_DETAILS: model_configuration_details,
17521760
runtime.CONST_ENVIRONMENT_CONFIG_DETAILS: environment_configuration_details,
17531761
}
17541762

1763+
if runtime.model_group_id:
1764+
model_deployment_configuration_details[
1765+
infrastructure.CONST_DEPLOYMENT_TYPE
1766+
] = ModelDeploymentType.MODEL_GROUP
1767+
model_deployment_configuration_details["modelGroupConfigurationDetails"] = {
1768+
runtime.CONST_MODEL_GROUP_ID: runtime.model_group_id
1769+
}
1770+
model_deployment_configuration_details[
1771+
"infrastructureConfigurationDetails"
1772+
] = {
1773+
"infrastructureType": "INSTANCE_POOL",
1774+
infrastructure.CONST_BANDWIDTH_MBPS: infrastructure.bandwidth_mbps
1775+
or DEFAULT_BANDWIDTH_MBPS,
1776+
infrastructure.CONST_INSTANCE_CONFIG: instance_configuration,
1777+
infrastructure.CONST_SCALING_POLICY: scaling_policy,
1778+
}
1779+
model_configuration_details.pop(runtime.CONST_MODEL_ID)
1780+
17551781
if runtime.deployment_mode == ModelDeploymentMode.STREAM:
17561782
if not hasattr(oci.data_science.models, "StreamConfigurationDetails"):
1757-
raise EnvironmentError(
1783+
raise OSError(
17581784
"Model deployment mode hasn't been supported in the current OCI SDK installed."
17591785
)
17601786
model_deployment_configuration_details[
@@ -1786,9 +1812,13 @@ def _build_category_log_details(self) -> Dict:
17861812

17871813
logs = {}
17881814
if (
1789-
self.infrastructure.access_log and
1790-
self.infrastructure.access_log.get(self.infrastructure.CONST_LOG_GROUP_ID, None)
1791-
and self.infrastructure.access_log.get(self.infrastructure.CONST_LOG_ID, None)
1815+
self.infrastructure.access_log
1816+
and self.infrastructure.access_log.get(
1817+
self.infrastructure.CONST_LOG_GROUP_ID, None
1818+
)
1819+
and self.infrastructure.access_log.get(
1820+
self.infrastructure.CONST_LOG_ID, None
1821+
)
17921822
):
17931823
logs[self.infrastructure.CONST_ACCESS] = {
17941824
self.infrastructure.CONST_LOG_GROUP_ID: self.infrastructure.access_log.get(
@@ -1799,9 +1829,13 @@ def _build_category_log_details(self) -> Dict:
17991829
),
18001830
}
18011831
if (
1802-
self.infrastructure.predict_log and
1803-
self.infrastructure.predict_log.get(self.infrastructure.CONST_LOG_GROUP_ID, None)
1804-
and self.infrastructure.predict_log.get(self.infrastructure.CONST_LOG_ID, None)
1832+
self.infrastructure.predict_log
1833+
and self.infrastructure.predict_log.get(
1834+
self.infrastructure.CONST_LOG_GROUP_ID, None
1835+
)
1836+
and self.infrastructure.predict_log.get(
1837+
self.infrastructure.CONST_LOG_ID, None
1838+
)
18051839
):
18061840
logs[self.infrastructure.CONST_PREDICT] = {
18071841
self.infrastructure.CONST_LOG_GROUP_ID: self.infrastructure.predict_log.get(

0 commit comments

Comments
 (0)