diff --git a/ads/aqua/extension/deployment_handler.py b/ads/aqua/extension/deployment_handler.py index ed0bfdd74..88ad84272 100644 --- a/ads/aqua/extension/deployment_handler.py +++ b/ads/aqua/extension/deployment_handler.py @@ -54,6 +54,33 @@ def get(self, id=""): else: raise HTTPError(400, f"The request {self.request.path} is invalid.") + @handle_exceptions + def delete(self, model_deployment_id): + return self.finish(AquaDeploymentApp().delete(model_deployment_id)) + + @handle_exceptions + def put(self, *args, **kwargs): + """ + Handles put request for the activating and deactivating OCI datascience model deployments + Raises + ------ + HTTPError + Raises HTTPError if inputs are missing or are invalid + """ + url_parse = urlparse(self.request.path) + paths = url_parse.path.strip("/").split("/") + if len(paths) != 4 or paths[0] != "aqua" or paths[1] != "deployments": + raise HTTPError(400, f"The request {self.request.path} is invalid.") + + model_deployment_id = paths[2] + action = paths[3] + if action == "activate": + return self.finish(AquaDeploymentApp().activate(model_deployment_id)) + elif action == "deactivate": + return self.finish(AquaDeploymentApp().deactivate(model_deployment_id)) + else: + raise HTTPError(400, f"The request {self.request.path} is invalid.") + @handle_exceptions def post(self, *args, **kwargs): """ @@ -270,5 +297,7 @@ def post(self, *args, **kwargs): ("deployments/?([^/]*)/params", AquaDeploymentParamsHandler), ("deployments/config/?([^/]*)", AquaDeploymentHandler), ("deployments/?([^/]*)", AquaDeploymentHandler), + ("deployments/?([^/]*)/activate", AquaDeploymentHandler), + ("deployments/?([^/]*)/deactivate", AquaDeploymentHandler), ("inference", AquaDeploymentInferenceHandler), ] diff --git a/ads/aqua/extension/errors.py b/ads/aqua/extension/errors.py index d5e44944c..9829ff9e4 100644 --- a/ads/aqua/extension/errors.py +++ b/ads/aqua/extension/errors.py @@ -8,3 +8,4 @@ class Errors(str): NO_INPUT_DATA = "No input data provided." MISSING_REQUIRED_PARAMETER = "Missing required parameter: '{}'" MISSING_ONEOF_REQUIRED_PARAMETER = "Either '{}' or '{}' is required." + INVALID_VALUE_OF_PARAMETER = "Invalid value of parameter: '{}'" diff --git a/ads/aqua/extension/model_handler.py b/ads/aqua/extension/model_handler.py index 96f2826d0..8a5e490ea 100644 --- a/ads/aqua/extension/model_handler.py +++ b/ads/aqua/extension/model_handler.py @@ -9,7 +9,10 @@ from ads.aqua.common.decorator import handle_exceptions from ads.aqua.common.errors import AquaRuntimeError, AquaValueError -from ads.aqua.common.utils import get_hf_model_info, list_hf_models +from ads.aqua.common.utils import ( + get_hf_model_info, + list_hf_models, +) from ads.aqua.extension.base_handler import AquaAPIhandler from ads.aqua.extension.errors import Errors from ads.aqua.model import AquaModelApp @@ -73,6 +76,8 @@ def delete(self, id=""): paths = url_parse.path.strip("/") if paths.startswith("aqua/model/cache"): return self.finish(AquaModelApp().clear_model_list_cache()) + elif id: + return self.finish(AquaModelApp().delete_model(id)) else: raise HTTPError(400, f"The request {self.request.path} is invalid.") @@ -139,6 +144,36 @@ def post(self, *args, **kwargs): ) ) + @handle_exceptions + def put(self, id): + try: + input_data = self.get_json_body() + except Exception as ex: + raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex + + if not input_data: + raise HTTPError(400, Errors.NO_INPUT_DATA) + + inference_container = input_data.get("inference_container") + inference_containers = AquaModelApp.list_valid_inference_containers() + if ( + inference_container is not None + and inference_container not in inference_containers + ): + raise HTTPError( + 400, Errors.INVALID_VALUE_OF_PARAMETER.format("inference_container") + ) + + enable_finetuning = input_data.get("enable_finetuning") + task = input_data.get("task") + app=AquaModelApp() + self.finish( + app.edit_registered_model( + id, inference_container, enable_finetuning, task + ) + ) + app.clear_model_details_cache(model_id=id) + class AquaModelLicenseHandler(AquaAPIhandler): """Handler for Aqua Model license REST APIs.""" diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index 374e20ada..ce8d523d2 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -10,11 +10,15 @@ import oci from cachetools import TTLCache from huggingface_hub import snapshot_download -from oci.data_science.models import JobRun, Model +from oci.data_science.models import JobRun, Metadata, Model, UpdateModelDetails from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger from ads.aqua.app import AquaApp -from ads.aqua.common.enums import InferenceContainerTypeFamily, Tags +from ads.aqua.common.enums import ( + FineTuningContainerTypeFamily, + InferenceContainerTypeFamily, + Tags, +) from ads.aqua.common.errors import AquaRuntimeError, AquaValueError from ads.aqua.common.utils import ( LifecycleStatus, @@ -23,6 +27,7 @@ create_word_icon, generate_tei_cmd_var, get_artifact_path, + get_container_config, get_hf_model_info, list_os_files_with_extension, load_config, @@ -78,7 +83,11 @@ TENANCY_OCID, ) from ads.model import DataScienceModel -from ads.model.model_metadata import ModelCustomMetadata, ModelCustomMetadataItem +from ads.model.model_metadata import ( + MetadataCustomCategory, + ModelCustomMetadata, + ModelCustomMetadataItem, +) from ads.telemetry import telemetry @@ -333,6 +342,96 @@ def get(self, model_id: str, load_model_card: Optional[bool] = True) -> "AquaMod return model_details + @telemetry(entry_point="plugin=model&action=delete", name="aqua") + def delete_model(self, model_id): + ds_model = DataScienceModel.from_id(model_id) + is_registered_model = ds_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None) + is_fine_tuned_model = ds_model.freeform_tags.get( + Tags.AQUA_FINE_TUNED_MODEL_TAG, None + ) + if is_registered_model or is_fine_tuned_model: + return ds_model.delete() + else: + raise AquaRuntimeError( + f"Failed to delete model:{model_id}. Only registered models or finetuned model can be deleted." + ) + + @telemetry(entry_point="plugin=model&action=delete", name="aqua") + def edit_registered_model(self, id, inference_container, enable_finetuning, task): + """Edits the default config of unverified registered model. + + Parameters + ---------- + id: str + The model OCID. + inference_container: str. + The inference container family name + enable_finetuning: str + Flag to enable or disable finetuning over the model. Defaults to None + task: + The usecase type of the model. e.g , text-generation , text_embedding etc. + + Returns + ------- + Model: + The instance of oci.data_science.models.Model. + + """ + ds_model = DataScienceModel.from_id(id) + if ds_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None): + if ds_model.freeform_tags.get(Tags.AQUA_SERVICE_MODEL_TAG, None): + raise AquaRuntimeError( + f"Failed to edit model:{id}. Only registered unverified models can be edited." + ) + else: + custom_metadata_list = ds_model.custom_metadata_list + freeform_tags = ds_model.freeform_tags + if inference_container: + custom_metadata_list.add( + key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER, + value=inference_container, + category=MetadataCustomCategory.OTHER, + description="Deployment container mapping for SMC", + replace=True, + ) + if enable_finetuning is not None: + if enable_finetuning.lower() == "true": + custom_metadata_list.add( + key=ModelCustomMetadataFields.FINETUNE_CONTAINER, + value=FineTuningContainerTypeFamily.AQUA_FINETUNING_CONTAINER_FAMILY, + category=MetadataCustomCategory.OTHER, + description="Fine-tuning container mapping for SMC", + replace=True, + ) + freeform_tags.update({Tags.READY_TO_FINE_TUNE: "true"}) + elif enable_finetuning.lower() == "false": + try: + custom_metadata_list.remove( + ModelCustomMetadataFields.FINETUNE_CONTAINER + ) + freeform_tags.pop(Tags.READY_TO_FINE_TUNE) + except Exception as ex: + raise AquaRuntimeError( + f"The given model already doesn't support finetuning: {ex}" + ) + + custom_metadata_list.remove("modelDescription") + if task: + freeform_tags.update({Tags.TASK: task}) + updated_custom_metadata_list = [ + Metadata(**metadata) + for metadata in custom_metadata_list.to_dict()["data"] + ] + update_model_details = UpdateModelDetails( + custom_metadata_list=updated_custom_metadata_list, + freeform_tags=freeform_tags, + ) + AquaApp().update_model(id, update_model_details) + else: + raise AquaRuntimeError( + f"Failed to edit model:{id}. Only registered unverified models can be edited." + ) + def _fetch_metric_from_metadata( self, custom_metadata_list: ModelCustomMetadata, @@ -629,6 +728,32 @@ def clear_model_list_cache( } return res + def clear_model_details_cache(self, model_id): + """ + Allows user to clear model details cache item + Returns + ------- + dict with the key used, and True if cache has the key that needs to be deleted. + """ + res = {} + logger.info(f"Clearing _service_model_details_cache for {model_id}") + with self._cache_lock: + if model_id in self._service_model_details_cache: + self._service_model_details_cache.pop(key=model_id) + res = {"key": {"model_id": model_id}, "cache_deleted": True} + + return res + + @staticmethod + def list_valid_inference_containers(): + containers = list( + AquaContainerConfig.from_container_index_json( + config=get_container_config(), enable_spec=True + ).inference.values() + ) + family_values = [item.family for item in containers] + return family_values + def _create_model_catalog_entry( self, os_path: str, diff --git a/ads/aqua/modeldeployment/deployment.py b/ads/aqua/modeldeployment/deployment.py index d7ba06abc..e4de392df 100644 --- a/ads/aqua/modeldeployment/deployment.py +++ b/ads/aqua/modeldeployment/deployment.py @@ -532,6 +532,18 @@ def list(self, **kwargs) -> List["AquaDeployment"]: return results + @telemetry(entry_point="plugin=deployment&action=delete", name="aqua") + def delete(self,model_deployment_id:str): + return self.ds_client.delete_model_deployment(model_deployment_id=model_deployment_id).data + + @telemetry(entry_point="plugin=deployment&action=deactivate",name="aqua") + def deactivate(self,model_deployment_id:str): + return self.ds_client.deactivate_model_deployment(model_deployment_id=model_deployment_id).data + + @telemetry(entry_point="plugin=deployment&action=activate",name="aqua") + def activate(self,model_deployment_id:str): + return self.ds_client.activate_model_deployment(model_deployment_id=model_deployment_id).data + @telemetry(entry_point="plugin=deployment&action=get", name="aqua") def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail": """Gets the information of Aqua model deployment. diff --git a/tests/unitary/with_extras/aqua/test_deployment_handler.py b/tests/unitary/with_extras/aqua/test_deployment_handler.py index a91955160..e6f5acc45 100644 --- a/tests/unitary/with_extras/aqua/test_deployment_handler.py +++ b/tests/unitary/with_extras/aqua/test_deployment_handler.py @@ -92,6 +92,30 @@ def test_get_deployment(self, mock_get): self.deployment_handler.get(id="mock-model-id") mock_get.assert_called() + @patch("ads.aqua.modeldeployment.AquaDeploymentApp.delete") + def test_delete_deployment(self, mock_delete): + self.deployment_handler.request.path = "aqua/deployments" + self.deployment_handler.delete("mock-model-id") + mock_delete.assert_called() + + @patch("ads.aqua.modeldeployment.AquaDeploymentApp.activate") + def test_activate_deployment(self, mock_activate): + self.deployment_handler.request.path = ( + "aqua/deployments/ocid1.datasciencemodeldeployment.oc1.iad.xxx/activate" + ) + mock_activate.return_value = {"lifecycle_state": "UPDATING"} + self.deployment_handler.put() + mock_activate.assert_called() + + @patch("ads.aqua.modeldeployment.AquaDeploymentApp.deactivate") + def test_deactivate_deployment(self, mock_deactivate): + self.deployment_handler.request.path = ( + "aqua/deployments/ocid1.datasciencemodeldeployment.oc1.iad.xxx/deactivate" + ) + mock_deactivate.return_value = {"lifecycle_state": "UPDATING"} + self.deployment_handler.put() + mock_deactivate.assert_called() + @patch("ads.aqua.modeldeployment.AquaDeploymentApp.list") def test_list_deployment(self, mock_list): """Test get method to return a list of model deployments.""" diff --git a/tests/unitary/with_extras/aqua/test_model_handler.py b/tests/unitary/with_extras/aqua/test_model_handler.py index d217684fb..d4b741463 100644 --- a/tests/unitary/with_extras/aqua/test_model_handler.py +++ b/tests/unitary/with_extras/aqua/test_model_handler.py @@ -20,7 +20,6 @@ AquaModelLicenseHandler, ) from ads.aqua.model import AquaModelApp -from ads.aqua.model.constants import ModelTask from ads.aqua.model.entities import AquaModel, AquaModelSummary, HFModelSummary @@ -80,6 +79,46 @@ def test_delete(self, mock_urlparse, mock_clear_model_list_cache): mock_urlparse.assert_called() mock_clear_model_list_cache.assert_called() + @patch("ads.aqua.extension.model_handler.urlparse") + @patch.object(AquaModelApp, "delete_model") + def test_delete_with_id(self, mock_delete, mock_urlparse): + request_path = MagicMock(path="aqua/model/ocid1.datasciencemodel.oc1.iad.xxx") + mock_urlparse.return_value = request_path + mock_delete.return_value = {"state": "DELETED"} + with patch( + "ads.aqua.extension.base_handler.AquaAPIhandler.finish" + ) as mock_finish: + mock_finish.side_effect = lambda x: x + result = self.model_handler.delete(id="ocid1.datasciencemodel.oc1.iad.xxx") + assert result["state"] is "DELETED" + mock_urlparse.assert_called() + mock_delete.assert_called() + + @patch.object(AquaModelApp, "list_valid_inference_containers") + @patch.object(AquaModelApp, "edit_registered_model") + def test_put(self, mock_edit, mock_inference_container_list): + mock_edit.return_value = None + mock_inference_container_list.return_value = [ + "odsc-vllm-serving", + "odsc-tgi-serving", + "odsc-llama-cpp-serving", + ] + self.model_handler.get_json_body = MagicMock( + return_value=dict( + task="text_generation", + enable_finetuning="true", + inference_container="odsc-tgi-serving", + ) + ) + with patch( + "ads.aqua.extension.base_handler.AquaAPIhandler.finish" + ) as mock_finish: + mock_finish.side_effect = lambda x: x + result = self.model_handler.put(id="ocid1.datasciencemodel.oc1.iad.xxx") + assert result is None + mock_edit.assert_called_once() + mock_inference_container_list.assert_called_once() + @patch.object(AquaModelApp, "list") def test_list(self, mock_list): with patch(