Skip to content

Commit 6f53733

Browse files
mayoorVipulMascarenhaskumar-shivam-ranjanlu-ohaiharsh97
authored
ADS 2.11.17 Release (#928)
Co-authored-by: Vipul <[email protected]> Co-authored-by: Kumar Shivam Ranjan <[email protected]> Co-authored-by: Lu Peng <[email protected]> Co-authored-by: Lu Peng <[email protected]> Co-authored-by: Harsh Rai <[email protected]>
1 parent dc168a5 commit 6f53733

File tree

16 files changed

+1307
-292
lines changed

16 files changed

+1307
-292
lines changed

ads/aqua/app.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

@@ -175,7 +174,7 @@ def create_model_version_set(
175174
f"Invalid model version set name. Please provide a model version set with `{tag}` in tags."
176175
)
177176

178-
except:
177+
except Exception:
179178
logger.debug(
180179
f"Model version set {model_version_set_name} doesn't exist. "
181180
"Creating new model version set."
@@ -254,7 +253,7 @@ def if_artifact_exist(self, model_id: str, **kwargs) -> bool:
254253

255254
try:
256255
response = self.ds_client.head_model_artifact(model_id=model_id, **kwargs)
257-
return True if response.status == 200 else False
256+
return response.status == 200
258257
except oci.exceptions.ServiceError as ex:
259258
if ex.status == 404:
260259
logger.info(f"Artifact not found in model {model_id}.")
@@ -302,15 +301,15 @@ def get_config(self, model_id: str, config_file_name: str) -> Dict:
302301
config_path,
303302
config_file_name=config_file_name,
304303
)
305-
except:
304+
except Exception:
306305
# todo: temp fix for issue related to config load for byom models, update logic to choose the right path
307306
try:
308307
config_path = f"{artifact_path.rstrip('/')}/config/"
309308
config = load_config(
310309
config_path,
311310
config_file_name=config_file_name,
312311
)
313-
except:
312+
except Exception:
314313
pass
315314

316315
if not config:
@@ -343,7 +342,7 @@ def build_cli(self) -> str:
343342
params = [
344343
f"--{field.name} {getattr(self,field.name)}"
345344
for field in fields(self.__class__)
346-
if getattr(self, field.name)
345+
if getattr(self, field.name) is not None
347346
]
348347
cmd = f"{cmd} {' '.join(params)}"
349348
return cmd

ads/aqua/common/enums.py

+9
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class Tags(str, metaclass=ExtendedEnumMeta):
3939
BASE_MODEL_CUSTOM = "aqua_custom_base_model"
4040
AQUA_EVALUATION_MODEL_ID = "evaluation_model_id"
4141
MODEL_FORMAT = "model_format"
42+
MODEL_ARTIFACT_FILE = "model_file"
4243

4344

4445
class InferenceContainerType(str, metaclass=ExtendedEnumMeta):
@@ -59,6 +60,14 @@ class InferenceContainerParamType(str, metaclass=ExtendedEnumMeta):
5960
PARAM_TYPE_LLAMA_CPP = "LLAMA_CPP_PARAMS"
6061

6162

63+
class EvaluationContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
64+
AQUA_EVALUATION_CONTAINER_FAMILY = "odsc-llm-evaluate"
65+
66+
67+
class FineTuningContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
68+
AQUA_FINETUNING_CONTAINER_FAMILY = "odsc-llm-fine-tuning"
69+
70+
6271
class HuggingFaceTags(str, metaclass=ExtendedEnumMeta):
6372
TEXT_GENERATION_INFERENCE = "text-generation-inference"
6473

ads/aqua/common/utils.py

+128-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import os
1111
import random
1212
import re
13+
import shlex
14+
import subprocess
1315
from datetime import datetime, timedelta
1416
from functools import wraps
1517
from pathlib import Path
@@ -19,6 +21,13 @@
1921
import fsspec
2022
import oci
2123
from cachetools import TTLCache, cached
24+
from huggingface_hub.hf_api import HfApi, ModelInfo
25+
from huggingface_hub.utils import (
26+
GatedRepoError,
27+
HfHubHTTPError,
28+
RepositoryNotFoundError,
29+
RevisionNotFoundError,
30+
)
2231
from oci.data_science.models import JobRun, Model
2332
from oci.object_storage.models import ObjectSummary
2433

@@ -37,6 +46,7 @@
3746
COMPARTMENT_MAPPING_KEY,
3847
CONSOLE_LINK_RESOURCE_TYPE_MAPPING,
3948
CONTAINER_INDEX,
49+
HF_LOGIN_DEFAULT_TIMEOUT,
4050
MAXIMUM_ALLOWED_DATASET_IN_BYTE,
4151
MODEL_BY_REFERENCE_OSS_PATH_KEY,
4252
SERVICE_MANAGED_CONTAINER_URI_SCHEME,
@@ -47,7 +57,7 @@
4757
VLLM_INFERENCE_RESTRICTED_PARAMS,
4858
)
4959
from ads.aqua.data import AquaResourceIdentifier
50-
from ads.common.auth import default_signer
60+
from ads.common.auth import AuthState, default_signer
5161
from ads.common.extended_enum import ExtendedEnumMeta
5262
from ads.common.object_storage_details import ObjectStorageDetails
5363
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
@@ -771,6 +781,33 @@ def get_ocid_substring(ocid: str, key_len: int) -> str:
771781
return ocid[-key_len:] if ocid and len(ocid) > key_len else ""
772782

