Skip to content

Commit eb785ca

Browse files
authored
Fix validation for model containing both formats. (#948)
All tests have passed: https://github.com/oracle/accelerated-data-science/actions
2 parents 9e8e5d7 + 7d966b5 commit eb785ca

File tree

1 file changed

+171
-124
lines changed

1 file changed

+171
-124
lines changed

ads/aqua/model/model.py

Lines changed: 171 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger
1616
from ads.aqua.app import AquaApp
17-
from ads.aqua.common.enums import Tags
17+
from ads.aqua.common.enums import InferenceContainerTypeFamily, Tags
1818
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
1919
from ads.aqua.common.utils import (
2020
LifecycleStatus,
@@ -933,139 +933,186 @@ def _validate_model(
933933
# now as we know that at least one type of model files exist, validate the content of oss path.
934934
# for safetensors, we check if config.json files exist, and for gguf format we check if files with
935935
# gguf extension exist.
936-
for model_format in model_formats:
936+
if {ModelFormat.SAFETENSORS, ModelFormat.GGUF}.issubset(set(model_formats)):
937937
if (
938-
model_format == ModelFormat.SAFETENSORS
939-
and len(safetensors_model_files) > 0
938+
import_model_details.inference_container.lower() == InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY
940939
):
941-
if import_model_details.download_from_hf:
942-
# validates config.json exists for safetensors model from hugginface
943-
if not hf_download_config_present:
944-
raise AquaRuntimeError(
945-
f"The model {model_name} does not contain {AQUA_MODEL_ARTIFACT_CONFIG} file as required "
946-
f"by {ModelFormat.SAFETENSORS.value} format model."
947-
f" Please check if the model name is correct in Hugging Face repository."
948-
)
949-
else:
950-
try:
951-
model_config = load_config(
952-
file_path=import_model_details.os_path,
953-
config_file_name=AQUA_MODEL_ARTIFACT_CONFIG,
954-
)
955-
except Exception as ex:
956-
logger.error(
957-
f"Exception occurred while loading config file from {import_model_details.os_path}"
958-
f"Exception message: {ex}"
959-
)
960-
raise AquaRuntimeError(
961-
f"The model path {import_model_details.os_path} does not contain the file config.json. "
962-
f"Please check if the path is correct or the model artifacts are available at this location."
963-
) from ex
964-
else:
965-
try:
966-
metadata_model_type = (
967-
verified_model.custom_metadata_list.get(
968-
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
969-
).value
970-
)
971-
if metadata_model_type:
972-
if (
973-
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
974-
in model_config
975-
):
976-
if (
977-
model_config[
978-
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
979-
]
980-
!= metadata_model_type
981-
):
982-
raise AquaRuntimeError(
983-
f"The {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in {AQUA_MODEL_ARTIFACT_CONFIG}"
984-
f" at {import_model_details.os_path} is invalid, expected {metadata_model_type} for "
985-
f"the model {model_name}. Please check if the path is correct or "
986-
f"the correct model artifacts are available at this location."
987-
f""
988-
)
989-
else:
990-
logger.debug(
991-
f"Could not find {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in "
992-
f"{AQUA_MODEL_ARTIFACT_CONFIG}. Proceeding with model registration."
993-
)
994-
except Exception:
995-
pass
996-
if verified_model:
997-
validation_result.telemetry_model_name = (
998-
verified_model.display_name
999-
)
1000-
elif (
1001-
model_config is not None
1002-
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config
1003-
):
1004-
validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME]}"
1005-
elif (
1006-
model_config is not None
1007-
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config
940+
self._validate_gguf_format(
941+
import_model_details=import_model_details,
942+
verified_model=verified_model,
943+
gguf_model_files=gguf_model_files,
944+
validation_result=validation_result,
945+
model_name=model_name
946+
)
947+
else:
948+
self._validate_safetensor_format(
949+
import_model_details=import_model_details,
950+
verified_model=verified_model,
951+
validation_result=validation_result,
952+
hf_download_config_present=hf_download_config_present,
953+
model_name=model_name
954+
)
955+
elif ModelFormat.SAFETENSORS in model_formats:
956+
self._validate_safetensor_format(
957+
import_model_details=import_model_details,
958+
verified_model=verified_model,
959+
validation_result=validation_result,
960+
hf_download_config_present=hf_download_config_present,
961+
model_name=model_name
962+
)
963+
elif ModelFormat.GGUF in model_formats:
964+
self._validate_gguf_format(
965+
import_model_details=import_model_details,
966+
verified_model=verified_model,
967+
gguf_model_files=gguf_model_files,
968+
validation_result=validation_result,
969+
model_name=model_name
970+
)
971+
972+
return validation_result
973+
974+
@staticmethod
975+
def _validate_safetensor_format(
976+
import_model_details: ImportModelDetails = None,
977+
verified_model: DataScienceModel = None,
978+
validation_result: ModelValidationResult = None,
979+
hf_download_config_present: bool = None,
980+
model_name: str = None
981+
):
982+
if import_model_details.download_from_hf:
983+
# validates config.json exists for safetensors model from hugginface
984+
if not hf_download_config_present:
985+
raise AquaRuntimeError(
986+
f"The model {model_name} does not contain {AQUA_MODEL_ARTIFACT_CONFIG} file as required "
987+
f"by {ModelFormat.SAFETENSORS.value} format model."
988+
f" Please check if the model name is correct in Hugging Face repository."
989+
)
990+
else:
991+
try:
992+
model_config = load_config(
993+
file_path=import_model_details.os_path,
994+
config_file_name=AQUA_MODEL_ARTIFACT_CONFIG,
995+
)
996+
except Exception as ex:
997+
logger.error(
998+
f"Exception occurred while loading config file from {import_model_details.os_path}"
999+
f"Exception message: {ex}"
1000+
)
1001+
raise AquaRuntimeError(
1002+
f"The model path {import_model_details.os_path} does not contain the file config.json. "
1003+
f"Please check if the path is correct or the model artifacts are available at this location."
1004+
) from ex
1005+
else:
1006+
try:
1007+
metadata_model_type = (
1008+
verified_model.custom_metadata_list.get(
1009+
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
1010+
).value
1011+
)
1012+
if metadata_model_type:
1013+
if (
1014+
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
1015+
in model_config
10081016
):
1009-
validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}"
1017+
if (
1018+
model_config[
1019+
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
1020+
]
1021+
!= metadata_model_type
1022+
):
1023+
raise AquaRuntimeError(
1024+
f"The {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in {AQUA_MODEL_ARTIFACT_CONFIG}"
1025+
f" at {import_model_details.os_path} is invalid, expected {metadata_model_type} for "
1026+
f"the model {model_name}. Please check if the path is correct or "
1027+
f"the correct model artifacts are available at this location."
1028+
f""
1029+
)
10101030
else:
1011-
validation_result.telemetry_model_name = (
1012-
AQUA_MODEL_TYPE_CUSTOM
1031+
logger.debug(
1032+
f"Could not find {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in "
1033+
f"{AQUA_MODEL_ARTIFACT_CONFIG}. Proceeding with model registration."
10131034
)
1014-
elif model_format == ModelFormat.GGUF and len(gguf_model_files) > 0:
1015-
if import_model_details.finetuning_container and not safetensors_model_files:
1016-
raise AquaValueError(
1017-
"Fine-tuning is currently not supported with GGUF model format."
1018-
)
1035+
except Exception:
1036+
pass
10191037
if verified_model:
1020-
try:
1021-
model_file = verified_model.custom_metadata_list.get(
1022-
AQUA_MODEL_ARTIFACT_FILE
1023-
).value
1024-
except ValueError as err:
1025-
raise AquaRuntimeError(
1026-
f"The model {verified_model.display_name} does not contain the custom metadata {AQUA_MODEL_ARTIFACT_FILE}. "
1027-
f"Please check if the model has the valid metadata."
1028-
) from err
1029-
else:
1030-
model_file = import_model_details.model_file
1031-
1032-
model_files = gguf_model_files
1033-
# todo: have a separate error validation class for different type of error messages.
1034-
if model_file:
1035-
if model_file not in model_files:
1036-
raise AquaRuntimeError(
1037-
f"The model path {import_model_details.os_path} or the Hugging Face "
1038-
f"model repository for {model_name} does not contain the file "
1039-
f"{model_file}. Please check if the path is correct or the model "
1040-
f"artifacts are available at this location."
1041-
)
1042-
else:
1043-
validation_result.model_file = model_file
1044-
elif len(model_files) == 0:
1045-
raise AquaRuntimeError(
1046-
f"The model path {import_model_details.os_path} or the Hugging Face model "
1047-
f"repository for {model_name} does not contain any GGUF format files. "
1048-
f"Please check if the path is correct or the model artifacts are available "
1049-
f"at this location."
1050-
)
1051-
elif len(model_files) > 1:
1052-
raise AquaRuntimeError(
1053-
f"The model path {import_model_details.os_path} or the Hugging Face model "
1054-
f"repository for {model_name} contains multiple GGUF format files. "
1055-
f"Please specify the file that needs to be deployed using the model_file "
1056-
f"parameter."
1038+
validation_result.telemetry_model_name = (
1039+
verified_model.display_name
10571040
)
1041+
elif (
1042+
model_config is not None
1043+
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config
1044+
):
1045+
validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME]}"
1046+
elif (
1047+
model_config is not None
1048+
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config
1049+
):
1050+
validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}"
10581051
else:
1059-
validation_result.model_file = model_files[0]
1052+
validation_result.telemetry_model_name = (
1053+
AQUA_MODEL_TYPE_CUSTOM
1054+
)
10601055

