|
1 | 1 | #!/usr/bin/env python
|
2 |
| -# Copyright (c) 2024 Oracle and/or its affiliates. |
| 2 | +# Copyright (c) 2024, 2025 Oracle and/or its affiliates. |
3 | 3 | # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4 | 4 | """AQUA utils and constants."""
|
5 | 5 |
|
|
11 | 11 | import random
|
12 | 12 | import re
|
13 | 13 | import shlex
|
| 14 | +import shutil |
14 | 15 | import subprocess
|
15 | 16 | from datetime import datetime, timedelta
|
16 | 17 | from functools import wraps
|
|
21 | 22 | import fsspec
|
22 | 23 | import oci
|
23 | 24 | from cachetools import TTLCache, cached
|
| 25 | +from huggingface_hub.constants import HF_HUB_CACHE |
| 26 | +from huggingface_hub.file_download import repo_folder_name |
24 | 27 | from huggingface_hub.hf_api import HfApi, ModelInfo
|
25 | 28 | from huggingface_hub.utils import (
|
26 | 29 | GatedRepoError,
|
|
30 | 33 | )
|
31 | 34 | from oci.data_science.models import JobRun, Model
|
32 | 35 | from oci.object_storage.models import ObjectSummary
|
| 36 | +from pydantic import ValidationError |
33 | 37 |
|
34 | 38 | from ads.aqua.common.enums import (
|
35 | 39 | InferenceContainerParamType,
|
@@ -788,7 +792,9 @@ def get_ocid_substring(ocid: str, key_len: int) -> str:
|
788 | 792 | return ocid[-key_len:] if ocid and len(ocid) > key_len else ""
|
789 | 793 |
|
790 | 794 |
|
791 |
| -def upload_folder(os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None) -> str: |
| 795 | +def upload_folder( |
| 796 | + os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None |
| 797 | +) -> str: |
792 | 798 | """Upload the local folder to the object storage
|
793 | 799 |
|
794 | 800 | Args:
|
@@ -818,6 +824,48 @@ def upload_folder(os_path: str, local_dir: str, model_name: str, exclude_pattern
|
818 | 824 | return f"oci://{os_details.bucket}@{os_details.namespace}" + "/" + object_path
|
819 | 825 |
|
820 | 826 |
|
| 827 | +def cleanup_local_hf_model_artifact( |
| 828 | + model_name: str, |
| 829 | + local_dir: str = None, |
| 830 | +): |
| 831 | + """ |
| 832 | + Helper function that deletes local artifacts downloaded from Hugging Face to free up disk space. |
| 833 | + Parameters |
| 834 | + ---------- |
| 835 | + model_name (str): Name of the huggingface model |
| 836 | + local_dir (str): Local directory where the object is downloaded |
| 837 | +
|
| 838 | + """ |
| 839 | + if local_dir and os.path.exists(local_dir): |
| 840 | + model_dir = os.path.join(local_dir, model_name) |
| 841 | + model_dir = ( |
| 842 | + os.path.dirname(model_dir) |
| 843 | + if "/" in model_name or os.sep in model_name |
| 844 | + else model_dir |
| 845 | + ) |
| 846 | + shutil.rmtree(model_dir, ignore_errors=True) |
| 847 | + if os.path.exists(model_dir): |
| 848 | + logger.debug( |
| 849 | + f"Could not delete local model artifact directory: {model_dir}" |
| 850 | + ) |
| 851 | + else: |
| 852 | + logger.debug(f"Deleted local model artifact directory: {model_dir}.") |
| 853 | + |
| 854 | + hf_local_path = os.path.join( |
| 855 | + HF_HUB_CACHE, repo_folder_name(repo_id=model_name, repo_type="model") |
| 856 | + ) |
| 857 | + shutil.rmtree(hf_local_path, ignore_errors=True) |
| 858 | + |
| 859 | + if os.path.exists(hf_local_path): |
| 860 | + logger.debug( |
| 861 | + f"Could not clear the local Hugging Face cache directory {hf_local_path} for the model {model_name}." |
| 862 | + ) |
| 863 | + else: |
| 864 | + logger.debug( |
| 865 | + f"Cleared contents of local Hugging Face cache directory {hf_local_path} for the model {model_name}." |
| 866 | + ) |
| 867 | + |
| 868 | + |
821 | 869 | def is_service_managed_container(container):
|
822 | 870 | return container and container.startswith(SERVICE_MANAGED_CONTAINER_URI_SCHEME)
|
823 | 871 |
|
@@ -1159,3 +1207,15 @@ def validate_cmd_var(cmd_var: List[str], overrides: List[str]) -> List[str]:
|
1159 | 1207 |
|
1160 | 1208 | combined_cmd_var = cmd_var + overrides
|
1161 | 1209 | return combined_cmd_var
|
| 1210 | + |
| 1211 | + |
| 1212 | +def build_pydantic_error_message(ex: ValidationError): |
| 1213 | + """Added to handle error messages from pydantic model validator. |
| 1214 | + Combine both loc and msg for errors where loc (field) is present in error details, else only build error |
| 1215 | + message using msg field.""" |
| 1216 | + |
| 1217 | + return { |
| 1218 | + ".".join(map(str, e["loc"])): e["msg"] |
| 1219 | + for e in ex.errors() |
| 1220 | + if "loc" in e and e["loc"] |
| 1221 | + } or "; ".join(e["msg"] for e in ex.errors()) |
0 commit comments