773783

784+
def upload_folder(os_path: str, local_dir: str, model_name: str) -> str:
785+
"""Upload the local folder to the object storage
786+
787+
Args:
788+
os_path (str): object storage URI with prefix. This is the path to upload
789+
local_dir (str): Local directory where the object is downloaded
790+
model_name (str): Name of the huggingface model
791+
Retuns:
792+
str: Object name inside the bucket
793+
"""
794+
os_details: ObjectStorageDetails = ObjectStorageDetails.from_path(os_path)
795+
if not os_details.is_bucket_versioned():
796+
raise ValueError(f"Version is not enabled at object storage location {os_path}")
797+
auth_state = AuthState()
798+
object_path = os_details.filepath.rstrip("/") + "/" + model_name + "/"
799+
command = f"oci os object bulk-upload --src-dir {local_dir} --prefix {object_path} -bn {os_details.bucket} -ns {os_details.namespace} --auth {auth_state.oci_iam_type} --profile {auth_state.oci_key_profile} --no-overwrite"
800+
try:
801+
logger.info(f"Running: {command}")
802+
subprocess.check_call(shlex.split(command))
803+
except subprocess.CalledProcessError as e:
804+
logger.error(
805+
f"Error uploading the object. Exit code: {e.returncode} with error {e.stdout}"
806+
)
807+
808+
return f"oci://{os_details.bucket}@{os_details.namespace}" + "/" + object_path
809+
810+
774811
def is_service_managed_container(container):
775812
return container and container.startswith(SERVICE_MANAGED_CONTAINER_URI_SCHEME)
776813

@@ -935,3 +972,93 @@ def get_restricted_params_by_container(container_type_name: str) -> set:
935972
return TGI_INFERENCE_RESTRICTED_PARAMS
936973
else:
937974
return set()
975+
976+
977+
def get_huggingface_login_timeout() -> int:
978+
"""This helper function returns the huggingface login timeout, returns default if not set via
979+
env var.
980+
Returns
981+
-------
982+
timeout: int
983+
huggingface login timeout.
984+
985+
"""
986+
timeout = HF_LOGIN_DEFAULT_TIMEOUT
987+
try:
988+
timeout = int(
989+
os.environ.get("HF_LOGIN_DEFAULT_TIMEOUT", HF_LOGIN_DEFAULT_TIMEOUT)
990+
)
991+
except ValueError:
992+
pass
993+
return timeout
994+
995+
996+
def format_hf_custom_error_message(error: HfHubHTTPError):
997+
"""
998+
Formats a custom error message based on the Hugging Face error response.
999+
1000+
Parameters
1001+
----------
1002+
error (HfHubHTTPError): The caught exception.
1003+
1004+
Raises
1005+
------
1006+
AquaRuntimeError: A user-friendly error message.
1007+
"""
1008+
# Extract the repository URL from the error message if present
1009+
match = re.search(r"(https://huggingface.co/[^\s]+)", str(error))
1010+
url = match.group(1) if match else "the requested Hugging Face URL."
1011+
1012+
if isinstance(error, RepositoryNotFoundError):
1013+
raise AquaRuntimeError(
1014+
reason=f"Failed to access `{url}`. Please check if the provided repository name is correct. "
1015+
"If the repo is private, make sure you are authenticated and have a valid HF token registered. "
1016+
"To register your token, run this command in your terminal: `huggingface-cli login`",
1017+
service_payload={"error": "RepositoryNotFoundError"},
1018+
)
1019+
1020+
if isinstance(error, GatedRepoError):
1021+
raise AquaRuntimeError(
1022+
reason=f"Access denied to `{url}` "
1023+
"This repository is gated. Access is restricted to authorized users. "
1024+
"Please request access or check with the repository administrator. "
1025+
"If you are trying to access a gated repository, ensure you have a valid HF token registered. "
1026+
"To register your token, run this command in your terminal: `huggingface-cli login`",
1027+
service_payload={"error": "GatedRepoError"},
1028+
)
1029+
1030+
if isinstance(error, RevisionNotFoundError):
1031+
raise AquaRuntimeError(
1032+
reason=f"The specified revision could not be found at `{url}` "
1033+
"Please check the revision identifier and try again.",
1034+
service_payload={"error": "RevisionNotFoundError"},
1035+
)
1036+
1037+
raise AquaRuntimeError(
1038+
reason=f"An error occurred while accessing `{url}` "
1039+
"Please check your network connection and try again. "
1040+
"If you are trying to access a gated repository, ensure you have a valid HF token registered. "
1041+
"To register your token, run this command in your terminal: `huggingface-cli login`",
1042+
service_payload={"error": "Error"},
1043+
)
1044+
1045+
1046+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
1047+
def get_hf_model_info(repo_id: str) -> ModelInfo:
1048+
"""Gets the model information object for the given model repository name. For models that requires a token,
1049+
this method assumes that the token validation is already done.
1050+
1051+
Parameters
1052+
----------
1053+
repo_id: str
1054+
hugging face model repository name
1055+
1056+
Returns
1057+
-------
1058+
instance of ModelInfo object
1059+
1060+
"""
1061+
try:
1062+
return HfApi().model_info(repo_id=repo_id)
1063+
except HfHubHTTPError as err:
1064+
raise format_hf_custom_error_message(err) from err

ads/aqua/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME = "_name_or_path"
3535
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE = "model_type"
3636
AQUA_MODEL_ARTIFACT_FILE = "model_file"
37+
HF_LOGIN_DEFAULT_TIMEOUT = 2
3738

3839
TRAINING_METRICS_FINAL = "training_metrics_final"
3940
VALIDATION_METRICS_FINAL = "validation_metrics_final"

ads/aqua/evaluation/evaluation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def create(
191191
enable_spec=True
192192
).inference
193193
for container in inference_config.values():
194-
if container.name == runtime.image.split(":")[0]:
194+
if container.name == runtime.image[:runtime.image.rfind(":")]:
195195
eval_inference_configuration = (
196196
container.spec.evaluation_configuration
197197
)

ads/aqua/extension/common_handler.py

+75-5
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

65

76
from importlib import metadata
87

8+
import huggingface_hub
99
import requests
10+
from huggingface_hub import HfApi
11+
from huggingface_hub.utils import LocalTokenNotFoundError
1012
from tornado.web import HTTPError
1113

1214
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
1315
from ads.aqua.common.decorator import handle_exceptions
1416
from ads.aqua.common.errors import AquaResourceAccessError, AquaRuntimeError
15-
from ads.aqua.common.utils import fetch_service_compartment, known_realm
17+
from ads.aqua.common.utils import (
18+
fetch_service_compartment,
19+
get_huggingface_login_timeout,
20+
known_realm,
21+
)
1622
from ads.aqua.extension.base_handler import AquaAPIhandler
1723
from ads.aqua.extension.errors import Errors
1824

@@ -46,16 +52,80 @@ def get(self):
4652
4753
"""
4854
if ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment():
49-
return self.finish(dict(status="ok"))
55+
return self.finish({"status": "ok"})
5056
elif known_realm():
51-
return self.finish(dict(status="compatible"))
57+
return self.finish({"status": "compatible"})
5258
else:
5359
raise AquaResourceAccessError(
54-
f"The AI Quick actions extension is not compatible in the given region."
60+
"The AI Quick actions extension is not compatible in the given region."
5561
)
5662

5763

64+
class NetworkStatusHandler(AquaAPIhandler):
65+
"""Handler to check internet connection."""
66+
67+
@handle_exceptions
68+
def get(self):
69+
requests.get("https://huggingface.com", timeout=get_huggingface_login_timeout())
70+
return self.finish({"status": 200, "message": "success"})
71+
72+
73+
class HFLoginHandler(AquaAPIhandler):
74+
"""Handler to login to HF."""
75+
76+
@handle_exceptions
77+
def post(self, *args, **kwargs):
78+
"""Handles post request for the HF login.
79+
80+
Raises
81+
------
82+
HTTPError
83+
Raises HTTPError if inputs are missing or are invalid.
84+
"""
85+
try:
86+
input_data = self.get_json_body()
87+
except Exception as ex:
88+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
89+
90+
if not input_data:
91+
raise HTTPError(400, Errors.NO_INPUT_DATA)
92+
93+
token = input_data.get("token")
94+
95+
if not token:
96+
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("token"))
97+
98+
# Login to HF
99+
try:
100+
huggingface_hub.login(token=token, new_session=False)
101+
except Exception as ex:
102+
raise AquaRuntimeError(
103+
reason=str(ex), service_payload={"error": type(ex).__name__}
104+
) from ex
105+
106+
return self.finish({"status": 200, "message": "login successful"})
107+
108+
109+
class HFUserStatusHandler(AquaAPIhandler):
110+
"""Handler to check if user logged in to the HF."""
111+
112+
@handle_exceptions
113+
def get(self):
114+
try:
115+
HfApi().whoami()
116+
except LocalTokenNotFoundError as err:
117+
raise AquaRuntimeError(
118+
"You are not logged in. Please log in to Hugging Face using the `huggingface-cli login` command."
119+
"See https://huggingface.co/settings/tokens.",
120+
) from err
121+
122+
return self.finish({"status": 200, "message": "logged in"})
123+
124+
58125
__handlers__ = [
59126
("ads_version", ADSVersionHandler),
60127
("hello", CompatibilityCheckHandler),
128+
("network_status", NetworkStatusHandler),
129+
("hf_login", HFLoginHandler),
130+
("hf_logged_in", HFUserStatusHandler),
61131
]

ads/aqua/extension/deployment_handler.py

+2
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def post(self, *args, **kwargs):
101101
container_family = input_data.get("container_family")
102102
ocpus = input_data.get("ocpus")
103103
memory_in_gbs = input_data.get("memory_in_gbs")
104+
model_file = input_data.get("model_file")
104105

105106
self.finish(
106107
AquaDeploymentApp().create(
@@ -122,6 +123,7 @@ def post(self, *args, **kwargs):
122123
container_family=container_family,
123124
ocpus=ocpus,
124125
memory_in_gbs=memory_in_gbs,
126+
model_file=model_file,
125127
)
126128
)
127129

0 commit comments

Comments
 (0)