Skip to content

Commit 7d966b5

Browse files
committed
Merge branch 'main' of https://github.com/oracle/accelerated-data-science into ODSC-61938/fix_validation_for_gguf_safetensors
2 parents dc7b47d + 9e8e5d7 commit 7d966b5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+2932
-1146
lines changed

.github/workflows/run-unittests-py38-cov-report.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ jobs:
115115
cov-${{ matrix.name }}/htmlcov/
116116
cov-${{ matrix.name }}/.coverage
117117
cov-${{ matrix.name }}/coverage.xml
118+
include-hidden-files: true
118119

119120
coverage-report:
120121
name: "Coverage report"

THIRD_PARTY_LICENSES.txt

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ fastavro
7272
* Source code: https://github.com/fastavro/fastavro
7373
* Project home: https://github.com/fastavro/fastavro
7474

75+
fiona
76+
* Copyright (c) 2007, Sean C. Gillies
77+
* License: BSD 3-Clause "New" or "Revised" License
78+
* Source code: https://github.com/Toblerity/Fiona
79+
* Project home: https://github.com/Toblerity/Fiona
80+
7581
folium
7682
* Copyright (C) 2013, Rob Story
7783
* License: MIT License
@@ -151,6 +157,18 @@ langchain
151157
* Source code: https://github.com/langchain-ai/langchain
152158
* Project home: https://www.langchain.com/
153159

160+
langchain-community
161+
* Copyright (c) 2023 LangChain, Inc.
162+
* License: MIT license
163+
* Source code: https://github.com/langchain-ai/langchain/tree/master/libs/community
164+
* Project home: https://github.com/langchain-ai/langchain/tree/master/libs/community
165+
166+
langchain-openai
167+
* Copyright (c) 2023 LangChain, Inc.
168+
* License: MIT license
169+
* Source code: https://github.com/langchain-ai/langchain/tree/master/libs/partners/openai
170+
* Project home: https://github.com/langchain-ai/langchain/tree/master/libs/partners/openai
171+
154172
lightgbm
155173
* Copyright (c) 2023 Microsoft Corporation
156174
* License: MIT license
@@ -459,7 +477,13 @@ pydantic
459477
* Source code: https://github.com/pydantic/pydantic
460478
* Project home: https://docs.pydantic.dev/latest/
461479

462-
=======
480+
rrcf
481+
* Copyright 2018 kLabUM
482+
* License: MIT License
483+
* Source code: https://github.com/kLabUM/rrcf
484+
* Project home: https://github.com/kLabUM/rrcf
485+
486+
463487
=============================== Licenses ===============================
464488
------------------------------------------------------------------------
465489

ads/aqua/common/utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
)
5959
from ads.aqua.data import AquaResourceIdentifier
6060
from ads.common.auth import AuthState, default_signer
61+
from ads.common.decorator.threaded import threaded
6162
from ads.common.extended_enum import ExtendedEnumMeta
6263
from ads.common.object_storage_details import ObjectStorageDetails
6364
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
@@ -225,6 +226,7 @@ def read_file(file_path: str, **kwargs) -> str:
225226
return UNKNOWN
226227

227228

229+
@threaded()
228230
def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
229231
artifact_path = f"{file_path.rstrip('/')}/{config_file_name}"
230232
signer = default_signer() if artifact_path.startswith("oci://") else {}
@@ -1065,11 +1067,15 @@ def get_hf_model_info(repo_id: str) -> ModelInfo:
10651067

10661068

10671069
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
1068-
def list_hf_models(query:str) -> List[str]:
1070+
def list_hf_models(query: str) -> List[str]:
10691071
try:
1070-
models= HfApi().list_models(model_name=query,task="text-generation",sort="downloads",direction=-1,limit=20)
1072+
models = HfApi().list_models(
1073+
model_name=query,
1074+
task="text-generation",
1075+
sort="downloads",
1076+
direction=-1,
1077+
limit=20,
1078+
)
10711079
return [model.id for model in models if model.disabled is None]
10721080
except HfHubHTTPError as err:
10731081
raise format_hf_custom_error_message(err) from err
1074-
1075-

ads/aqua/extension/common_handler.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,15 @@
1111
from huggingface_hub.utils import LocalTokenNotFoundError
1212
from tornado.web import HTTPError
1313

14-
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
1514
from ads.aqua.common.decorator import handle_exceptions
1615
from ads.aqua.common.errors import AquaResourceAccessError, AquaRuntimeError
1716
from ads.aqua.common.utils import (
18-
fetch_service_compartment,
1917
get_huggingface_login_timeout,
2018
known_realm,
2119
)
2220
from ads.aqua.extension.base_handler import AquaAPIhandler
2321
from ads.aqua.extension.errors import Errors
22+
from ads.aqua.extension.utils import ui_compatability_check
2423

2524

