|
10 | 10 | import os
|
11 | 11 | import random
|
12 | 12 | import re
|
| 13 | +import shlex |
| 14 | +import subprocess |
13 | 15 | from datetime import datetime, timedelta
|
14 | 16 | from functools import wraps
|
15 | 17 | from pathlib import Path
|
|
19 | 21 | import fsspec
|
20 | 22 | import oci
|
21 | 23 | from cachetools import TTLCache, cached
|
| 24 | +from huggingface_hub.hf_api import HfApi, ModelInfo |
| 25 | +from huggingface_hub.utils import ( |
| 26 | + GatedRepoError, |
| 27 | + HfHubHTTPError, |
| 28 | + RepositoryNotFoundError, |
| 29 | + RevisionNotFoundError, |
| 30 | +) |
22 | 31 | from oci.data_science.models import JobRun, Model
|
23 | 32 | from oci.object_storage.models import ObjectSummary
|
24 | 33 |
|
|
37 | 46 | COMPARTMENT_MAPPING_KEY,
|
38 | 47 | CONSOLE_LINK_RESOURCE_TYPE_MAPPING,
|
39 | 48 | CONTAINER_INDEX,
|
| 49 | + HF_LOGIN_DEFAULT_TIMEOUT, |
40 | 50 | MAXIMUM_ALLOWED_DATASET_IN_BYTE,
|
41 | 51 | MODEL_BY_REFERENCE_OSS_PATH_KEY,
|
42 | 52 | SERVICE_MANAGED_CONTAINER_URI_SCHEME,
|
|
47 | 57 | VLLM_INFERENCE_RESTRICTED_PARAMS,
|
48 | 58 | )
|
49 | 59 | from ads.aqua.data import AquaResourceIdentifier
|
50 |
| -from ads.common.auth import default_signer |
| 60 | +from ads.common.auth import AuthState, default_signer |
51 | 61 | from ads.common.extended_enum import ExtendedEnumMeta
|
52 | 62 | from ads.common.object_storage_details import ObjectStorageDetails
|
53 | 63 | from ads.common.oci_resource import SEARCH_TYPE, OCIResource
|
@@ -771,6 +781,33 @@ def get_ocid_substring(ocid: str, key_len: int) -> str:
|
771 | 781 | return ocid[-key_len:] if ocid and len(ocid) > key_len else ""
|
772 | 782 |
|
773 | 783 |
|
| 784 | +def upload_folder(os_path: str, local_dir: str, model_name: str) -> str: |
| 785 | + """Upload the local folder to the object storage |
| 786 | +
|
| 787 | + Args: |
| 788 | + os_path (str): object storage URI with prefix. This is the path to upload |
| 789 | + local_dir (str): Local directory where the object is downloaded |
| 790 | + model_name (str): Name of the huggingface model |
| 791 | + Retuns: |
| 792 | + str: Object name inside the bucket |
| 793 | + """ |
| 794 | + os_details: ObjectStorageDetails = ObjectStorageDetails.from_path(os_path) |
| 795 | + if not os_details.is_bucket_versioned(): |
| 796 | + raise ValueError(f"Version is not enabled at object storage location {os_path}") |
| 797 | + auth_state = AuthState() |
| 798 | + object_path = os_details.filepath.rstrip("/") + "/" + model_name + "/" |
| 799 | + command = f"oci os object bulk-upload --src-dir {local_dir} --prefix {object_path} -bn {os_details.bucket} -ns {os_details.namespace} --auth {auth_state.oci_iam_type} --profile {auth_state.oci_key_profile} --no-overwrite" |
| 800 | + try: |
| 801 | + logger.info(f"Running: {command}") |
| 802 | + subprocess.check_call(shlex.split(command)) |
| 803 | + except subprocess.CalledProcessError as e: |
| 804 | + logger.error( |
| 805 | + f"Error uploading the object. Exit code: {e.returncode} with error {e.stdout}" |
| 806 | + ) |
| 807 | + |
| 808 | + return f"oci://{os_details.bucket}@{os_details.namespace}" + "/" + object_path |
| 809 | + |
| 810 | + |
774 | 811 | def is_service_managed_container(container):
|
775 | 812 | return container and container.startswith(SERVICE_MANAGED_CONTAINER_URI_SCHEME)
|
776 | 813 |
|
@@ -935,3 +972,93 @@ def get_restricted_params_by_container(container_type_name: str) -> set:
|
935 | 972 | return TGI_INFERENCE_RESTRICTED_PARAMS
|
936 | 973 | else:
|
937 | 974 | return set()
|
| 975 | + |
| 976 | + |
| 977 | +def get_huggingface_login_timeout() -> int: |
| 978 | + """This helper function returns the huggingface login timeout, returns default if not set via |
| 979 | + env var. |
| 980 | + Returns |
| 981 | + ------- |
| 982 | + timeout: int |
| 983 | + huggingface login timeout. |
| 984 | +
|
| 985 | + """ |
| 986 | + timeout = HF_LOGIN_DEFAULT_TIMEOUT |
| 987 | + try: |
| 988 | + timeout = int( |
| 989 | + os.environ.get("HF_LOGIN_DEFAULT_TIMEOUT", HF_LOGIN_DEFAULT_TIMEOUT) |
| 990 | + ) |
| 991 | + except ValueError: |
| 992 | + pass |
| 993 | + return timeout |
| 994 | + |
| 995 | + |
| 996 | +def format_hf_custom_error_message(error: HfHubHTTPError): |
| 997 | + """ |
| 998 | + Formats a custom error message based on the Hugging Face error response. |
| 999 | +
|
| 1000 | + Parameters |
| 1001 | + ---------- |
| 1002 | + error (HfHubHTTPError): The caught exception. |
| 1003 | +
|
| 1004 | + Raises |
| 1005 | + ------ |
| 1006 | + AquaRuntimeError: A user-friendly error message. |
| 1007 | + """ |
| 1008 | + # Extract the repository URL from the error message if present |
| 1009 | + match = re.search(r"(https://huggingface.co/[^\s]+)", str(error)) |
| 1010 | + url = match.group(1) if match else "the requested Hugging Face URL." |
| 1011 | + |
| 1012 | + if isinstance(error, RepositoryNotFoundError): |
| 1013 | + raise AquaRuntimeError( |
| 1014 | + reason=f"Failed to access `{url}`. Please check if the provided repository name is correct. " |
| 1015 | + "If the repo is private, make sure you are authenticated and have a valid HF token registered. " |
| 1016 | + "To register your token, run this command in your terminal: `huggingface-cli login`", |
| 1017 | + service_payload={"error": "RepositoryNotFoundError"}, |
| 1018 | + ) |
| 1019 | + |
| 1020 | + if isinstance(error, GatedRepoError): |
| 1021 | + raise AquaRuntimeError( |
| 1022 | + reason=f"Access denied to `{url}` " |
| 1023 | + "This repository is gated. Access is restricted to authorized users. " |
| 1024 | + "Please request access or check with the repository administrator. " |
| 1025 | + "If you are trying to access a gated repository, ensure you have a valid HF token registered. " |
| 1026 | + "To register your token, run this command in your terminal: `huggingface-cli login`", |
| 1027 | + service_payload={"error": "GatedRepoError"}, |
| 1028 | + ) |
| 1029 | + |
| 1030 | + if isinstance(error, RevisionNotFoundError): |
| 1031 | + raise AquaRuntimeError( |
| 1032 | + reason=f"The specified revision could not be found at `{url}` " |
| 1033 | + "Please check the revision identifier and try again.", |
| 1034 | + service_payload={"error": "RevisionNotFoundError"}, |
| 1035 | + ) |
| 1036 | + |
| 1037 | + raise AquaRuntimeError( |
| 1038 | + reason=f"An error occurred while accessing `{url}` " |
| 1039 | + "Please check your network connection and try again. " |
| 1040 | + "If you are trying to access a gated repository, ensure you have a valid HF token registered. " |
| 1041 | + "To register your token, run this command in your terminal: `huggingface-cli login`", |
| 1042 | + service_payload={"error": "Error"}, |
| 1043 | + ) |
| 1044 | + |
| 1045 | + |
| 1046 | +@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now)) |
| 1047 | +def get_hf_model_info(repo_id: str) -> ModelInfo: |
| 1048 | + """Gets the model information object for the given model repository name. For models that requires a token, |
| 1049 | + this method assumes that the token validation is already done. |
| 1050 | +
|
| 1051 | + Parameters |
| 1052 | + ---------- |
| 1053 | + repo_id: str |
| 1054 | + hugging face model repository name |
| 1055 | +
|
| 1056 | + Returns |
| 1057 | + ------- |
| 1058 | + instance of ModelInfo object |
| 1059 | +
|
| 1060 | + """ |
| 1061 | + try: |
| 1062 | + return HfApi().model_info(repo_id=repo_id) |
| 1063 | + except HfHubHTTPError as err: |
| 1064 | + raise format_hf_custom_error_message(err) from err |
0 commit comments