Skip to content

Add Remote LLM Support for Perturbation-Based Attribution via RemoteLLMAttribution and VLLMProvider #1544

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

saichandrapandraju
Copy link

@saichandrapandraju saichandrapandraju commented Apr 15, 2025

This PR introduces support for applying Captum's perturbation-based attribution algorithms to remotely hosted large language models (LLMs). It enables users to perform interpretability analyses on models served via APIs, such as those using vLLM, without requiring access to model internals.

Motivation:

Captum’s current LLM attribution framework requires access to local models, limiting its usability in production and hosted environments. With the rise of scalable remote inference backends and OpenAI-compatible APIs, this PR allows Captum to be used for black-box interpretability with hosted models, as long as they return token-level log probabilities.

This integration also aligns with ongoing efforts like llama-stack, which aims to provide a unified API layer for inference (and also for RAG, Agents, Tools, Safety, Evals, and Telemetry) across multiple backends—further expanding Captum’s reach for model explainability.

Key Additions:

  • RemoteLLMProvider Interface:
    A generic interface for fetching log probabilities from remote LLMs, making it easy to plug in various inference backends.
  • VLLMProvider Implementation:
    A concrete subclass of RemoteLLMProvider tailored for models served using vLLM, handling the specifics of communicating with vLLM endpoints to retrieve necessary data for attribution.
  • RemoteLLMAttribution class:
    A subclass of LLMAttribution that overrides internal methods to work with remote providers. It enables all perturbation-based algorithms (e.g., Feature Ablation, Shapley Values, KernelSHAP) using only the output logprobs from a remote LLM.
  • OpenAI-Compatible API Support:
    Used openai client under the hood for querying remote models, as many LLM serving solutions now support the OpenAI-compatible API format (e.g., vLLM OpenAI server and projects like llama-stack(see here for ongoing work related to this).

Issue(s) related to this:

… hosted models that provide logprobs (like vLLM)
@facebook-github-bot
Copy link
Contributor

Hi @saichandrapandraju!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@facebook-github-bot
Copy link
Contributor

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@saichandrapandraju
Copy link
Author

Hi @vivekmig @yucu @aobo-y could you check this PR please. Let me know if I have to change anything

@aobo-y
Copy link
Contributor

aobo-y commented Apr 18, 2025

Thank you @saichandrapandraju for the great effort! Generally, i agree this idea makes a lot of sense.
But our team may need some time to look into the code changes and get back to you.

@craymichael can you take a look at it? since you have studied the integration with llama-stack before

@saichandrapandraju
Copy link
Author

saichandrapandraju commented Apr 18, 2025

Thank you for the positive feedback @aobo-y ! Happy to hear the direction makes sense. Please take your time reviewing — I’ll be around to clarify or iterate on anything as needed. Looking forward to it!

@emaadmanzoor
Copy link

Is there any update on this? Seems like a great idea!

@craymichael
Copy link
Contributor

Hi, sorry for the delay on this! I actually have on my todos to dive into this PR next week.

Copy link
Contributor

@craymichael craymichael left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just wrapped up review of the diff - it's really in great shape, thank you for contributing this! Moreover, the remote provider abstraction is generic enough to support other frameworks, including Llama Stack (which includes support for log prob configs and prompt_logprobs depending on provider). I left some comments and questions (and some thinking out loud...). Let's discuss/iterate and we can land this contribution soon!

attr_kws["n_samples"] = n_samples

# In remote mode, we don't need the actual model, this is just a placeholder
placeholder_model = torch.nn.Module()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a test this is fine. In release notes/an example notebook we can instead use a simpler placeholder like

lambda *_: 0

since the forward function has signature Callable[..., Union[int, float, Tensor, Future[Tensor]]]. I'm thinking as well that for convenience we can have the above lambda available as a class attribute of RemoteLLMAttribution with name placeholder_model so the syntax looks like:

AttrClass(RemoteLLMAttribution.placeholder_model)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As part of these changes, we should update (or create a counterpart) to the LLM attribution tutorial and update the docs. I'm happy to do this once this feature lands as well.

from abc import ABC, abstractmethod
from typing import Any, List, Optional
from captum._utils.typing import TokenizerLike
from openai import OpenAI
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is being imported into attr.py, let's move this import to the VLLMProvider init wrapped around with a try-except, telling the user to install openai package if it isn't already. We'll keep it as an optional dependency as you have it in setup.py

"""
# Parameter normalization
if 'max_tokens' not in gen_args:
gen_args['max_tokens'] = gen_args.pop('max_new_tokens', 25)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add default to docstring


try:
self.client = OpenAI(base_url=self.api_url,
api_key=os.getenv("OPENAI_API_KEY", "EMPTY")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add env var to docstring
Looks like we don't need to handle this logic
https://github.com/openai/openai-python/blob/main/src/openai/_client.py#L114

Initialize a vLLM provider.

Args:
api_url: The URL of the vLLM API
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would let api_url be an optional type without a default. With OpenAI API it will default to openAI's base URL which is what users may expect, but we can wrap VLLMProvider and call it OpenAIProvider for convenience.

Suggestion:

class VLLMProvider(RemoteLLMProvider):
    def __init__(self, api_url: Optional[str], model_name: Optional[str] = None):

Additional class (also imported into attr):

class OpenAIProvider(VLLMProvider):
    def __init__(self, api_url: Optional[str] = None, model_name: Optional[str] = None):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually looks like OpenAI isn't letting us force logprobs with new models...feel free to ignore this for now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this class is bound to OpenAI sdk. I am wondering how much of it is general to be extracted to another VLLMProvider. Maybe just name this OpenAIProvider?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing to consider is that OpenAI doesn't support prompt logprobs at least with newer models, it's a VLLM engine parameter which is why it needs to be passed in a weird way. I actually don't think prompt_logprobs is even supported by OpenAI and there is a different way to grab logprobs, but I'm not certain.

Also VLLM server is its own library, but uses the OpenAI API and so is compatible with its SDKs as long as the base URL is set correctly. VLLM doesn't have its own SDK afaik.

models = self.client.models.list().data
if not models:
raise ValueError("No models available from the vLLM API")
self.model_name = models[0].id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's log an informational message about which model is being used since none was provided

prompt=prompt,
temperature=0.0,
max_tokens=1,
extra_body={"prompt_logprobs": 0}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should prompt_logprobs be 1 here? I thought 0 would result in an exception as it needs to be > 1

for probs in response.choices[0].prompt_logprobs[1:]:
if not probs:
raise ValueError("Empty probability data in API response")
prompt_logprobs.append(list(probs.values())[0]['logprob'])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, prompt_logprobs is a list[dict[str, dict[str, float]]] where each element corresponds to a prompt token with contents being a map from generation token for each of the k prompt_logprobs with generation data in the token's map, including the logprob. If so, maybe we should assert that the length of probs is always 1?

Also iiuc, I can nitpick:

  • We can iterate over the final num_target_str_tokens only
  • Can replace list(probs.values())[0] with next(iter(probs.values()))

if not target_str:
raise ValueError("Target string cannot be empty")

num_target_str_tokens = len(tokenizer.encode(target_str, add_special_tokens=False))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that each response contains token-level information. I think that it's possible that when we generate a response if none is provided that we might not need to require a tokenizer, which will minimize tokenizer-model mismatch. However, I think we can just make note of this and keep as future work.

@@ -669,3 +670,362 @@ def test_llm_attr_with_skip_tensor_target(self) -> None:
self.assertEqual(token_attr.shape, (5, 4))
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"])

class DummyRemoteLLMProvider(RemoteLLMProvider):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great to also have a test of VLLMProvider with a mocked openai client (method client.completions.create mainly) that solidifies the expected values of prompt_logprobs and verifies the logprob logic

@saichandrapandraju
Copy link
Author

saichandrapandraju commented May 9, 2025

Thank you for the comments @craymichael . I'll work on these and get back to you soon!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants