Skip to content

Add Remote LLM Support for Perturbation-Based Attribution via RemoteLLMAttribution and VLLMProvider #1544

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions captum/attr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
LLMAttribution,
LLMAttributionResult,
LLMGradientAttribution,
RemoteLLMAttribution,
)
from captum.attr._core.remote_provider import RemoteLLMProvider, VLLMProvider
from captum.attr._core.lrp import LRP
from captum.attr._core.neuron.neuron_conductance import NeuronConductance
from captum.attr._core.neuron.neuron_deep_lift import NeuronDeepLift, NeuronDeepLiftShap
Expand Down Expand Up @@ -111,6 +113,9 @@
"LLMAttribution",
"LLMAttributionResult",
"LLMGradientAttribution",
"RemoteLLMAttribution",
"RemoteLLMProvider",
"VLLMProvider",
"InternalInfluence",
"InterpretableInput",
"LayerGradCam",
Expand Down
120 changes: 120 additions & 0 deletions captum/attr/_core/llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
TextTokenInput,
)
from torch import nn, Tensor
from captum.attr._core.remote_provider import RemoteLLMProvider

DEFAULT_GEN_ARGS: Dict[str, Any] = {
"max_new_tokens": 25,
Expand Down Expand Up @@ -892,3 +893,122 @@ def forward(

# the attribution target is limited to the log probability
return token_log_probs


class RemoteLLMAttribution(LLMAttribution):
"""
Attribution class for large language models that are hosted remotely and offer logprob APIs.
"""
def __init__(
self,
attr_method: PerturbationAttribution,
tokenizer: TokenizerLike,
provider: RemoteLLMProvider,
attr_target: str = "log_prob",
) -> None:
"""
Args:
attr_method: Instance of a supported perturbation attribution class
tokenizer (Tokenizer): tokenizer of the llm model used in the attr_method
provider: Remote LLM provider that implements the RemoteLLMProvider protocol
attr_target: attribute towards log probability or probability.
Available values ["log_prob", "prob"]
Default: "log_prob"
"""
super().__init__(
attr_method=attr_method,
tokenizer=tokenizer,
attr_target=attr_target,
)

self.provider = provider
self.attr_method.forward_func = self._remote_forward_func

def _get_target_tokens(
self,
inp: InterpretableInput,
target: Union[str, torch.Tensor, None] = None,
skip_tokens: Union[List[int], List[str], None] = None,
gen_args: Optional[Dict[str, Any]] = None
) -> Tensor:
"""
Get the target tokens for the remote LLM provider.
"""
assert isinstance(
inp, self.SUPPORTED_INPUTS
), f"RemoteLLMAttribution does not support input type {type(inp)}"

if target is None:
# generate when None with remote provider
assert hasattr(self.provider, "generate") and callable(self.provider.generate), (
"The provider does not have generate function for generating target sequence."
"Target must be given for attribution"
)
if not gen_args:
gen_args = DEFAULT_GEN_ARGS

model_inp = self._format_model_input(inp.to_model_input())
target_str = self.provider.generate(model_inp, **gen_args)
target_tokens = self.tokenizer.encode(target_str, return_tensors="pt", add_special_tokens=False)[0]

else:
target_tokens = super()._get_target_tokens(inp, target, skip_tokens, gen_args)

return target_tokens

def _format_model_input(self, model_input: Union[str, Tensor]) -> str:
"""
Format the model input for the remote LLM provider.
"""
# return str input
if isinstance(model_input, Tensor):
return self.tokenizer.decode(model_input.flatten())
return model_input

def _remote_forward_func(
self,
perturbed_tensor: Union[None, Tensor],
inp: InterpretableInput,
target_tokens: Tensor,
use_cached_outputs: bool = False,
_inspect_forward: Optional[Callable[[str, str, List[float]], None]] = None,
) -> Tensor:
"""
Forward function for the remote LLM provider.

Raises:
ValueError: If the number of token logprobs doesn't match expected length
"""
perturbed_input = self._format_model_input(inp.to_model_input(perturbed_tensor))

target_str:str = self.tokenizer.decode(target_tokens)

target_token_probs = self.provider.get_logprobs(input_prompt=perturbed_input, target_str=target_str, tokenizer=self.tokenizer)

if len(target_token_probs) != target_tokens.size()[0]:
raise ValueError(
f"Number of token logprobs from provider ({len(target_token_probs)}) "
f"does not match expected target token length ({target_tokens.size()[0]})"
)

log_prob_list: List[Tensor] = list(map(torch.tensor, target_token_probs))

total_log_prob = torch.sum(torch.stack(log_prob_list), dim=0)
# 1st element is the total prob, rest are the target tokens
# add a leading dim for batch even we only support single instance for now
if self.include_per_token_attr:
target_log_probs = torch.stack(
[total_log_prob, *log_prob_list], dim=0
).unsqueeze(0)
else:
target_log_probs = total_log_prob
target_probs = torch.exp(target_log_probs)

if _inspect_forward:
prompt = perturbed_input
response = self.tokenizer.decode(target_tokens)

# callback for externals to inspect (prompt, response, seq_prob)
_inspect_forward(prompt, response, target_probs[0].tolist())

return target_probs if self.attr_target != "log_prob" else target_log_probs
191 changes: 191 additions & 0 deletions captum/attr/_core/remote_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
from abc import ABC, abstractmethod
from typing import Any, List, Optional
from captum._utils.typing import TokenizerLike
from openai import OpenAI
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is being imported into attr.py, let's move this import to the VLLMProvider init wrapped around with a try-except, telling the user to install openai package if it isn't already. We'll keep it as an optional dependency as you have it in setup.py

import os

class RemoteLLMProvider(ABC):
"""All remote LLM providers that offer logprob via API (like vLLM) extends this class."""

api_url: str

@abstractmethod
def generate(
self,
prompt: str,
**gen_args: Any
) -> str:
"""
Args:
prompt: The input prompt to generate from
gen_args: Additional generation arguments

Returns:
The generated text.
"""
...

@abstractmethod
def get_logprobs(
self,
input_prompt: str,
target_str: str,
tokenizer: Optional[TokenizerLike] = None
) -> List[float]:
"""
Get the log probabilities for all tokens in the target string.

Args:
input_prompt: The input prompt
target_str: The target string
tokenizer: The tokenizer to use

Returns:
A list of log probabilities corresponding to each token in the target prompt.
For a `target_str` of `t` tokens, this method returns a list of logprobs of length `k`.
"""
...

class VLLMProvider(RemoteLLMProvider):
def __init__(self, api_url: str, model_name: Optional[str] = None):
"""
Initialize a vLLM provider.

Args:
api_url: The URL of the vLLM API
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would let api_url be an optional type without a default. With OpenAI API it will default to openAI's base URL which is what users may expect, but we can wrap VLLMProvider and call it OpenAIProvider for convenience.

Suggestion:

class VLLMProvider(RemoteLLMProvider):
    def __init__(self, api_url: Optional[str], model_name: Optional[str] = None):

Additional class (also imported into attr):

class OpenAIProvider(VLLMProvider):
    def __init__(self, api_url: Optional[str] = None, model_name: Optional[str] = None):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually looks like OpenAI isn't letting us force logprobs with new models...feel free to ignore this for now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this class is bound to OpenAI sdk. I am wondering how much of it is general to be extracted to another VLLMProvider. Maybe just name this OpenAIProvider?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing to consider is that OpenAI doesn't support prompt logprobs at least with newer models, it's a VLLM engine parameter which is why it needs to be passed in a weird way. I actually don't think prompt_logprobs is even supported by OpenAI and there is a different way to grab logprobs, but I'm not certain.

Also VLLM server is its own library, but uses the OpenAI API and so is compatible with its SDKs as long as the base URL is set correctly. VLLM doesn't have its own SDK afaik.

model_name: The name of the model to use. If None, the first model from
the API's model list will be used.

Raises:
ValueError: If api_url is empty or model_name is not in the API's model list
ConnectionError: If API connection fails
"""
if not api_url.strip():
raise ValueError("API URL is required")

self.api_url = api_url

try:
self.client = OpenAI(base_url=self.api_url,
api_key=os.getenv("OPENAI_API_KEY", "EMPTY")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add env var to docstring
Looks like we don't need to handle this logic
https://github.com/openai/openai-python/blob/main/src/openai/_client.py#L114

)

# If model_name is not provided, get the first available model from the API
if model_name is None:
models = self.client.models.list().data
if not models:
raise ValueError("No models available from the vLLM API")
self.model_name = models[0].id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's log an informational message about which model is being used since none was provided

else:
self.model_name = model_name

except ConnectionError as e:
raise ConnectionError(f"Failed to connect to vLLM API: {str(e)}")
except Exception as e:
raise Exception(f"Unexpected error while initializing vLLM provider: {str(e)}")

def generate(self, prompt: str, **gen_args: Any) -> str:
"""
Generate text using the vLLM API.

Args:
prompt: The input prompt for text generation
**gen_args: Additional generation arguments

Returns:
str: The generated text

Raises:
KeyError: If API response is missing expected data
ConnectionError: If connection to API fails
"""
# Parameter normalization
if 'max_tokens' not in gen_args:
gen_args['max_tokens'] = gen_args.pop('max_new_tokens', 25)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add default to docstring

if 'do_sample' in gen_args:
gen_args.pop('do_sample')

try:
response = self.client.completions.create(
model=self.model_name,
prompt=prompt,
**gen_args
)
if not hasattr(response, 'choices') or not response.choices:
raise KeyError("API response missing expected 'choices' data")

return response.choices[0].text

except ConnectionError as e:
raise ConnectionError(f"Failed to connect to vLLM API: {str(e)}")
except Exception as e:
raise Exception(f"Unexpected error during text generation: {str(e)}")

def get_logprobs(
self,
input_prompt: str,
target_str: str,
tokenizer: Optional[TokenizerLike] = None
) -> List[float]:
"""
Get the log probabilities for all tokens in the target string.

Args:
input_prompt: The input prompt
target_str: The target string
tokenizer: The tokenizer to use

Returns:
A list of log probabilities corresponding to each token in the target prompt.
For a `target_str` of `t` tokens, this method returns a list of logprobs of length `k`.

Raises:
ValueError: If tokenizer is None or target_str is empty or response format is invalid
KeyError: If API response is missing expected data
IndexError: If response format is unexpected
ConnectionError: If connection to API fails
"""
if tokenizer is None:
raise ValueError("Tokenizer is required for vLLM provider")
if not target_str:
raise ValueError("Target string cannot be empty")

num_target_str_tokens = len(tokenizer.encode(target_str, add_special_tokens=False))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that each response contains token-level information. I think that it's possible that when we generate a response if none is provided that we might not need to require a tokenizer, which will minimize tokenizer-model mismatch. However, I think we can just make note of this and keep as future work.


prompt = input_prompt + target_str

try:
response = self.client.completions.create(
model=self.model_name,
prompt=prompt,
temperature=0.0,
max_tokens=1,
extra_body={"prompt_logprobs": 0}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should prompt_logprobs be 1 here? I thought 0 would result in an exception as it needs to be > 1

)

if not hasattr(response, 'choices') or not response.choices:
raise KeyError("API response missing expected 'choices' data")

if not hasattr(response.choices[0], 'prompt_logprobs'):
raise KeyError("API response missing 'prompt_logprobs' data")

prompt_logprobs = []
try:
for probs in response.choices[0].prompt_logprobs[1:]:
if not probs:
raise ValueError("Empty probability data in API response")
prompt_logprobs.append(list(probs.values())[0]['logprob'])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, prompt_logprobs is a list[dict[str, dict[str, float]]] where each element corresponds to a prompt token with contents being a map from generation token for each of the k prompt_logprobs with generation data in the token's map, including the logprob. If so, maybe we should assert that the length of probs is always 1?

Also iiuc, I can nitpick:

  • We can iterate over the final num_target_str_tokens only
  • Can replace list(probs.values())[0] with next(iter(probs.values()))

except (IndexError, KeyError) as e:
raise IndexError(f"Unexpected format in log probability data: {str(e)}")

if len(prompt_logprobs) < num_target_str_tokens:
raise ValueError(f"Not enough logprobs received: expected {num_target_str_tokens}, got {len(prompt_logprobs)}")

return prompt_logprobs[-num_target_str_tokens:]

except ConnectionError as e:
raise ConnectionError(f"Failed to connect to vLLM API when getting logprobs: {str(e)}")
except Exception as e:
raise Exception(f"Unexpected error while getting log probabilities: {str(e)}")


4 changes: 4 additions & 0 deletions setup.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,12 @@ def report(*args):

TEST_REQUIRES = ["pytest", "pytest-cov", "parameterized", "flask", "flask-compress"]

REMOTE_REQUIRES = ["openai"]

DEV_REQUIRES = (
INSIGHTS_REQUIRES
+ TEST_REQUIRES
+ REMOTE_REQUIRES
+ [
"black",
"flake8",
Expand Down Expand Up @@ -169,6 +172,7 @@ def get_package_files(root, subdirs):
"insights": INSIGHTS_REQUIRES,
"test": TEST_REQUIRES,
"tutorials": TUTORIALS_REQUIRES,
"remote": REMOTE_REQUIRES,
},
package_data={"captum": package_files},
data_files=[
Expand Down
Loading