|
14 | 14 |
|
15 | 15 | from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger
|
16 | 16 | from ads.aqua.app import AquaApp
|
17 |
| -from ads.aqua.common.enums import Tags |
| 17 | +from ads.aqua.common.enums import InferenceContainerTypeFamily, Tags |
18 | 18 | from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
|
19 | 19 | from ads.aqua.common.utils import (
|
20 | 20 | LifecycleStatus,
|
@@ -933,139 +933,186 @@ def _validate_model(
|
933 | 933 | # now as we know that at least one type of model files exist, validate the content of oss path.
|
934 | 934 | # for safetensors, we check if config.json files exist, and for gguf format we check if files with
|
935 | 935 | # gguf extension exist.
|
936 |
| - for model_format in model_formats: |
| 936 | + if {ModelFormat.SAFETENSORS, ModelFormat.GGUF}.issubset(set(model_formats)): |
937 | 937 | if (
|
938 |
| - model_format == ModelFormat.SAFETENSORS |
939 |
| - and len(safetensors_model_files) > 0 |
| 938 | + import_model_details.inference_container.lower() == InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY |
940 | 939 | ):
|
941 |
| - if import_model_details.download_from_hf: |
942 |
| - # validates config.json exists for safetensors model from hugginface |
943 |
| - if not hf_download_config_present: |
944 |
| - raise AquaRuntimeError( |
945 |
| - f"The model {model_name} does not contain {AQUA_MODEL_ARTIFACT_CONFIG} file as required " |
946 |
| - f"by {ModelFormat.SAFETENSORS.value} format model." |
947 |
| - f" Please check if the model name is correct in Hugging Face repository." |
948 |
| - ) |
949 |
| - else: |
950 |
| - try: |
951 |
| - model_config = load_config( |
952 |
| - file_path=import_model_details.os_path, |
953 |
| - config_file_name=AQUA_MODEL_ARTIFACT_CONFIG, |
954 |
| - ) |
955 |
| - except Exception as ex: |
956 |
| - logger.error( |
957 |
| - f"Exception occurred while loading config file from {import_model_details.os_path}" |
958 |
| - f"Exception message: {ex}" |
959 |
| - ) |
960 |
| - raise AquaRuntimeError( |
961 |
| - f"The model path {import_model_details.os_path} does not contain the file config.json. " |
962 |
| - f"Please check if the path is correct or the model artifacts are available at this location." |
963 |
| - ) from ex |
964 |
| - else: |
965 |
| - try: |
966 |
| - metadata_model_type = ( |
967 |
| - verified_model.custom_metadata_list.get( |
968 |
| - AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE |
969 |
| - ).value |
970 |
| - ) |
971 |
| - if metadata_model_type: |
972 |
| - if ( |
973 |
| - AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE |
974 |
| - in model_config |
975 |
| - ): |
976 |
| - if ( |
977 |
| - model_config[ |
978 |
| - AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE |
979 |
| - ] |
980 |
| - != metadata_model_type |
981 |
| - ): |
982 |
| - raise AquaRuntimeError( |
983 |
| - f"The {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in {AQUA_MODEL_ARTIFACT_CONFIG}" |
984 |
| - f" at {import_model_details.os_path} is invalid, expected {metadata_model_type} for " |
985 |
| - f"the model {model_name}. Please check if the path is correct or " |
986 |
| - f"the correct model artifacts are available at this location." |
987 |
| - f"" |
988 |
| - ) |
989 |
| - else: |
990 |
| - logger.debug( |
991 |
| - f"Could not find {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in " |
992 |
| - f"{AQUA_MODEL_ARTIFACT_CONFIG}. Proceeding with model registration." |
993 |
| - ) |
994 |
| - except Exception: |
995 |
| - pass |
996 |
| - if verified_model: |
997 |
| - validation_result.telemetry_model_name = ( |
998 |
| - verified_model.display_name |
999 |
| - ) |
1000 |
| - elif ( |
1001 |
| - model_config is not None |
1002 |
| - and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config |
1003 |
| - ): |
1004 |
| - validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME]}" |
1005 |
| - elif ( |
1006 |
| - model_config is not None |
1007 |
| - and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config |
| 940 | + self._validate_gguf_format( |
| 941 | + import_model_details=import_model_details, |
| 942 | + verified_model=verified_model, |
| 943 | + gguf_model_files=gguf_model_files, |
| 944 | + validation_result=validation_result, |
| 945 | + model_name=model_name |
| 946 | + ) |
| 947 | + else: |
| 948 | + self._validate_safetensor_format( |
| 949 | + import_model_details=import_model_details, |
| 950 | + verified_model=verified_model, |
| 951 | + validation_result=validation_result, |
| 952 | + hf_download_config_present=hf_download_config_present, |
| 953 | + model_name=model_name |
| 954 | + ) |
| 955 | + elif ModelFormat.SAFETENSORS in model_formats: |
| 956 | + self._validate_safetensor_format( |
| 957 | + import_model_details=import_model_details, |
| 958 | + verified_model=verified_model, |
| 959 | + validation_result=validation_result, |
| 960 | + hf_download_config_present=hf_download_config_present, |
| 961 | + model_name=model_name |
| 962 | + ) |
| 963 | + elif ModelFormat.GGUF in model_formats: |
| 964 | + self._validate_gguf_format( |
| 965 | + import_model_details=import_model_details, |
| 966 | + verified_model=verified_model, |
| 967 | + gguf_model_files=gguf_model_files, |
| 968 | + validation_result=validation_result, |
| 969 | + model_name=model_name |
| 970 | + ) |
| 971 | + |
| 972 | + return validation_result |
| 973 | + |
| 974 | + @staticmethod |
| 975 | + def _validate_safetensor_format( |
| 976 | + import_model_details: ImportModelDetails = None, |
| 977 | + verified_model: DataScienceModel = None, |
| 978 | + validation_result: ModelValidationResult = None, |
| 979 | + hf_download_config_present: bool = None, |
| 980 | + model_name: str = None |
| 981 | + ): |
| 982 | + if import_model_details.download_from_hf: |
| 983 | + # validates config.json exists for safetensors model from hugginface |
| 984 | + if not hf_download_config_present: |
| 985 | + raise AquaRuntimeError( |
| 986 | + f"The model {model_name} does not contain {AQUA_MODEL_ARTIFACT_CONFIG} file as required " |
| 987 | + f"by {ModelFormat.SAFETENSORS.value} format model." |
| 988 | + f" Please check if the model name is correct in Hugging Face repository." |
| 989 | + ) |
| 990 | + else: |
| 991 | + try: |
| 992 | + model_config = load_config( |
| 993 | + file_path=import_model_details.os_path, |
| 994 | + config_file_name=AQUA_MODEL_ARTIFACT_CONFIG, |
| 995 | + ) |
| 996 | + except Exception as ex: |
| 997 | + logger.error( |
| 998 | + f"Exception occurred while loading config file from {import_model_details.os_path}" |
| 999 | + f"Exception message: {ex}" |
| 1000 | + ) |
| 1001 | + raise AquaRuntimeError( |
| 1002 | + f"The model path {import_model_details.os_path} does not contain the file config.json. " |
| 1003 | + f"Please check if the path is correct or the model artifacts are available at this location." |
| 1004 | + ) from ex |
| 1005 | + else: |
| 1006 | + try: |
| 1007 | + metadata_model_type = ( |
| 1008 | + verified_model.custom_metadata_list.get( |
| 1009 | + AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE |
| 1010 | + ).value |
| 1011 | + ) |
| 1012 | + if metadata_model_type: |
| 1013 | + if ( |
| 1014 | + AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE |
| 1015 | + in model_config |
1008 | 1016 | ):
|
1009 |
| - validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}" |
| 1017 | + if ( |
| 1018 | + model_config[ |
| 1019 | + AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE |
| 1020 | + ] |
| 1021 | + != metadata_model_type |
| 1022 | + ): |
| 1023 | + raise AquaRuntimeError( |
| 1024 | + f"The {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in {AQUA_MODEL_ARTIFACT_CONFIG}" |
| 1025 | + f" at {import_model_details.os_path} is invalid, expected {metadata_model_type} for " |
| 1026 | + f"the model {model_name}. Please check if the path is correct or " |
| 1027 | + f"the correct model artifacts are available at this location." |
| 1028 | + f"" |
| 1029 | + ) |
1010 | 1030 | else:
|
1011 |
| - validation_result.telemetry_model_name = ( |
1012 |
| - AQUA_MODEL_TYPE_CUSTOM |
| 1031 | + logger.debug( |
| 1032 | + f"Could not find {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in " |
| 1033 | + f"{AQUA_MODEL_ARTIFACT_CONFIG}. Proceeding with model registration." |
1013 | 1034 | )
|
1014 |
| - elif model_format == ModelFormat.GGUF and len(gguf_model_files) > 0: |
1015 |
| - if import_model_details.finetuning_container and not safetensors_model_files: |
1016 |
| - raise AquaValueError( |
1017 |
| - "Fine-tuning is currently not supported with GGUF model format." |
1018 |
| - ) |
| 1035 | + except Exception: |
| 1036 | + pass |
1019 | 1037 | if verified_model:
|
1020 |
| - try: |
1021 |
| - model_file = verified_model.custom_metadata_list.get( |
1022 |
| - AQUA_MODEL_ARTIFACT_FILE |
1023 |
| - ).value |
1024 |
| - except ValueError as err: |
1025 |
| - raise AquaRuntimeError( |
1026 |
| - f"The model {verified_model.display_name} does not contain the custom metadata {AQUA_MODEL_ARTIFACT_FILE}. " |
1027 |
| - f"Please check if the model has the valid metadata." |
1028 |
| - ) from err |
1029 |
| - else: |
1030 |
| - model_file = import_model_details.model_file |
1031 |
| - |
1032 |
| - model_files = gguf_model_files |
1033 |
| - # todo: have a separate error validation class for different type of error messages. |
1034 |
| - if model_file: |
1035 |
| - if model_file not in model_files: |
1036 |
| - raise AquaRuntimeError( |
1037 |
| - f"The model path {import_model_details.os_path} or the Hugging Face " |
1038 |
| - f"model repository for {model_name} does not contain the file " |
1039 |
| - f"{model_file}. Please check if the path is correct or the model " |
1040 |
| - f"artifacts are available at this location." |
1041 |
| - ) |
1042 |
| - else: |
1043 |
| - validation_result.model_file = model_file |
1044 |
| - elif len(model_files) == 0: |
1045 |
| - raise AquaRuntimeError( |
1046 |
| - f"The model path {import_model_details.os_path} or the Hugging Face model " |
1047 |
| - f"repository for {model_name} does not contain any GGUF format files. " |
1048 |
| - f"Please check if the path is correct or the model artifacts are available " |
1049 |
| - f"at this location." |
1050 |
| - ) |
1051 |
| - elif len(model_files) > 1: |
1052 |
| - raise AquaRuntimeError( |
1053 |
| - f"The model path {import_model_details.os_path} or the Hugging Face model " |
1054 |
| - f"repository for {model_name} contains multiple GGUF format files. " |
1055 |
| - f"Please specify the file that needs to be deployed using the model_file " |
1056 |
| - f"parameter." |
| 1038 | + validation_result.telemetry_model_name = ( |
| 1039 | + verified_model.display_name |
1057 | 1040 | )
|
| 1041 | + elif ( |
| 1042 | + model_config is not None |
| 1043 | + and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config |
| 1044 | + ): |
| 1045 | + validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME]}" |
| 1046 | + elif ( |
| 1047 | + model_config is not None |
| 1048 | + and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config |
| 1049 | + ): |
| 1050 | + validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}" |
1058 | 1051 | else:
|
1059 |
| - validation_result.model_file = model_files[0] |
| 1052 | + validation_result.telemetry_model_name = ( |
| 1053 | + AQUA_MODEL_TYPE_CUSTOM |
| 1054 | + ) |
1060 | 1055 |
|
1061 |
| - if verified_model: |
1062 |
| - validation_result.telemetry_model_name = verified_model.display_name |
1063 |
| - elif import_model_details.download_from_hf: |
1064 |
| - validation_result.telemetry_model_name = model_name |
1065 |
| - else: |
1066 |
| - validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM |
| 1056 | + @staticmethod |
| 1057 | + def _validate_gguf_format( |
| 1058 | + import_model_details: ImportModelDetails = None, |
| 1059 | + verified_model: DataScienceModel = None, |
| 1060 | + gguf_model_files: List[str] = None, |
| 1061 | + validation_result: ModelValidationResult = None, |
| 1062 | + model_name: str = None, |
| 1063 | + ): |
| 1064 | + if import_model_details.finetuning_container: |
| 1065 | + raise AquaValueError( |
| 1066 | + "Fine-tuning is currently not supported with GGUF model format." |
| 1067 | + ) |
| 1068 | + if verified_model: |
| 1069 | + try: |
| 1070 | + model_file = verified_model.custom_metadata_list.get( |
| 1071 | + AQUA_MODEL_ARTIFACT_FILE |
| 1072 | + ).value |
| 1073 | + except ValueError as err: |
| 1074 | + raise AquaRuntimeError( |
| 1075 | + f"The model {verified_model.display_name} does not contain the custom metadata {AQUA_MODEL_ARTIFACT_FILE}. " |
| 1076 | + f"Please check if the model has the valid metadata." |
| 1077 | + ) from err |
| 1078 | + else: |
| 1079 | + model_file = import_model_details.model_file |
1067 | 1080 |
|
1068 |
| - return validation_result |
| 1081 | + model_files = gguf_model_files |
| 1082 | + # todo: have a separate error validation class for different type of error messages. |
| 1083 | + if model_file: |
| 1084 | + if model_file not in model_files: |
| 1085 | + raise AquaRuntimeError( |
| 1086 | + f"The model path {import_model_details.os_path} or the Hugging Face " |
| 1087 | + f"model repository for {model_name} does not contain the file " |
| 1088 | + f"{model_file}. Please check if the path is correct or the model " |
| 1089 | + f"artifacts are available at this location." |
| 1090 | + ) |
| 1091 | + else: |
| 1092 | + validation_result.model_file = model_file |
| 1093 | + elif len(model_files) == 0: |
| 1094 | + raise AquaRuntimeError( |
| 1095 | + f"The model path {import_model_details.os_path} or the Hugging Face model " |
| 1096 | + f"repository for {model_name} does not contain any GGUF format files. " |
| 1097 | + f"Please check if the path is correct or the model artifacts are available " |
| 1098 | + f"at this location." |
| 1099 | + ) |
| 1100 | + elif len(model_files) > 1: |
| 1101 | + raise AquaRuntimeError( |
| 1102 | + f"The model path {import_model_details.os_path} or the Hugging Face model " |
| 1103 | + f"repository for {model_name} contains multiple GGUF format files. " |
| 1104 | + f"Please specify the file that needs to be deployed using the model_file " |
| 1105 | + f"parameter." |
| 1106 | + ) |
| 1107 | + else: |
| 1108 | + validation_result.model_file = model_files[0] |
| 1109 | + |
| 1110 | + if verified_model: |
| 1111 | + validation_result.telemetry_model_name = verified_model.display_name |
| 1112 | + elif import_model_details.download_from_hf: |
| 1113 | + validation_result.telemetry_model_name = model_name |
| 1114 | + else: |
| 1115 | + validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM |
1069 | 1116 |
|
1070 | 1117 | @staticmethod
|
1071 | 1118 | def _download_model_from_hf(
|
|
0 commit comments