diff --git a/ads/aqua/constants.py b/ads/aqua/constants.py index 959e03bd1..e64dc0d56 100644 --- a/ads/aqua/constants.py +++ b/ads/aqua/constants.py @@ -55,6 +55,8 @@ SUPPORTED_FILE_FORMATS = ["jsonl"] MODEL_BY_REFERENCE_OSS_PATH_KEY = "artifact_location" +AQUA_CHAT_TEMPLATE_METADATA_KEY = "chat_template" + CONSOLE_LINK_RESOURCE_TYPE_MAPPING = { "datasciencemodel": "models", "datasciencemodeldeployment": "model-deployments", diff --git a/ads/aqua/extension/model_handler.py b/ads/aqua/extension/model_handler.py index 78833f51a..c645a0c82 100644 --- a/ads/aqua/extension/model_handler.py +++ b/ads/aqua/extension/model_handler.py @@ -11,12 +11,15 @@ from ads.aqua.common.enums import CustomInferenceContainerTypeFamily from ads.aqua.common.errors import AquaRuntimeError from ads.aqua.common.utils import get_hf_model_info, is_valid_ocid, list_hf_models +from ads.aqua.constants import AQUA_CHAT_TEMPLATE_METADATA_KEY from ads.aqua.extension.base_handler import AquaAPIhandler from ads.aqua.extension.errors import Errors from ads.aqua.model import AquaModelApp from ads.aqua.model.entities import AquaModelSummary, HFModelSummary from ads.config import SERVICE +from ads.model import DataScienceModel from ads.model.common.utils import MetadataArtifactPathType +from ads.model.service.oci_datascience_model import OCIDataScienceModel class AquaModelHandler(AquaAPIhandler): @@ -320,26 +323,65 @@ def post(self, *args, **kwargs): # noqa: ARG002 ) -class AquaModelTokenizerConfigHandler(AquaAPIhandler): +class AquaModelChatTemplateHandler(AquaAPIhandler): def get(self, model_id): """ - Handles requests for retrieving the Hugging Face tokenizer configuration of a specified model. - Expected request format: GET /aqua/models//tokenizer + Handles requests for retrieving the chat template from custom metadata of a specified model. + Expected request format: GET /aqua/models//chat-template """ path_list = urlparse(self.request.path).path.strip("/").split("/") - # Path should be /aqua/models/ocid1.iad.ahdxxx/tokenizer - # path_list=['aqua','models','','tokenizer'] + # Path should be /aqua/models/ocid1.iad.ahdxxx/chat-template + # path_list=['aqua','models','','chat-template'] if ( len(path_list) == 4 and is_valid_ocid(path_list[2]) - and path_list[3] == "tokenizer" + and path_list[3] == "chat-template" ): - return self.finish(AquaModelApp().get_hf_tokenizer_config(model_id)) + try: + oci_data_science_model = OCIDataScienceModel.from_id(model_id) + except Exception as e: + raise HTTPError(404, f"Model not found for id: {model_id}. Details: {str(e)}") + return self.finish(oci_data_science_model.get_custom_metadata_artifact("chat_template")) raise HTTPError(400, f"The request {self.request.path} is invalid.") + @handle_exceptions + def post(self, model_id: str): + """ + Handles POST requests to add a custom chat_template metadata artifact to a model. + + Expected request format: + POST /aqua/models//chat-template + Body: { "chat_template": "" } + + """ + try: + input_body = self.get_json_body() + except Exception as e: + raise HTTPError(400, f"Invalid JSON body: {str(e)}") + + chat_template = input_body.get("chat_template") + if not chat_template: + raise HTTPError(400, "Missing required field: 'chat_template'") + + try: + data_science_model = DataScienceModel.from_id(model_id) + except Exception as e: + raise HTTPError(404, f"Model not found for id: {model_id}. Details: {str(e)}") + + try: + result = data_science_model.create_custom_metadata_artifact( + metadata_key_name=AQUA_CHAT_TEMPLATE_METADATA_KEY, + path_type=MetadataArtifactPathType.CONTENT, + artifact_path_or_content=chat_template.encode() + ) + except Exception as e: + raise HTTPError(500, f"Failed to create metadata artifact: {str(e)}") + + return self.finish(result) + class AquaModelDefinedMetadataArtifactHandler(AquaAPIhandler): """ @@ -381,7 +423,7 @@ def post(self, model_id: str, metadata_key: str): ("model/?([^/]*)", AquaModelHandler), ("model/?([^/]*)/license", AquaModelLicenseHandler), ("model/?([^/]*)/readme", AquaModelReadmeHandler), - ("model/?([^/]*)/tokenizer", AquaModelTokenizerConfigHandler), + ("model/?([^/]*)/chat-template", AquaModelChatTemplateHandler), ("model/hf/search/?([^/]*)", AquaHuggingFaceHandler), ( "model/?([^/]*)/definedMetadata/?([^/]*)", diff --git a/tests/unitary/with_extras/aqua/test_model_handler.py b/tests/unitary/with_extras/aqua/test_model_handler.py index 767b27f39..e0938e1c1 100644 --- a/tests/unitary/with_extras/aqua/test_model_handler.py +++ b/tests/unitary/with_extras/aqua/test_model_handler.py @@ -4,7 +4,7 @@ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ from unicodedata import category from unittest import TestCase -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, ANY import pytest from huggingface_hub.hf_api import HfApi, ModelInfo @@ -14,13 +14,13 @@ from ads.aqua.common.errors import AquaRuntimeError from ads.aqua.common.utils import get_hf_model_info -from ads.aqua.constants import AQUA_TROUBLESHOOTING_LINK, STATUS_CODE_MESSAGES +from ads.aqua.constants import AQUA_TROUBLESHOOTING_LINK, STATUS_CODE_MESSAGES, AQUA_CHAT_TEMPLATE_METADATA_KEY from ads.aqua.extension.errors import ReplyDetails from ads.aqua.extension.model_handler import ( AquaHuggingFaceHandler, AquaModelHandler, AquaModelLicenseHandler, - AquaModelTokenizerConfigHandler, + AquaModelChatTemplateHandler ) from ads.aqua.model import AquaModelApp from ads.aqua.model.entities import AquaModel, AquaModelSummary, HFModelSummary @@ -254,39 +254,114 @@ def test_get(self, mock_load_license): mock_load_license.assert_called_with("test_model_id") -class ModelTokenizerConfigHandlerTestCase(TestCase): +class AquaModelChatTemplateHandlerTestCase(TestCase): @patch.object(IPythonHandler, "__init__") def setUp(self, ipython_init_mock) -> None: ipython_init_mock.return_value = None - self.model_tokenizer_config_handler = AquaModelTokenizerConfigHandler( + self.model_chat_template_handler = AquaModelChatTemplateHandler( MagicMock(), MagicMock() ) - self.model_tokenizer_config_handler.finish = MagicMock() - self.model_tokenizer_config_handler.request = MagicMock() + self.model_chat_template_handler.finish = MagicMock() + self.model_chat_template_handler.request = MagicMock() + self.model_chat_template_handler._headers = {} - @patch.object(AquaModelApp, "get_hf_tokenizer_config") + @patch("ads.aqua.extension.model_handler.OCIDataScienceModel.from_id") @patch("ads.aqua.extension.model_handler.urlparse") - def test_get(self, mock_urlparse, mock_get_hf_tokenizer_config): - request_path = MagicMock(path="aqua/model/ocid1.xx./tokenizer") + def test_get_valid_path(self, mock_urlparse, mock_from_id): + request_path = MagicMock(path="/aqua/models/ocid1.xx./chat-template") mock_urlparse.return_value = request_path - self.model_tokenizer_config_handler.get(model_id="test_model_id") - self.model_tokenizer_config_handler.finish.assert_called_with( - mock_get_hf_tokenizer_config.return_value - ) - mock_get_hf_tokenizer_config.assert_called_with("test_model_id") - @patch.object(AquaModelApp, "get_hf_tokenizer_config") + model_mock = MagicMock() + model_mock.get_custom_metadata_artifact.return_value = "chat_template_string" + mock_from_id.return_value = model_mock + + self.model_chat_template_handler.get(model_id="test_model_id") + self.model_chat_template_handler.finish.assert_called_with("chat_template_string") + model_mock.get_custom_metadata_artifact.assert_called_with("chat_template") + @patch("ads.aqua.extension.model_handler.urlparse") - def test_get_invalid_path(self, mock_urlparse, mock_get_hf_tokenizer_config): - """Test invalid request path should raise HTTPError(400)""" - request_path = MagicMock(path="/invalid/path") + def test_get_invalid_path(self, mock_urlparse): + request_path = MagicMock(path="/wrong/path") mock_urlparse.return_value = request_path with self.assertRaises(HTTPError) as context: - self.model_tokenizer_config_handler.get(model_id="test_model_id") + self.model_chat_template_handler.get("ocid1.test.chat") self.assertEqual(context.exception.status_code, 400) - self.model_tokenizer_config_handler.finish.assert_not_called() - mock_get_hf_tokenizer_config.assert_not_called() + + @patch("ads.aqua.extension.model_handler.OCIDataScienceModel.from_id", side_effect=Exception("Not found")) + @patch("ads.aqua.extension.model_handler.urlparse") + def test_get_model_not_found(self, mock_urlparse, mock_from_id): + request_path = MagicMock(path="/aqua/models/ocid1.invalid/chat-template") + mock_urlparse.return_value = request_path + + with self.assertRaises(HTTPError) as context: + self.model_chat_template_handler.get("ocid1.invalid") + self.assertEqual(context.exception.status_code, 404) + + @patch("ads.aqua.extension.model_handler.DataScienceModel.from_id") + def test_post_valid(self, mock_from_id): + model_mock = MagicMock() + model_mock.create_custom_metadata_artifact.return_value = {"result": "success"} + mock_from_id.return_value = model_mock + + self.model_chat_template_handler.get_json_body = MagicMock(return_value={"chat_template": "Hello <|user|>"}) + result = self.model_chat_template_handler.post("ocid1.valid") + self.model_chat_template_handler.finish.assert_called_with({"result": "success"}) + + model_mock.create_custom_metadata_artifact.assert_called_with( + metadata_key_name=AQUA_CHAT_TEMPLATE_METADATA_KEY, + path_type=ANY, + artifact_path_or_content=b"Hello <|user|>" + ) + + @patch.object(AquaModelChatTemplateHandler, "write_error") + def test_post_invalid_json(self, mock_write_error): + self.model_chat_template_handler.get_json_body = MagicMock(side_effect=Exception("Invalid JSON")) + self.model_chat_template_handler._headers = {} + self.model_chat_template_handler.post("ocid1.test.invalidjson") + + mock_write_error.assert_called_once() + + kwargs = mock_write_error.call_args.kwargs + exc_info = kwargs.get("exc_info") + + assert exc_info is not None + exc_type, exc_instance, _ = exc_info + + assert isinstance(exc_instance, HTTPError) + assert exc_instance.status_code == 400 + assert "Invalid JSON body" in str(exc_instance) + + @patch.object(AquaModelChatTemplateHandler, "write_error") + def test_post_missing_chat_template(self, mock_write_error): + self.model_chat_template_handler.get_json_body = MagicMock(return_value={}) + self.model_chat_template_handler._headers = {} + + self.model_chat_template_handler.post("ocid1.test.model") + + mock_write_error.assert_called_once() + exc_info = mock_write_error.call_args.kwargs.get("exc_info") + assert exc_info is not None + _, exc_instance, _ = exc_info + assert isinstance(exc_instance, HTTPError) + assert exc_instance.status_code == 400 + assert "Missing required field: 'chat_template'" in str(exc_instance) + + @patch("ads.aqua.extension.model_handler.DataScienceModel.from_id", side_effect=Exception("Not found")) + @patch.object(AquaModelChatTemplateHandler, "write_error") + def test_post_model_not_found(self, mock_write_error, mock_from_id): + self.model_chat_template_handler.get_json_body = MagicMock(return_value={"chat_template": "test template"}) + self.model_chat_template_handler._headers = {} + + self.model_chat_template_handler.post("ocid1.invalid.model") + + mock_write_error.assert_called_once() + exc_info = mock_write_error.call_args.kwargs.get("exc_info") + assert exc_info is not None + _, exc_instance, _ = exc_info + assert isinstance(exc_instance, HTTPError) + assert exc_instance.status_code == 404 + assert "Model not found" in str(exc_instance) class TestAquaHuggingFaceHandler: