diff --git a/captum/attr/__init__.py b/captum/attr/__init__.py index a33cd862d..e0d9b5bd4 100644 --- a/captum/attr/__init__.py +++ b/captum/attr/__init__.py @@ -27,6 +27,7 @@ LLMAttribution, LLMAttributionResult, LLMGradientAttribution, + RemoteLLMAttribution, ) from captum.attr._core.lrp import LRP from captum.attr._core.neuron.neuron_conductance import NeuronConductance @@ -43,6 +44,7 @@ ) from captum.attr._core.noise_tunnel import NoiseTunnel from captum.attr._core.occlusion import Occlusion +from captum.attr._core.remote_provider import RemoteLLMProvider, VLLMProvider from captum.attr._core.saliency import Saliency from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling from captum.attr._models.base import ( @@ -111,6 +113,9 @@ "LLMAttribution", "LLMAttributionResult", "LLMGradientAttribution", + "RemoteLLMAttribution", + "RemoteLLMProvider", + "VLLMProvider", "InternalInfluence", "InterpretableInput", "LayerGradCam", diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index 3466ad499..bc2e477ae 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -23,6 +23,7 @@ from captum.attr._core.layer.layer_gradient_x_activation import LayerGradientXActivation from captum.attr._core.layer.layer_integrated_gradients import LayerIntegratedGradients from captum.attr._core.lime import Lime +from captum.attr._core.remote_provider import RemoteLLMProvider from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling from captum.attr._utils.attribution import ( Attribution, @@ -892,3 +893,158 @@ def forward( # the attribution target is limited to the log probability return token_log_probs + + +class _PlaceholderModel: + """ + Simple placeholder model that can be used with + RemoteLLMAttribution without needing a real model. + This can be acheived by `lambda *_:0` but BaseLLMAttribution expects + `device`, so creating this class to set the device. + """ + + def __init__(self) -> None: + self.device: Union[torch.device, str] = torch.device("cpu") + + def __call__(self, *args: Any, **kwargs: Any) -> int: + return 0 + + +class RemoteLLMAttribution(LLMAttribution): + """ + Attribution class for large language models + that are hosted remotely and offer logprob APIs. + """ + + placeholder_model = _PlaceholderModel() + + def __init__( + self, + attr_method: PerturbationAttribution, + tokenizer: TokenizerLike, + provider: RemoteLLMProvider, + attr_target: str = "log_prob", + ) -> None: + """ + Args: + attr_method: Instance of a supported perturbation attribution class + tokenizer (Tokenizer): tokenizer of the llm model used in the attr_method + provider: Remote LLM provider that implements the RemoteLLMProvider protocol + attr_target: attribute towards log probability or probability. + Available values ["log_prob", "prob"] + Default: "log_prob" + """ + super().__init__( + attr_method=attr_method, + tokenizer=tokenizer, + attr_target=attr_target, + ) + + self.provider = provider + self.attr_method.forward_func = self._remote_forward_func + + def _get_target_tokens( + self, + inp: InterpretableInput, + target: Union[str, torch.Tensor, None] = None, + skip_tokens: Union[List[int], List[str], None] = None, + gen_args: Optional[Dict[str, Any]] = None, + ) -> Tensor: + """ + Get the target tokens for the remote LLM provider. + """ + assert isinstance( + inp, self.SUPPORTED_INPUTS + ), f"RemoteLLMAttribution does not support input type {type(inp)}" + + if target is None: + # generate when None with remote provider + assert hasattr(self.provider, "generate") and callable( + self.provider.generate + ), ( + "The provider does not have generate function" + " for generating target sequence." + "Target must be given for attribution" + ) + if not gen_args: + gen_args = DEFAULT_GEN_ARGS + + model_inp = self._format_remote_model_input(inp.to_model_input()) + target_str = self.provider.generate(model_inp, **gen_args) + target_tokens = self.tokenizer.encode( + target_str, return_tensors="pt", add_special_tokens=False + )[0] + + else: + target_tokens = super()._get_target_tokens( + inp, target, skip_tokens, gen_args + ) + + return target_tokens + + def _format_remote_model_input(self, model_input: Union[str, Tensor]) -> str: + """ + Format the model input for the remote LLM provider. + Convert tokenized tensor to str + to make RemoteLLMAttribution work with model inputs of both + raw text and text token tensors + """ + # return str input + if isinstance(model_input, Tensor): + return self.tokenizer.decode(model_input.flatten()) + return model_input + + def _remote_forward_func( + self, + perturbed_tensor: Union[None, Tensor], + inp: InterpretableInput, + target_tokens: Tensor, + use_cached_outputs: bool = False, + _inspect_forward: Optional[Callable[[str, str, List[float]], None]] = None, + ) -> Tensor: + """ + Forward function for the remote LLM provider. + + Raises: + ValueError: If the number of token logprobs doesn't match expected length + """ + perturbed_input = self._format_remote_model_input( + inp.to_model_input(perturbed_tensor) + ) + + target_str: str = self.tokenizer.decode(target_tokens) + + target_token_probs = self.provider.get_logprobs( + input_prompt=perturbed_input, + target_str=target_str, + tokenizer=self.tokenizer, + ) + + if len(target_token_probs) != target_tokens.size()[0]: + raise ValueError( + f"Number of token logprobs from provider ({len(target_token_probs)}) " + f"does not match expected target " + f"token length ({target_tokens.size()[0]})" + ) + + log_prob_list: List[Tensor] = list(map(torch.tensor, target_token_probs)) + + total_log_prob = torch.sum(torch.stack(log_prob_list), dim=0) + # 1st element is the total prob, rest are the target tokens + # add a leading dim for batch even we only support single instance for now + if self.include_per_token_attr: + target_log_probs = torch.stack( + [total_log_prob, *log_prob_list], dim=0 + ).unsqueeze(0) + else: + target_log_probs = total_log_prob + target_probs = torch.exp(target_log_probs) + + if _inspect_forward: + prompt = perturbed_input + response = self.tokenizer.decode(target_tokens) + + # callback for externals to inspect (prompt, response, seq_prob) + _inspect_forward(prompt, response, target_probs[0].tolist()) + + return target_probs if self.attr_target != "log_prob" else target_log_probs diff --git a/captum/attr/_core/remote_provider.py b/captum/attr/_core/remote_provider.py new file mode 100644 index 000000000..db2dc2a40 --- /dev/null +++ b/captum/attr/_core/remote_provider.py @@ -0,0 +1,241 @@ +import logging +import os +from abc import ABC, abstractmethod +from typing import Any, List, Optional + +from captum._utils.typing import TokenizerLike + +logger = logging.getLogger(__name__) + + +class RemoteLLMProvider(ABC): + """All remote LLM providers that offer logprob via API + (like vLLM) extends this class.""" + + api_url: str + + @abstractmethod + def generate(self, prompt: str, **gen_args: Any) -> str: + """ + Args: + prompt: The input prompt to generate from + gen_args: Additional generation arguments + + Returns: + The generated text. + """ + ... + + @abstractmethod + def get_logprobs( + self, + input_prompt: str, + target_str: str, + tokenizer: Optional[TokenizerLike] = None, + ) -> List[float]: + """ + Get the log probabilities for all tokens in the target string. + + Args: + input_prompt: The input prompt + target_str: The target string + tokenizer: The tokenizer to use + + Returns: + A list of log probabilities corresponding to + each token in the target prompt. + For a `target_str` of `t` tokens, + this method returns a list of logprobs of length `k`. + """ + ... + + +class VLLMProvider(RemoteLLMProvider): + def __init__(self, api_url: str, model_name: Optional[str] = None): + """ + Initialize a vLLM provider. + + Args: + api_url: The URL of the vLLM API + model_name: The name of the model to use. If None, the first model from + the API's model list will be used. + + Environment Variables: + OPENAI_API_KEY: If not set, "EMPTY" will be used as the API key. + + Raises: + ValueError: If api_url is empty or model_name is not in the API's model list + ConnectionError: If API connection fails + ImportError: If the openai package is not installed + """ + try: + from openai import OpenAI + except ImportError: + raise ImportError( + "The 'openai' package is required to use the VLLMProvider." + "You can install it by either:\n" + "1. Installing captum with remote dependencies: " + "`pip install captum[remote]` OR\n" + "2. Installing openai directly: `pip install openai`" + ) + + if not api_url.strip(): + raise ValueError("API URL is required") + + self.api_url = api_url + + try: + self.client = OpenAI( + base_url=self.api_url, api_key=os.getenv("OPENAI_API_KEY", "EMPTY") + ) + + # If model_name is not provided, get the first available model from the API + if model_name is None: + models = self.client.models.list().data + if not models: + raise ValueError("No models available from the vLLM API") + self.model_name = models[0].id + logger.info( + f"No model_name is specified for VLLMProvider." + f" Using first available model: {self.model_name}" + ) + else: + self.model_name = model_name + + except ValueError: + raise + except ConnectionError as e: + raise ConnectionError(f"Failed to connect to vLLM API: {str(e)}") + except Exception as e: + raise Exception( + f"Unexpected error while initializing vLLM provider: {str(e)}" + ) + + def generate(self, prompt: str, **gen_args: Any) -> str: + """ + Generate text using the vLLM API. + + Args: + prompt: The input prompt for text generation + **gen_args: Additional generation arguments. Supported arguments include: + - max_tokens: Maximum number of tokens to generate (default: 25) + - max_new_tokens: Alternative to max_tokens + (will be converted to max_tokens) + - temperature, top_p, etc.: Other generation parameters + supported by the OpenAI API + + Returns: + str: The generated text + + Raises: + KeyError: If API response is missing expected data + ConnectionError: If connection to API fails + """ + # Parameter normalization + if "max_tokens" not in gen_args: + gen_args["max_tokens"] = gen_args.pop("max_new_tokens", 25) + if "do_sample" in gen_args: + gen_args.pop("do_sample") + + try: + response = self.client.completions.create( + model=self.model_name, prompt=prompt, **gen_args + ) + if not hasattr(response, "choices") or not response.choices: + raise KeyError("API response missing expected 'choices' data") + + return response.choices[0].text + + except KeyError: + raise + except ConnectionError as e: + raise ConnectionError(f"Failed to connect to vLLM API: {str(e)}") + except Exception as e: + raise Exception(f"Unexpected error during text generation: {str(e)}") + + def get_logprobs( + self, + input_prompt: str, + target_str: str, + tokenizer: Optional[TokenizerLike] = None, + ) -> List[float]: + """ + Get the log probabilities for all tokens in the target string. + + Args: + input_prompt: The input prompt + target_str: The target string + tokenizer: The tokenizer to use + + Returns: + A list of log probabilities corresponding to each token + in the target prompt. + For a `target_str` of `t` tokens, this method returns + a list of logprobs of length `t`. + + Raises: + ValueError: If tokenizer is None or target_str is empty + or response format is invalid + KeyError: If API response is missing expected data + IndexError: If response format is unexpected + ConnectionError: If connection to API fails + """ + if tokenizer is None: + raise ValueError("Tokenizer is required for vLLM provider") + if not target_str: + raise ValueError("Target string cannot be empty") + + num_target_str_tokens = len( + tokenizer.encode(target_str, add_special_tokens=False) + ) + + prompt = input_prompt + target_str + + try: + response = self.client.completions.create( + model=self.model_name, + prompt=prompt, + temperature=0.0, + max_tokens=1, + extra_body={"prompt_logprobs": 0}, + ) + + if not hasattr(response, "choices") or not response.choices: + raise KeyError("API response missing expected 'choices' data") + + if ( + not hasattr(response.choices[0], "prompt_logprobs") + or not response.choices[0].prompt_logprobs + ): + raise KeyError("API response missing 'prompt_logprobs' data") + + prompt_logprobs = [] + try: + for probs in response.choices[0].prompt_logprobs[ + -num_target_str_tokens: + ]: + if not probs: + raise ValueError("Empty probability data in API response") + assert len(probs) == 1, "Expected exactly one token in logprobs" + prompt_logprobs.append(next(iter(probs.values()))["logprob"]) + except (IndexError, KeyError) as e: + raise IndexError(f"Unexpected format in log probability data: {str(e)}") + + if len(prompt_logprobs) != num_target_str_tokens: + raise ValueError( + f"Not enough logprobs received:" + f"expected {num_target_str_tokens}, got {len(prompt_logprobs)}" + ) + + return prompt_logprobs + + except (KeyError, IndexError, ValueError): + raise + except ConnectionError as e: + raise ConnectionError( + f"Failed to connect to vLLM API when getting logprobs: {str(e)}" + ) + except Exception as e: + raise Exception( + f"Unexpected error while getting log probabilities: {str(e)}" + ) diff --git a/scripts/install_via_conda.sh b/scripts/install_via_conda.sh index 78426c54f..709d5d5ff 100755 --- a/scripts/install_via_conda.sh +++ b/scripts/install_via_conda.sh @@ -21,7 +21,7 @@ conda install -q -y pytorch cpuonly -c pytorch # install other deps conda install -q -y pytest ipywidgets ipython scikit-learn parameterized werkzeug -conda install -q -y -c conda-forge matplotlib pytest-cov flask flask-compress conda-build +conda install -q -y -c conda-forge matplotlib pytest-cov flask flask-compress conda-build openai conda install -q -y transformers # install captum diff --git a/setup.py b/setup.py old mode 100755 new mode 100644 index 38cb97d5b..2f473b532 --- a/setup.py +++ b/setup.py @@ -63,9 +63,12 @@ def report(*args): TEST_REQUIRES = ["pytest", "pytest-cov", "parameterized", "flask", "flask-compress"] +REMOTE_REQUIRES = ["openai"] + DEV_REQUIRES = ( INSIGHTS_REQUIRES + TEST_REQUIRES + + REMOTE_REQUIRES + [ "black", "flake8", @@ -169,6 +172,7 @@ def get_package_files(root, subdirs): "insights": INSIGHTS_REQUIRES, "test": TEST_REQUIRES, "tutorials": TUTORIALS_REQUIRES, + "remote": REMOTE_REQUIRES, }, package_data={"captum": package_files}, data_files=[ diff --git a/tests/attr/test_llm_attr.py b/tests/attr/test_llm_attr.py index 8f790870c..bbdeced56 100644 --- a/tests/attr/test_llm_attr.py +++ b/tests/attr/test_llm_attr.py @@ -18,17 +18,23 @@ Type, Union, ) +from unittest.mock import MagicMock, patch import torch from captum._utils.models.linear_model import SkLearnLasso -from captum._utils.typing import BatchEncodingType +from captum._utils.typing import BatchEncodingType, TokenizerLike from captum.attr._core.feature_ablation import FeatureAblation from captum.attr._core.kernel_shap import KernelShap from captum.attr._core.layer.layer_gradient_shap import LayerGradientShap from captum.attr._core.layer.layer_gradient_x_activation import LayerGradientXActivation from captum.attr._core.layer.layer_integrated_gradients import LayerIntegratedGradients from captum.attr._core.lime import Lime -from captum.attr._core.llm_attr import LLMAttribution, LLMGradientAttribution +from captum.attr._core.llm_attr import ( + LLMAttribution, + LLMGradientAttribution, + RemoteLLMAttribution, +) +from captum.attr._core.remote_provider import RemoteLLMProvider, VLLMProvider from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling from captum.attr._utils.attribution import GradientAttribution, PerturbationAttribution from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput @@ -664,3 +670,708 @@ def test_llm_attr_with_skip_tensor_target(self) -> None: self.assertEqual(token_attr.shape, (5, 4)) self.assertEqual(res.input_tokens, ["", "a", "b", "c"]) self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"]) + + +class TestVLLMProvider(BaseTest): + """Test suite for VLLMProvider class.""" + + def setUp(self) -> None: + super().setUp() + self.api_url = "https://test-vllm-api.com" + self.model_name = "test-model" + self.input_prompt = "a b c d" + self.target_str = "e f g h" + + self.tokenizer = DummyTokenizer() + + # Set up patch for OpenAI import + self.openai_patcher = patch("openai.OpenAI") + self.mock_openai = self.openai_patcher.start() + + # Create a mock OpenAI client + self.mock_client = MagicMock() + self.mock_openai.return_value = self.mock_client + + def tearDown(self) -> None: + self.openai_patcher.stop() + super().tearDown() + + def test_init_successful(self) -> None: + """Test successful initialization of VLLMProvider.""" + model_name: str = "default-model" + + # Mock the models.list() response + mock_models_data = [MagicMock(id=model_name)] + self.mock_client.models.list.return_value = MagicMock(data=mock_models_data) + + # Create provider without specifying model name + provider = VLLMProvider(api_url=self.api_url) + + # Verify the client was initialized correctly + self.mock_openai.assert_called_once() + self.assertEqual(provider.api_url, self.api_url) + self.assertEqual(provider.model_name, model_name) + + # Verify models.list() was called + self.mock_client.models.list.assert_called_once() + + def test_init_with_model_name(self) -> None: + """Test initialization with specific model name.""" + provider = VLLMProvider(api_url=self.api_url, model_name=self.model_name) + + # Verify model name was set correctly + self.assertEqual(provider.model_name, self.model_name) + + # Verify models.list() was NOT called + self.mock_client.models.list.assert_not_called() + + def test_init_empty_api_url(self) -> None: + """Test initialization with empty API URL raises ValueError.""" + with self.assertRaises(ValueError) as context: + VLLMProvider(api_url=" ") + + self.assertIn("API URL is required", str(context.exception)) + + def test_init_connection_error(self) -> None: + """Test initialization handling connection error.""" + # Mock connection error + self.mock_openai.side_effect = ConnectionError("Failed to connect") + + with self.assertRaises(ConnectionError) as context: + VLLMProvider(api_url=self.api_url) + + self.assertIn("Failed to connect to vLLM API", str(context.exception)) + + def test_init_no_models(self) -> None: + """Test initialization when no models are available.""" + # Mock empty models list + self.mock_client.models.list.return_value = MagicMock(data=[]) + + with self.assertRaises(ValueError) as context: + VLLMProvider(api_url=self.api_url) + + self.assertIn("No models available", str(context.exception)) + + def test_generate_successful(self) -> None: + """Test successful text generation.""" + # Set up mock response + mock_choice = MagicMock() + mock_choice.text = self.target_str + mock_response = MagicMock() + mock_response.choices = [mock_choice] + self.mock_client.completions.create.return_value = mock_response + + # Create provider + provider = VLLMProvider(api_url=self.api_url, model_name=self.model_name) + provider.client = self.mock_client + + # Call generate + result = provider.generate(self.input_prompt, max_tokens=10) + + # Verify result + self.assertEqual(result, self.target_str) + + # Verify API was called with correct parameters + self.mock_client.completions.create.assert_called_once_with( + model=self.model_name, prompt=self.input_prompt, max_tokens=10 + ) + + def test_generate_with_max_new_tokens(self) -> None: + """Test generation with max_new_tokens parameter.""" + # Set up mock response + mock_choice = MagicMock() + mock_choice.text = self.target_str + mock_response = MagicMock() + mock_response.choices = [mock_choice] + self.mock_client.completions.create.return_value = mock_response + + # Create provider + provider = VLLMProvider(api_url=self.api_url, model_name=self.model_name) + provider.client = self.mock_client + + # Call generate with max_new_tokens instead of max_tokens + _ = provider.generate(self.input_prompt, max_new_tokens=10) + + # Verify API was called with converted max_tokens parameter + self.mock_client.completions.create.assert_called_once_with( + model=self.model_name, prompt=self.input_prompt, max_tokens=10 + ) + + def test_generate_empty_choices(self) -> None: + """Test generation when response has empty choices.""" + # Set up mock response with empty choices + mock_response = MagicMock() + mock_response.choices = [] + self.mock_client.completions.create.return_value = mock_response + + # Create provider + provider = VLLMProvider(api_url=self.api_url, model_name=self.model_name) + provider.client = self.mock_client + + # Call generate and expect exception + with self.assertRaises(KeyError) as context: + provider.generate(self.input_prompt) + + self.assertIn( + "API response missing expected 'choices' data", str(context.exception) + ) + + def test_generate_connection_error(self) -> None: + """Test generation handling connection error.""" + # Mock connection error + self.mock_client.completions.create.side_effect = ConnectionError( + "Connection failed" + ) + + # Create provider + provider = VLLMProvider(api_url=self.api_url, model_name=self.model_name) + provider.client = self.mock_client + + # Call generate and expect exception + with self.assertRaises(ConnectionError) as context: + provider.generate(self.input_prompt) + + self.assertIn("Failed to connect to vLLM API", str(context.exception)) + + def test_get_logprobs_successful(self) -> None: + """Test successful retrieval of log probabilities.""" + # Set up test data + input_token_ids = self.tokenizer.encode( + self.input_prompt, add_special_tokens=False + ) + num_input_tokens = len(input_token_ids) + + target_token_ids = self.tokenizer.encode( + self.target_str, add_special_tokens=False + ) + expected_values = [0.1, 0.2, 0.3, 0.4] + num_target_tokens = len(target_token_ids) + + # Create mock vLLM response with prompt_logprobs + prompt_logprobs: List[Dict[str, Dict[str, Any]]] = [] + for i in range(num_input_tokens): + token_probs = { + str(input_token_ids[i]): { + "logprob": -0.5, # fixed logprob for input tokens (for testing) + "rank": i + 1, + "decoded_token": self.tokenizer.convert_ids_to_tokens( + input_token_ids[i] + ), + } + } + prompt_logprobs.append(token_probs) + for i in range(num_target_tokens): + token_probs = { + str(target_token_ids[i]): { + "logprob": expected_values[i], + "rank": i + 1, + "decoded_token": self.tokenizer.convert_ids_to_tokens( + target_token_ids[i] + ), + } + } + prompt_logprobs.append(token_probs) + + mock_choices = MagicMock() + # prompt_logprobs will be of length + # num_input_tokens + num_target_tokens + mock_choices.prompt_logprobs = prompt_logprobs + mock_response = MagicMock() + mock_response.choices = [mock_choices] + self.mock_client.completions.create.return_value = mock_response + + # Create provider and call get_logprobs + provider = VLLMProvider(api_url=self.api_url, model_name=self.model_name) + provider.client = self.mock_client + + logprobs = provider.get_logprobs( + self.input_prompt, self.target_str, self.tokenizer + ) + + # Verify API call + self.mock_client.completions.create.assert_called_once_with( + model=self.model_name, + prompt=self.input_prompt + self.target_str, + temperature=0.0, + max_tokens=1, + extra_body={"prompt_logprobs": 0}, + ) + + # Verify results + self.assertEqual(len(logprobs), num_target_tokens) + for i, logprob in enumerate(logprobs): + self.assertEqual(logprob, expected_values[i]) + + def test_get_logprobs_missing_tokenizer(self) -> None: + """Test get_logprobs with missing tokenizer.""" + with self.assertRaises(ValueError) as context: + provider = VLLMProvider(api_url=self.api_url, model_name=self.model_name) + provider.get_logprobs(self.input_prompt, self.target_str, None) + + self.assertIn("Tokenizer is required", str(context.exception)) + + def test_get_logprobs_empty_target(self) -> None: + """Test get_logprobs with empty target string.""" + with self.assertRaises(ValueError) as context: + provider = VLLMProvider(api_url=self.api_url, model_name=self.model_name) + provider.get_logprobs(self.input_prompt, "", self.tokenizer) + + self.assertIn("Target string cannot be empty", str(context.exception)) + + def test_get_logprobs_missing_prompt_logprobs(self) -> None: + """Test get_logprobs when response is missing prompt_logprobs.""" + # Set up mock response without prompt_logprobs + mock_choices = MagicMock() + mock_choices.prompt_logprobs = None + mock_response = MagicMock() + mock_response.choices = [mock_choices] + self.mock_client.completions.create.return_value = mock_response + + # Create provider + provider = VLLMProvider(api_url=self.api_url, model_name=self.model_name) + provider.client = self.mock_client + + with self.assertRaises(KeyError) as context: + provider.get_logprobs(self.input_prompt, self.target_str, self.tokenizer) + + self.assertIn( + "API response missing 'prompt_logprobs' data", str(context.exception) + ) + + def test_get_logprobs_empty_probs(self) -> None: + """Test get_logprobs with empty probability data.""" + # Create mock response with empty probs + prompt_logprobs: List[Dict[str, Dict[str, Any]]] = [ + {} + ] # Empty dict for token probabilities + mock_choices = MagicMock() + mock_choices.prompt_logprobs = prompt_logprobs + mock_response = MagicMock() + mock_response.choices = [mock_choices] + self.mock_client.completions.create.return_value = mock_response + + # Create provider + provider = VLLMProvider(api_url=self.api_url, model_name=self.model_name) + provider.client = self.mock_client + + with self.assertRaises(ValueError) as context: + provider.get_logprobs(self.input_prompt, self.target_str, self.tokenizer) + + self.assertIn("Empty probability data", str(context.exception)) + + def test_get_logprobs_keyerror(self) -> None: + """Test get_logprobs with missing 'logprob' key in response.""" + # Create mock response with invalid token prob structure + prompt_logprobs: List[Dict[str, Dict[str, Any]]] = [ + {"1": {"wrong_logprob_key": 0.1, "rank": 1, "decoded_token": "a"}}, + {"2": {"wrong_logprob_key": 0.2, "rank": 1, "decoded_token": "b"}}, + ] + mock_choices = MagicMock() + mock_choices.prompt_logprobs = prompt_logprobs + mock_response = MagicMock() + mock_response.choices = [mock_choices] + self.mock_client.completions.create.return_value = mock_response + + # Create provider + provider = VLLMProvider(api_url=self.api_url, model_name=self.model_name) + provider.client = self.mock_client + + with self.assertRaises(IndexError) as context: + provider.get_logprobs("a", "b", self.tokenizer) + + self.assertIn( + "Unexpected format in log probability data", str(context.exception) + ) + + def test_get_logprobs_length_mismatch(self) -> None: + """Test get_logprobs with length mismatch + between expected and received tokens.""" + # Create mock response with only 1 logprobs (fewer than expected) + prompt_logprobs = [{"1": {"logprob": 0.1, "rank": 1, "decoded_token": "a"}}] + mock_choices = MagicMock() + mock_choices.prompt_logprobs = prompt_logprobs + mock_response = MagicMock() + mock_response.choices = [mock_choices] + self.mock_client.completions.create.return_value = mock_response + + # Create provider + provider = VLLMProvider(api_url=self.api_url, model_name=self.model_name) + provider.client = self.mock_client + + with self.assertRaises(ValueError) as context: + provider.get_logprobs(self.input_prompt, self.target_str, self.tokenizer) + + self.assertIn("Not enough logprobs received", str(context.exception)) + + def test_get_logprobs_connection_error(self) -> None: + """Test get_logprobs handling connection error.""" + # Mock connection error + self.mock_client.completions.create.side_effect = ConnectionError( + "Connection failed" + ) + + # Create provider + provider = VLLMProvider(api_url=self.api_url, model_name=self.model_name) + provider.client = self.mock_client + + with self.assertRaises(ConnectionError) as context: + provider.get_logprobs(self.input_prompt, self.target_str, self.tokenizer) + + self.assertIn( + "Failed to connect to vLLM API when getting logprobs", + str(context.exception), + ) + + +class DummyRemoteLLMProvider(RemoteLLMProvider): + def __init__(self, deterministic_logprobs: bool = False) -> None: + self.api_url = "https://test-api.com" + self.deterministic_logprobs = deterministic_logprobs + + def generate(self, prompt: str, **gen_args: Any) -> str: + assert ( + "mock_response" in gen_args + ), "must mock response to use DummyRemoteLLMProvider to generate" + return gen_args["mock_response"] + + def get_logprobs( + self, + input_prompt: str, + target_str: str, + tokenizer: Optional[TokenizerLike] = None, + ) -> List[float]: + assert tokenizer is not None, "Tokenizer is required" + prompt = input_prompt + target_str + tokens = tokenizer.encode(prompt, add_special_tokens=False) + num_tokens = len(tokens) + + num_target_str_tokens = len( + tokenizer.encode(target_str, add_special_tokens=False) + ) + + logprobs = [] + + for i in range(num_tokens): + # Start with a base value + logprob = -0.1 - (0.01 * i) + + # Make sensitive to key features + if "a" not in prompt: + logprob -= 0.1 + if "c" not in prompt: + logprob -= 0.2 + if "d" not in prompt: + logprob -= 0.3 + if "f" not in prompt: + logprob -= 0.4 + + logprobs.append(logprob) + + return logprobs[-num_target_str_tokens:] + + +@parameterized_class( + ("device",), [("cpu",), ("cuda",)] if torch.cuda.is_available() else [("cpu",)] +) +class TestRemoteLLMAttr(BaseTest): + # pyre-fixme[13]: Attribute `device` is never initialized. + device: str + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @parameterized.expand( + [ + ( + AttrClass, + delta, + n_samples, + torch.tensor(true_seq_attr), + torch.tensor(true_tok_attr), + ) + for AttrClass, delta, n_samples, true_seq_attr, true_tok_attr in zip( + (FeatureAblation, ShapleyValueSampling, ShapleyValues), # AttrClass + (0.001, 0.001, 0.001), # delta + (None, 1000, None), # n_samples + ( # true_seq_attr + [0.5, 1.0, 1.5, 2.0], # FeatureAblation + [0.5, 1.0, 1.5, 2.0], # ShapleyValueSampling + [0.5, 1.0, 1.5, 2.0], # ShapleyValues + ), + ( # true_tok_attr + [ # FeatureAblation + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + ], + [ # ShapleyValueSampling + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + ], + [ # ShapleyValues + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + ], + ), + ) + ] + ) + def test_remote_llm_attr( + self, + AttrClass: Type[PerturbationAttribution], + delta: float, + n_samples: Optional[int], + true_seq_attr: Tensor, + true_tok_attr: Tensor, + ) -> None: + attr_kws: Dict[str, int] = {} + if n_samples is not None: + attr_kws["n_samples"] = n_samples + + tokenizer = DummyTokenizer() + provider = DummyRemoteLLMProvider(deterministic_logprobs=True) + # attr_method = AttrClass(RemoteLLMAttribution.placeholder_model) + placeholder_model = RemoteLLMAttribution.placeholder_model + placeholder_model.device = self.device + attr_method = AttrClass(placeholder_model) + remote_llm_attr = RemoteLLMAttribution( + attr_method=attr_method, + tokenizer=tokenizer, + provider=provider, + ) + + # from TestLLMAttr + inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"]) + res = remote_llm_attr.attribute( + inp, + "m n o p q", + skip_tokens=[0], + # use_cached_outputs=self.use_cached_outputs, + # pyre-fixme[6]: In call `LLMAttribution.attribute`, + # for 4th positional argument, expected + # `Optional[typing.Callable[..., typing.Any]]` but got `int`. + **attr_kws, # type: ignore + ) + + self.assertEqual(res.seq_attr.shape, (4,)) + self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4)) + self.assertEqual(res.input_tokens, ["a", "c", "d", "f"]) + self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"]) + self.assertEqual(res.seq_attr.device.type, self.device) + self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device) + + assertTensorAlmostEqual( + self, + actual=res.seq_attr, + expected=true_seq_attr, + delta=delta, + mode="max", + ) + assertTensorAlmostEqual( + self, + actual=res.token_attr, + expected=true_tok_attr, + delta=delta, + mode="max", + ) + + def test_remote_llm_attr_without_target(self) -> None: + + tokenizer = DummyTokenizer() + provider = DummyRemoteLLMProvider(deterministic_logprobs=True) + # attr_method = FeatureAblation(RemoteLLMAttribution.placeholder_model) + placeholder_model = RemoteLLMAttribution.placeholder_model + placeholder_model.device = self.device + attr_method = FeatureAblation(placeholder_model) + remote_llm_attr = RemoteLLMAttribution( + attr_method=attr_method, + tokenizer=tokenizer, + provider=provider, + ) + + # from TestLLMAttr + inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"]) + res = remote_llm_attr.attribute( + inp, + gen_args={"mock_response": "x y z"}, + # use_cached_outputs=self.use_cached_outputs, + ) + + self.assertEqual(res.seq_attr.shape, (4,)) + self.assertEqual(cast(Tensor, res.token_attr).shape, (3, 4)) + self.assertEqual(res.input_tokens, ["a", "c", "d", "f"]) + self.assertEqual(res.output_tokens, ["x", "y", "z"]) + self.assertEqual(res.seq_attr.device.type, self.device) + self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device) + + def test_remote_llm_attr_fa_log_prob(self) -> None: + + tokenizer = DummyTokenizer() + provider = DummyRemoteLLMProvider(deterministic_logprobs=True) + attr_method = FeatureAblation(RemoteLLMAttribution.placeholder_model) + remote_llm_attr = RemoteLLMAttribution( + attr_method=attr_method, + tokenizer=tokenizer, + provider=provider, + attr_target="log_prob", + ) + + # from TestLLMAttr + inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"]) + res = remote_llm_attr.attribute( + inp, + "m n o p q", + skip_tokens=[0], + # use_cached_outputs=self.use_cached_outputs, + ) + + # With FeatureAblation, the seq attr in log_prob + # equals to the sum of each token attr + assertTensorAlmostEqual(self, res.seq_attr, cast(Tensor, res.token_attr).sum(0)) + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @parameterized.expand( + [ + ( + AttrClass, + delta, + n_samples, + torch.tensor(true_seq_attr), + interpretable_model, + ) + for AttrClass, delta, n_samples, true_seq_attr, interpretable_model in zip( + (Lime, KernelShap), + (0.003, 0.001), + (1000, 2500), + ( + [0.4956, 0.9957, 1.4959, 1.9959], + [0.5, 1.0, 1.5, 2.0], + ), + (SkLearnLasso(alpha=0.001), None), + ) + ] + ) + def test_remote_llm_attr_without_token( + self, + AttrClass: Type[PerturbationAttribution], + delta: float, + n_samples: int, + true_seq_attr: Tensor, + interpretable_model: Optional[nn.Module] = None, + ) -> None: + init_kws = {} + if interpretable_model is not None: + init_kws["interpretable_model"] = interpretable_model + attr_kws: Dict[str, int] = {} + if n_samples is not None: + attr_kws["n_samples"] = n_samples + + tokenizer = DummyTokenizer() + provider = DummyRemoteLLMProvider(deterministic_logprobs=True) + # attr_method = AttrClass(RemoteLLMAttribution.placeholder_model, **init_kws) + placeholder_model = RemoteLLMAttribution.placeholder_model + placeholder_model.device = self.device + attr_method = AttrClass(placeholder_model, **init_kws) + remote_llm_attr = RemoteLLMAttribution( + attr_method=attr_method, + tokenizer=tokenizer, + provider=provider, + attr_target="log_prob", + ) + + inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"]) + res = remote_llm_attr.attribute( + inp, + "m n o p q", + skip_tokens=[0], + # use_cached_outputs=self.use_cached_outputs, + **attr_kws, # type: ignore + ) + + self.assertEqual(res.seq_attr.shape, (4,)) + self.assertEqual(res.seq_attr.device.type, self.device) + self.assertEqual(res.token_attr, None) + self.assertEqual(res.input_tokens, ["a", "c", "d", "f"]) + self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"]) + assertTensorAlmostEqual( + self, + actual=res.seq_attr, + expected=true_seq_attr, + delta=delta, + mode="max", + ) + + def test_remote_llm_attr_futures_not_implemented(self) -> None: + + tokenizer = DummyTokenizer() + provider = DummyRemoteLLMProvider() + attr_method = FeatureAblation(RemoteLLMAttribution.placeholder_model) + remote_llm_attr = RemoteLLMAttribution( + attr_method=attr_method, + tokenizer=tokenizer, + provider=provider, + ) + + # from TestLLMAttr + attributions = None + with self.assertRaises(NotImplementedError): + attributions = remote_llm_attr.attribute_future() + self.assertEqual(attributions, None) + + def test_remote_llm_attr_with_no_skip_tokens(self) -> None: + + tokenizer = DummyTokenizer() + provider = DummyRemoteLLMProvider(deterministic_logprobs=True) + attr_method = FeatureAblation(RemoteLLMAttribution.placeholder_model) + remote_llm_fa = RemoteLLMAttribution( + attr_method=attr_method, + tokenizer=tokenizer, + provider=provider, + ) + + # from TestLLMAttr + inp = TextTokenInput("a b c", tokenizer) + res = remote_llm_fa.attribute(inp, "m n o p q") + + # 5 output tokens, 4 input tokens including sos + self.assertEqual(res.seq_attr.shape, (4,)) + assert res.token_attr is not None + self.assertIsNotNone(res.token_attr) + token_attr = res.token_attr + self.assertEqual(token_attr.shape, (6, 4)) + self.assertEqual(res.input_tokens, ["", "a", "b", "c"]) + self.assertEqual(res.output_tokens, ["", "m", "n", "o", "p", "q"]) + + def test_remote_llm_attr_with_skip_tensor_target(self) -> None: + + tokenizer = DummyTokenizer() + provider = DummyRemoteLLMProvider(deterministic_logprobs=True) + attr_method = FeatureAblation(RemoteLLMAttribution.placeholder_model) + remote_llm_fa = RemoteLLMAttribution( + attr_method=attr_method, + tokenizer=tokenizer, + provider=provider, + ) + + # from TestLLMAttr + inp = TextTokenInput("a b c", tokenizer) + res = remote_llm_fa.attribute( + inp, + torch.tensor(tokenizer.encode("m n o p q")), + skip_tokens=[0], + ) + + # 5 output tokens, 4 input tokens including sos + self.assertEqual(res.seq_attr.shape, (4,)) + assert res.token_attr is not None + self.assertIsNotNone(res.token_attr) + token_attr = res.token_attr + self.assertEqual(token_attr.shape, (5, 4)) + self.assertEqual(res.input_tokens, ["", "a", "b", "c"]) + self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"])