Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion llm-service/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from enum import Enum
from typing import cast, Optional, Literal

MODEL_PROVIDER_ENV_VAR_NAME = "MODEL_PROVIDER"

SummaryStorageProviderType = Literal["Local", "S3"]
ChatStoreProviderType = Literal["Local", "S3"]
Expand Down Expand Up @@ -196,7 +197,7 @@ def model_provider(self) -> Optional[ModelSource]:
"""The preferred model provider to use.
Options: 'AZURE', 'CAII', 'OPENAI', 'BEDROCK'
If not set, will use the first available provider in priority order."""
provider = os.environ.get("MODEL_PROVIDER")
provider = os.environ.get(MODEL_PROVIDER_ENV_VAR_NAME)
try:
return ModelSource(provider)
except ValueError:
Expand Down
9 changes: 9 additions & 0 deletions llm-service/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,15 @@ def _configure_logger() -> None:

_configure_logger()


def _enable_http_debug() -> None:
urllib3_logger = logging.getLogger("urllib3")
urllib3_logger.setLevel(logging.DEBUG)
urllib3_logger.propagate = True


_enable_http_debug()
Comment on lines +93 to +99
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ad85f43 should be reverted before this is merged.


if os.environ.get("ENABLE_OPIK") == "True":
opik.configure(
use_local=True, url=os.environ.get("OPIK_URL", "http://localhost:5174")
Expand Down
10 changes: 10 additions & 0 deletions llm-service/app/services/models/providers/_model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
# DATA.
#
import abc
import itertools
import os

from llama_index.core.base.embeddings.base import BaseEmbedding
Expand Down Expand Up @@ -109,3 +110,12 @@ def get_embedding_model(name: str) -> BaseEmbedding:
def get_reranking_model(name: str, top_n: int) -> BaseNodePostprocessor:
"""Return reranking model with `name`."""
raise NotImplementedError


def get_all_env_var_names() -> set[str]:
"""Return the names of all the env vars required by all model providers."""
return set(
itertools.chain.from_iterable(
subcls.get_env_var_names() for subcls in _ModelProvider.__subclasses__()
)
)
13 changes: 6 additions & 7 deletions llm-service/app/services/models/reranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,21 @@ def get_noop() -> BaseNodePostprocessor:
def list_available() -> list[ModelResponse]:
return get_provider_class().list_reranking_models()

_TEST_NODES = [
NodeWithScore(node=TextNode(text="test node"), score=0.5),
NodeWithScore(node=TextNode(text="another test node"), score=0.4),
]

@classmethod
def test(cls, model_name: str) -> str:
models = cls.list_available()
for model in models:
if model.model_id == model_name:
node = NodeWithScore(node=TextNode(text="test"), score=0.5)
another_test_node = NodeWithScore(
node=TextNode(text="another test node"), score=0.4
)
reranking_model: BaseNodePostprocessor | None = cls.get(
model_name=model_name
)
if reranking_model:
reranking_model.postprocess_nodes(
[node, another_test_node], None, "test"
)
reranking_model.postprocess_nodes(cls._TEST_NODES, None, "test")
return "ok"
raise HTTPException(status_code=404, detail="Model not found")

Expand Down
37 changes: 17 additions & 20 deletions llm-service/app/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
# DATA.
# ##############################################################################

import os
import pathlib
import uuid
Expand All @@ -53,9 +52,7 @@
from app.ai.vector_stores.qdrant import QdrantVectorStore
from app.main import app
from app.services.metadata_apis import data_sources_metadata_api
from app.services import models
from app.services.metadata_apis.data_sources_metadata_api import RagDataSource
from app.services.models.providers import BedrockModelProvider


@dataclass
Expand Down Expand Up @@ -177,16 +174,16 @@ def get_datasource_metadata(data_source_id: int) -> RagDataSource:
)


@pytest.fixture(autouse=True)
def embedding_model(monkeypatch: pytest.MonkeyPatch) -> None:
model = DummyEmbeddingModel()
monkeypatch.setattr(models.Embedding, "get", lambda cls, model_name=None: model)


@pytest.fixture(autouse=True)
def llm(monkeypatch: pytest.MonkeyPatch) -> None:
model = models.LLM.get_noop()
monkeypatch.setattr(models.LLM, "get", lambda cls, model_name=None: model)
# @pytest.fixture(autouse=True)
# def embedding_model(monkeypatch: pytest.MonkeyPatch) -> None:
# model = DummyEmbeddingModel()
# monkeypatch.setattr(models.Embedding, "get", lambda cls, model_name=None: model)
#
#
# @pytest.fixture(autouse=True)
# def llm(monkeypatch: pytest.MonkeyPatch) -> None:
# model = models.LLM.get_noop()
# monkeypatch.setattr(models.LLM, "get", lambda cls, model_name=None: model)


@pytest.fixture
Expand Down Expand Up @@ -219,10 +216,10 @@ def client() -> Iterator[TestClient]:
yield test_client


@pytest.fixture(autouse=True)
def _get_model_arn_by_suffix(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
BedrockModelProvider,
"_get_model_arns",
lambda: [],
)
# @pytest.fixture(autouse=True)
# def _get_model_arn_by_suffix(monkeypatch: pytest.MonkeyPatch) -> None:
# monkeypatch.setattr(
# BedrockModelProvider,
# "_get_model_arns",
# lambda: [],
# )
38 changes: 38 additions & 0 deletions llm-service/app/tests/model_provider_mocks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#
# CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
# (C) Cloudera, Inc. 2025
# All rights reserved.
#
# Applicable Open Source License: Apache 2.0
#
# NOTE: Cloudera open source products are modular software products
# made up of hundreds of individual components, each of which was
# individually copyrighted. Each Cloudera open source product is a
# collective work under U.S. Copyright Law. Your license to use the
# collective work is as provided in your written agreement with
# Cloudera. Used apart from the collective work, this file is
# licensed for your use pursuant to the open source license
# identified above.
#
# This code is provided to you pursuant a written agreement with
# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute
# this code. If you do not have a written agreement with Cloudera nor
# with an authorized and properly licensed third party, you do not
# have any rights to access nor to use this code.
#
# Absent a written agreement with Cloudera, Inc. ("Cloudera") to the
# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY
# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED
# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO
# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND
# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU,
# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS
# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE
# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR
# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES
# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
# DATA.
#

Loading
Loading