2625
class ADSVersionHandler(AquaAPIhandler):
@@ -51,7 +50,7 @@ def get(self):
5150
AquaResourceAccessError: raised when aqua is not accessible in the given session/region.
5251
5352
"""
54-
if ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment():
53+
if ui_compatability_check():
5554
return self.finish({"status": "ok"})
5655
elif known_realm():
5756
return self.finish({"status": "compatible"})

ads/aqua/extension/common_ws_msg_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from importlib import metadata
88
from typing import List, Union
99

10-
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, fetch_service_compartment
1110
from ads.aqua.common.decorator import handle_exceptions
1211
from ads.aqua.common.errors import AquaResourceAccessError
1312
from ads.aqua.common.utils import known_realm
@@ -17,6 +16,7 @@
1716
CompatibilityCheckResponse,
1817
RequestResponseType,
1918
)
19+
from ads.aqua.extension.utils import ui_compatability_check
2020

2121

2222
class AquaCommonWsMsgHandler(AquaWSMsgHandler):
@@ -39,7 +39,7 @@ def process(self) -> Union[AdsVersionResponse, CompatibilityCheckResponse]:
3939
)
4040
return response
4141
if request.get("kind") == "CompatibilityCheck":
42-
if ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment():
42+
if ui_compatability_check():
4343
return CompatibilityCheckResponse(
4444
message_id=request.get("message_id"),
4545
kind=RequestResponseType.CompatibilityCheck,

ads/aqua/extension/utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54
from dataclasses import fields
5+
from datetime import datetime, timedelta
66
from typing import Dict, Optional
77

8+
from cachetools import TTLCache, cached
89
from tornado.web import HTTPError
910

11+
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
12+
from ads.aqua.common.utils import fetch_service_compartment
1013
from ads.aqua.extension.errors import Errors
1114

1215

@@ -21,3 +24,11 @@ def validate_function_parameters(data_class, input_data: Dict):
2124
raise HTTPError(
2225
400, Errors.MISSING_REQUIRED_PARAMETER.format(required_parameter)
2326
)
27+
28+
29+
@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=1), timer=datetime.now))
30+
def ui_compatability_check():
31+
"""This method caches the service compartment OCID details that is set by either the environment variable or if
32+
fetched from the configuration. The cached result is returned when multiple calls are made in quick succession
33+
from the UI to avoid multiple config file loads."""
34+
return ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment()

ads/llm/__init__.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,16 @@
66

77
try:
88
import langchain
9-
from ads.llm.langchain.plugins.llm_gen_ai import GenerativeAI
10-
from ads.llm.langchain.plugins.llm_md import ModelDeploymentTGI
11-
from ads.llm.langchain.plugins.llm_md import ModelDeploymentVLLM
12-
from ads.llm.langchain.plugins.embeddings import GenerativeAIEmbeddings
9+
from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint import (
10+
OCIModelDeploymentVLLM,
11+
OCIModelDeploymentTGI,
12+
)
13+
from ads.llm.langchain.plugins.chat_models.oci_data_science import (
14+
ChatOCIModelDeployment,
15+
ChatOCIModelDeploymentVLLM,
16+
ChatOCIModelDeploymentTGI,
17+
)
18+
from ads.llm.chat_template import ChatTemplates
1319
except ImportError as ex:
1420
if ex.name == "langchain":
1521
raise ImportError(

ads/llm/chat_template.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*--
3+
4+
# Copyright (c) 2023 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
7+
8+
import os
9+
10+
11+
class ChatTemplates:
12+
"""Contains chat templates."""
13+
14+
@staticmethod
15+
def _read_template(filename):
16+
with open(
17+
os.path.join(os.path.dirname(__file__), "templates", filename),
18+
mode="r",
19+
encoding="utf-8",
20+
) as f:
21+
return f.read()
22+
23+
@staticmethod
24+
def mistral():
25+
"""Chat template for auto tool calling with Mistral model deploy with vLLM."""
26+
return ChatTemplates._read_template("tool_chat_template_mistral_parallel.jinja")
27+
28+
@staticmethod
29+
def hermes():
30+
"""Chat template for auto tool calling with Hermes model deploy with vLLM."""
31+
return ChatTemplates._read_template("tool_chat_template_hermes.jinja")

ads/llm/guardrails/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from typing import Any, List, Dict, Tuple
1515
from langchain.schema.prompt import PromptValue
1616
from langchain.tools.base import BaseTool, ToolException
17-
from langchain.pydantic_v1 import BaseModel, root_validator
17+
from pydantic import BaseModel, model_validator
1818

1919

2020
class RunInfo(BaseModel):
@@ -190,7 +190,8 @@ class Config:
190190
This is used by the ``apply_filter()`` method.
191191
"""
192192

193-
@root_validator
193+
@model_validator(mode="before")
194+
@classmethod
194195
def default_name(cls, values):
195196
"""Sets the default name of the guardrail."""
196197
if not values.get("name"):

ads/llm/guardrails/huggingface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
import evaluate
9-
from langchain.pydantic_v1 import root_validator
9+
from pydantic.v1 import root_validator
1010
from .base import Guardrail
1111

1212

0 commit comments

Comments
 (0)