diff --git a/CHANGELOG.md b/CHANGELOG.md index 27ed4d6..e0914ef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +- Update response validation methods to use the TLM endpoint in the Codex backend rather than a TLM model. +- Update response validation methods to support accepting and propagating TLMOptions from the cleanlab_tlm library. + ## [1.0.1] - 2025-02-26 - Updates to logic for `is_unhelpful_response` util method. diff --git a/src/cleanlab_codex/response_validation.py b/src/cleanlab_codex/response_validation.py index 677ec64..6b9b2eb 100644 --- a/src/cleanlab_codex/response_validation.py +++ b/src/cleanlab_codex/response_validation.py @@ -4,6 +4,9 @@ from __future__ import annotations +from typing import ( + TYPE_CHECKING as _TYPE_CHECKING, +) from typing import ( Any, Callable, @@ -13,13 +16,33 @@ cast, ) +from codex import AuthenticationError, BadRequestError from pydantic import BaseModel, ConfigDict, Field +from cleanlab_codex.internal.sdk_client import ( + MissingAuthKeyError, + client_from_access_key, + client_from_api_key, + is_access_key, +) from cleanlab_codex.internal.utils import generate_pydantic_model_docstring -from cleanlab_codex.types.tlm import TLM from cleanlab_codex.utils.errors import MissingDependencyError from cleanlab_codex.utils.prompt import default_format_prompt +if _TYPE_CHECKING: + from cleanlab_tlm.tlm import TLMOptions + from codex import Codex as _Codex + + from cleanlab_codex.types.tlm import TlmScoreResponse + + +class MissingAuthError(ValueError): + """Raised when no API key or access key is provided and untrustworthy or unhelpfulness checks are run.""" + + def __str__(self) -> str: + return "A valid Codex API key or access key must be provided when using the TLM for untrustworthy or unhelpfulness checks." + + _DEFAULT_FALLBACK_ANSWER: str = ( "Based on the available information, I cannot provide a complete answer to this question." ) @@ -73,9 +96,14 @@ class BadResponseDetectionConfig(BaseModel): ) # Shared config (for untrustworthiness and unhelpfulness checks) - tlm: Optional[TLM] = Field( + tlm_options: Optional[TLMOptions] = Field( + default=None, + description="Configuration options for the TLM model used for evaluation.", + ) + + codex_key: Optional[str] = Field( default=None, - description="TLM model to use for evaluation (required for untrustworthiness and unhelpfulness checks).", + description="Codex Access Key or API Key to use when querying TLM for untrustworthiness and unhelpfulness checks.", ) @@ -95,6 +123,8 @@ def is_bad_response( context: Optional[str] = None, query: Optional[str] = None, config: Union[BadResponseDetectionConfig, Dict[str, Any]] = _DEFAULT_CONFIG, + run_untrustworthy_check: Optional[bool] = True, + run_unhelpful_check: Optional[bool] = True, ) -> bool: """Run a series of checks to determine if a response is bad. @@ -113,6 +143,8 @@ def is_bad_response( context (str, optional): Optional context/documents used for answering. Required for untrustworthy check. query (str, optional): Optional user question. Required for untrustworthy and unhelpful checks. config (BadResponseDetectionConfig, optional): Optional, configuration parameters for validation checks. See [BadResponseDetectionConfig](#class-badresponsedetectionconfig) for details. If not provided, default values will be used. + run_untrustworthy_check (bool, optional): Optional flag to specify whether to run untrustworthy check. This check is run by default. + run_unhelpful_check (bool, optional): Optional flag to specify whether to run unhelpfulness check. This check is run by default. Returns: bool: `True` if any validation check fails, `False` if all pass. @@ -130,7 +162,7 @@ def is_bad_response( ) ) - can_run_untrustworthy_check = query is not None and context is not None and config.tlm is not None + can_run_untrustworthy_check = query is not None and context is not None and run_untrustworthy_check if can_run_untrustworthy_check: # The if condition guarantees these are not None validation_checks.append( @@ -138,20 +170,22 @@ def is_bad_response( response=response, context=cast(str, context), query=cast(str, query), - tlm=cast(TLM, config.tlm), + tlm_options=config.tlm_options, trustworthiness_threshold=config.trustworthiness_threshold, format_prompt=config.format_prompt, + codex_key=config.codex_key, ) ) - can_run_unhelpful_check = query is not None and config.tlm is not None + can_run_unhelpful_check = query is not None and run_unhelpful_check if can_run_unhelpful_check: validation_checks.append( lambda: is_unhelpful_response( response=response, query=cast(str, query), - tlm=cast(TLM, config.tlm), + tlm_options=config.tlm_options, confidence_score_threshold=config.unhelpfulness_confidence_threshold, + codex_key=config.codex_key, ) ) @@ -189,13 +223,76 @@ def is_fallback_response( return bool(partial_ratio >= threshold) +def _create_codex_client(codex_key_arg: str | None) -> _Codex: + """ + Helper method to create a Codex client for proxying TLM requests. + + Args: + codex_key_or_arg (str): A Codex API or Access key to use when querying TLM. + + Returns: + _Codex: A Codex client to use to proxy TLM requests. + """ + if codex_key_arg is None: + try: + return client_from_access_key() + except MissingAuthKeyError: + pass + try: + return client_from_api_key() + except (MissingAuthKeyError, BadRequestError): + pass + raise MissingAuthError from None + + try: + if is_access_key(codex_key_arg): + return client_from_access_key(codex_key_arg) + + return client_from_api_key(codex_key_arg) + except (MissingAuthKeyError, BadRequestError): + raise MissingAuthError from None + + +def _try_tlm_score( + client: _Codex, + prompt: str, + response: str, + options: Optional[TLMOptions] = None, + constrain_outputs: Optional[list[str]] = None, +) -> TlmScoreResponse: + """ + Helper mtehod to try reaching the TLM scoring Codex endpoint, and catch any Authentication issues and raise our own Authentication Error. + + Args: + client (_Codex): The (authenticated) Codex client to use. + prompt (str): The prompt to pass to tlm.score. + response (str): The response to pass to tlm.score. + options (TLMOptions): The TLMOptions to pass to the TLM. + constrain_outputs (list[str]): The constrain_outputs keyword argument to pass to tlm.score. + + Notes: + We need the try-except here since when users authenticate via an access key, there is no eager check to see if they + are correctly authenticated (unlike when authenticating via an API key, which performs an immediate check to see + if the authentication is valid). This means that we could get AuthenticationErrors from the Codex client, that we + want to catch, and instead raise our own MissingAuthError. + + Returns: + TLMScoreResponse: The TLMScoreResponse from TLM, or a MissingAuthError if the user is not correctly authenticated. + """ + try: + return client.tlm.score(prompt=prompt, response=response, options=options, constrain_outputs=constrain_outputs) + except AuthenticationError: + raise MissingAuthError from None + + def is_untrustworthy_response( response: str, context: str, query: str, - tlm: TLM, + tlm_options: Optional[TLMOptions] = None, trustworthiness_threshold: float = _DEFAULT_TRUSTWORTHINESS_THRESHOLD, format_prompt: Callable[[str, str], str] = default_format_prompt, + codex_key: Optional[str] = None, ) -> bool: """Check if a response is untrustworthy. @@ -207,27 +304,20 @@ def is_untrustworthy_response( response (str): The response to check from the assistant. context (str): The context information available for answering the query. query (str): The user's question or request. - tlm (TLM): The TLM model to use for evaluation. + tlm_options (TLMOptions): The options to pass to the TLM model used for evaluation. trustworthiness_threshold (float): Score threshold (0.0-1.0) under which a response is considered untrustworthy. Lower values allow less trustworthy responses. Default 0.5 means responses with scores less than 0.5 are considered untrustworthy. format_prompt (Callable[[str, str], str]): Function that takes (query, context) and returns a formatted prompt string. Users should provide the prompt formatting function for their RAG application here so that the response can be evaluated using the same prompt that was used to generate the response. + codex_key (str): A Codex API or Access key to use when querying TLM. Returns: bool: `True` if the response is deemed untrustworthy by TLM, `False` otherwise. """ - try: - from cleanlab_tlm import TLM # noqa: F401 - except ImportError as e: - raise MissingDependencyError( - import_name=e.name or "cleanlab_tlm", - package_name="cleanlab-tlm", - package_url="https://github.com/cleanlab/cleanlab-tlm", - ) from e - prompt = format_prompt(query, context) - result = tlm.get_trustworthiness_score(prompt, response) + client = _create_codex_client(codex_key) + result = _try_tlm_score(client=client, prompt=prompt, response=response, options=tlm_options) score: float = result["trustworthiness_score"] return score < trustworthiness_threshold @@ -235,8 +325,9 @@ def is_untrustworthy_response( def is_unhelpful_response( response: str, query: str, - tlm: TLM, + tlm_options: Optional[TLMOptions] = None, confidence_score_threshold: float = _DEFAULT_UNHELPFULNESS_CONFIDENCE_THRESHOLD, + codex_key: Optional[str] = None, ) -> bool: """Check if a response is unhelpful by asking [TLM](/tlm) to evaluate it. @@ -248,23 +339,15 @@ def is_unhelpful_response( Args: response (str): The response to check. query (str): User query that will be used to evaluate if the response is helpful. - tlm (TLM): The TLM model to use for evaluation. + tlm_options (TLMOptions): The options to pass to the TLM model used for evaluation. confidence_score_threshold (float): Confidence threshold (0.0-1.0) above which a response is considered unhelpful. E.g. if confidence_score_threshold is 0.5, then responses with scores higher than 0.5 are considered unhelpful. + codex_key (str): A Codex API or Access key to use when querying TLM. Returns: bool: `True` if TLM determines the response is unhelpful with sufficient confidence, `False` otherwise. """ - try: - from cleanlab_tlm import TLM # noqa: F401 - except ImportError as e: - raise MissingDependencyError( - import_name=e.name or "cleanlab_tlm", - package_name="cleanlab-tlm", - package_url="https://github.com/cleanlab/cleanlab-tlm", - ) from e - # IMPORTANT: The current implementation couples three things that must stay in sync: # 1. The question phrasing ("is unhelpful?") # 2. The expected_unhelpful_response ("Yes") @@ -300,8 +383,13 @@ def is_unhelpful_response( f"{question}" ) - output = tlm.get_trustworthiness_score( - prompt, response=expected_unhelpful_response, constrain_outputs=["Yes", "No"] + client = _create_codex_client(codex_key) + output = _try_tlm_score( + client=client, + prompt=prompt, + response=expected_unhelpful_response, + options=tlm_options, + constrain_outputs=["Yes", "No"], ) # Current implementation assumes question is phrased to expect "Yes" for unhelpful responses diff --git a/src/cleanlab_codex/types/tlm.py b/src/cleanlab_codex/types/tlm.py index 48d071f..2007cb3 100644 --- a/src/cleanlab_codex/types/tlm.py +++ b/src/cleanlab_codex/types/tlm.py @@ -1,18 +1,17 @@ -from typing import Any, Dict, Protocol, Sequence, Union, runtime_checkable - - -@runtime_checkable -class TLM(Protocol): - def get_trustworthiness_score( - self, - prompt: Union[str, Sequence[str]], - response: Union[str, Sequence[str]], - **kwargs: Any, - ) -> Dict[str, Any]: ... - - def prompt( - self, - prompt: Union[str, Sequence[str]], - /, - **kwargs: Any, - ) -> Dict[str, Any]: ... +"""Types for Codex TLM endpoint.""" + +from codex.types.tlm_score_response import TlmScoreResponse as _TlmScoreResponse + +from cleanlab_codex.internal.utils import generate_class_docstring + + +class TlmScoreResponse(_TlmScoreResponse): ... + + +TlmScoreResponse.__doc__ = f""" +Type representing a TLM score response in a Codex project. This is the complete data structure returned from the Codex API, including system-generated fields like ID and timestamps. + +{generate_class_docstring(_TlmScoreResponse, name=TlmScoreResponse.__name__)} +""" + +__all__ = ["TlmScoreResponse"] diff --git a/tests/conftest.py b/tests/conftest.py index 364c053..d54ec4b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,3 @@ -from tests.fixtures.client import mock_client_from_access_key, mock_client_from_api_key +from tests.fixtures.client import mock_client_from_access_key, mock_client_from_access_key_tlm, mock_client_from_api_key -__all__ = ["mock_client_from_access_key", "mock_client_from_api_key"] +__all__ = ["mock_client_from_access_key", "mock_client_from_api_key", "mock_client_from_access_key_tlm"] diff --git a/tests/fixtures/client.py b/tests/fixtures/client.py index 16be75e..d5442c8 100644 --- a/tests/fixtures/client.py +++ b/tests/fixtures/client.py @@ -18,3 +18,11 @@ def mock_client_from_api_key() -> Generator[MagicMock, None, None]: mock_client = MagicMock() mock_init.return_value = mock_client yield mock_client + + +@pytest.fixture +def mock_client_from_access_key_tlm() -> Generator[MagicMock, None, None]: + with patch("cleanlab_codex.response_validation.client_from_access_key") as mock_init: + mock_client = MagicMock() + mock_init.return_value = mock_client + yield mock_client diff --git a/tests/test_response_validation.py b/tests/test_response_validation.py index 3d2e1d4..47795f1 100644 --- a/tests/test_response_validation.py +++ b/tests/test_response_validation.py @@ -2,13 +2,14 @@ from __future__ import annotations -from typing import Any, Dict, Sequence, Union -from unittest.mock import Mock, patch +from typing import Callable +from unittest.mock import MagicMock, Mock, patch import pytest from cleanlab_codex.response_validation import ( _DEFAULT_UNHELPFULNESS_CONFIDENCE_THRESHOLD, + MissingAuthError, is_bad_response, is_fallback_response, is_unhelpful_response, @@ -20,48 +21,7 @@ BAD_RESPONSE = "Based on the available information, I cannot provide a complete answer." QUERY = "What is the capital of France?" CONTEXT = "Paris is the capital and largest city of France." - - -class MockTLM(Mock): - _trustworthiness_score: float = 0.8 - _response: str = "No" - - @property - def trustworthiness_score(self) -> float: - return self._trustworthiness_score - - @trustworthiness_score.setter - def trustworthiness_score(self, value: float) -> None: - self._trustworthiness_score = value - - @property - def response(self) -> str: - return self._response - - @response.setter - def response(self, value: str) -> None: - self._response = value - - def get_trustworthiness_score( - self, - prompt: Union[str, Sequence[str]], # noqa: ARG002 - response: Union[str, Sequence[str]], # noqa: ARG002 - **kwargs: Any, # noqa: ARG002 - ) -> Dict[str, Any]: - return {"trustworthiness_score": self._trustworthiness_score} - - def prompt( - self, - prompt: Union[str, Sequence[str]], # noqa: ARG002 - /, - **kwargs: Any, # noqa: ARG002 - ) -> Dict[str, Any]: - return {"response": self._response, "trustworthiness_score": self._trustworthiness_score} - - -@pytest.fixture -def mock_tlm() -> MockTLM: - return MockTLM() +DUMMY_ACCESS_KEY = "sk-1-EMOh6UrRo7exTEbEi8_azzACAEdtNiib2LLa1IGo6kA" @pytest.mark.parametrize( @@ -95,15 +55,25 @@ def test_is_fallback_response( assert is_fallback_response(response, **kwargs) is expected # type: ignore -def test_is_untrustworthy_response(mock_tlm: Mock) -> None: +def test_is_untrustworthy_response(mock_client_from_access_key_tlm: MagicMock) -> None: """Test untrustworthy response detection.""" # Test trustworthy response - mock_tlm.trustworthiness_score = 0.8 - assert is_untrustworthy_response(GOOD_RESPONSE, CONTEXT, QUERY, mock_tlm, trustworthiness_threshold=0.5) is False + mock_client_from_access_key_tlm.tlm.score.return_value = {"trustworthiness_score": 0.8} + assert ( + is_untrustworthy_response( + GOOD_RESPONSE, CONTEXT, QUERY, tlm_options=None, trustworthiness_threshold=0.5, codex_key=DUMMY_ACCESS_KEY + ) + is False + ) # Test untrustworthy response - mock_tlm.trustworthiness_score = 0.3 - assert is_untrustworthy_response(BAD_RESPONSE, CONTEXT, QUERY, mock_tlm, trustworthiness_threshold=0.5) is True + mock_client_from_access_key_tlm.tlm.score.return_value = {"trustworthiness_score": 0.3} + assert ( + is_untrustworthy_response( + BAD_RESPONSE, CONTEXT, QUERY, tlm_options=None, trustworthiness_threshold=0.5, codex_key=DUMMY_ACCESS_KEY + ) + is True + ) @pytest.mark.parametrize( @@ -124,7 +94,7 @@ def test_is_untrustworthy_response(mock_tlm: Mock) -> None: ], ) def test_is_unhelpful_response( - mock_tlm: Mock, + mock_client_from_access_key_tlm: MagicMock, tlm_score: float, threshold: float | None, *, @@ -136,13 +106,15 @@ def test_is_unhelpful_response( This may seem counter-intuitive, but higher scores indicate more similar responses to known unhelpful patterns. """ - mock_tlm.trustworthiness_score = tlm_score + mock_client_from_access_key_tlm.tlm.score.return_value = {"trustworthiness_score": tlm_score} # The response content doesn't affect the result, only the score matters if threshold is not None: - result = is_unhelpful_response(GOOD_RESPONSE, QUERY, mock_tlm, confidence_score_threshold=threshold) + result = is_unhelpful_response( + GOOD_RESPONSE, QUERY, tlm_options=None, confidence_score_threshold=threshold, codex_key=DUMMY_ACCESS_KEY + ) else: - result = is_unhelpful_response(GOOD_RESPONSE, QUERY, mock_tlm) + result = is_unhelpful_response(GOOD_RESPONSE, QUERY, tlm_options=None, codex_key=DUMMY_ACCESS_KEY) assert result is expected_unhelpful @@ -157,7 +129,7 @@ def test_is_unhelpful_response( ], ) def test_is_bad_response( - mock_tlm: Mock, + mock_client_from_access_key_tlm: MagicMock, response: str, trustworthiness_score: float, prompt_score: float, @@ -166,42 +138,33 @@ def test_is_bad_response( ) -> None: """Test the main is_bad_response function.""" # Create a new Mock object for get_trustworthiness_score - mock_tlm.get_trustworthiness_score = Mock(return_value={"trustworthiness_score": trustworthiness_score}) + mock_client_from_access_key_tlm.tlm.score.return_value = {"trustworthiness_score": trustworthiness_score} # Set up the second call to return prompt_score - mock_tlm.get_trustworthiness_score.side_effect = [ + mock_client_from_access_key_tlm.tlm.score.side_effect = [ {"trustworthiness_score": trustworthiness_score}, # Should be called by is_untrustworthy_response {"trustworthiness_score": prompt_score}, # Should be called by is_unhelpful_response ] - assert ( - is_bad_response( - response, - context=CONTEXT, - query=QUERY, - config={"tlm": mock_tlm}, - ) - is expected - ) + assert is_bad_response(response, context=CONTEXT, query=QUERY, config={"codex_key": DUMMY_ACCESS_KEY}) is expected @pytest.mark.parametrize( - ("response", "fuzz_ratio", "prompt_score", "query", "tlm", "expected"), + ("response", "fuzz_ratio", "prompt_score", "query", "expected"), [ # Test with only fallback check (no context/query/tlm) - (BAD_RESPONSE, 90, None, None, None, True), + (BAD_RESPONSE, 90, None, None, True), # Test with fallback and unhelpful checks (no context) - (GOOD_RESPONSE, 30, 0.1, QUERY, "mock_tlm", False), + (GOOD_RESPONSE, 30, 0.1, QUERY, False), # Test with fallback and unhelpful checks (with context) (prompt_score is above threshold) - (GOOD_RESPONSE, 30, 0.6, QUERY, "mock_tlm", True), + (GOOD_RESPONSE, 30, 0.6, QUERY, True), ], ) def test_is_bad_response_partial_inputs( - mock_tlm: Mock, + mock_client_from_access_key_tlm: MagicMock, response: str, fuzz_ratio: int, prompt_score: float, query: str, - tlm: Mock, *, expected: bool, ) -> None: @@ -210,14 +173,67 @@ def test_is_bad_response_partial_inputs( mock_fuzz.partial_ratio.return_value = fuzz_ratio with patch.dict("sys.modules", {"thefuzz": Mock(fuzz=mock_fuzz)}): if prompt_score is not None: - mock_tlm.trustworthiness_score = prompt_score - tlm = mock_tlm + mock_client_from_access_key_tlm.tlm.score.return_value = {"trustworthiness_score": prompt_score} assert ( is_bad_response( response, query=query, - config={"tlm": tlm}, + config={"codex_key": DUMMY_ACCESS_KEY}, ) is expected ) + + +@pytest.mark.parametrize( + "method", + [ + lambda: is_unhelpful_response(response="", query=""), + lambda: is_untrustworthy_response(response="", context="", query=""), + ], +) +def test_tlm_access_with_no_access_key_and_no_auth( + method: Callable[[], bool], +) -> None: + """Test that attempting to access is_unhelpful_response and is_untrustworthy_response without a valid key fails.""" + with pytest.raises( + MissingAuthError, + match="A valid Codex API key or access key must be provided when using the TLM for untrustworthy or unhelpfulness checks.", + ): + method() + + +@pytest.mark.parametrize( + "method", + [ + lambda key: is_unhelpful_response(response="", query="", codex_key=key), + lambda key: is_untrustworthy_response(response="", context="", query="", codex_key=key), + ], +) +def test_tlm_access_with_bad_api_key( + method: Callable[[str], bool], +) -> None: + """Test that attempting to access is_unhelpful_response and is_untrustworthy_response without a valid key fails.""" + with pytest.raises( + MissingAuthError, + match="A valid Codex API key or access key must be provided when using the TLM for untrustworthy or unhelpfulness checks.", + ): + method("MY-API-KEY") + + +@pytest.mark.parametrize( + "method", + [ + lambda key: is_unhelpful_response(response="", query="", codex_key=key), + lambda key: is_untrustworthy_response(response="", context="", query="", codex_key=key), + ], +) +def test_tlm_access_with_bad_access_key( + method: Callable[[str], bool], +) -> None: + """Test that attempting to access is_unhelpful_response and is_untrustworthy_response without a valid key fails.""" + with pytest.raises( + MissingAuthError, + match="A valid Codex API key or access key must be provided when using the TLM for untrustworthy or unhelpfulness checks.", + ): + method(DUMMY_ACCESS_KEY)