1061-
if verified_model:
1062-
validation_result.telemetry_model_name = verified_model.display_name
1063-
elif import_model_details.download_from_hf:
1064-
validation_result.telemetry_model_name = model_name
1065-
else:
1066-
validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM
1056+
@staticmethod
1057+
def _validate_gguf_format(
1058+
import_model_details: ImportModelDetails = None,
1059+
verified_model: DataScienceModel = None,
1060+
gguf_model_files: List[str] = None,
1061+
validation_result: ModelValidationResult = None,
1062+
model_name: str = None,
1063+
):
1064+
if import_model_details.finetuning_container:
1065+
raise AquaValueError(
1066+
"Fine-tuning is currently not supported with GGUF model format."
1067+
)
1068+
if verified_model:
1069+
try:
1070+
model_file = verified_model.custom_metadata_list.get(
1071+
AQUA_MODEL_ARTIFACT_FILE
1072+
).value
1073+
except ValueError as err:
1074+
raise AquaRuntimeError(
1075+
f"The model {verified_model.display_name} does not contain the custom metadata {AQUA_MODEL_ARTIFACT_FILE}. "
1076+
f"Please check if the model has the valid metadata."
1077+
) from err
1078+
else:
1079+
model_file = import_model_details.model_file
10671080

1068-
return validation_result
1081+
model_files = gguf_model_files
1082+
# todo: have a separate error validation class for different type of error messages.
1083+
if model_file:
1084+
if model_file not in model_files:
1085+
raise AquaRuntimeError(
1086+
f"The model path {import_model_details.os_path} or the Hugging Face "
1087+
f"model repository for {model_name} does not contain the file "
1088+
f"{model_file}. Please check if the path is correct or the model "
1089+
f"artifacts are available at this location."
1090+
)
1091+
else:
1092+
validation_result.model_file = model_file
1093+
elif len(model_files) == 0:
1094+
raise AquaRuntimeError(
1095+
f"The model path {import_model_details.os_path} or the Hugging Face model "
1096+
f"repository for {model_name} does not contain any GGUF format files. "
1097+
f"Please check if the path is correct or the model artifacts are available "
1098+
f"at this location."
1099+
)
1100+
elif len(model_files) > 1:
1101+
raise AquaRuntimeError(
1102+
f"The model path {import_model_details.os_path} or the Hugging Face model "
1103+
f"repository for {model_name} contains multiple GGUF format files. "
1104+
f"Please specify the file that needs to be deployed using the model_file "
1105+
f"parameter."
1106+
)
1107+
else:
1108+
validation_result.model_file = model_files[0]
1109+
1110+
if verified_model:
1111+
validation_result.telemetry_model_name = verified_model.display_name
1112+
elif import_model_details.download_from_hf:
1113+
validation_result.telemetry_model_name = model_name
1114+
else:
1115+
validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM
10691116

10701117
@staticmethod
10711118
def _download_model_from_hf(

0 commit comments

Comments
 (0)