Skip to content

Making LLMAttribute work with BertForMultipleChoice models #1524

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
rbelew opened this issue Mar 7, 2025 · 3 comments
Open

Making LLMAttribute work with BertForMultipleChoice models #1524

rbelew opened this issue Mar 7, 2025 · 3 comments

Comments

@rbelew
Copy link

rbelew commented Mar 7, 2025

🚀 Feature

Allow LLMAttribution goodness to be applied to BERT models for multiple choice tasks

Motivation

following up on suggestions from aobo-y

Pitch

Integrated gradient attribution techniques work over BertForMultipleChoice; it would be great if
FeatureAblation / LLMAttribution did, too.

Alternatives

Two suggestions were made

First approach:

  • code
    fa = FeatureAblation(model) 
    llm_attr = LLMAttribution(fa, tokenizer)

    inp = TextTokenInput(promptTxt, tokenizer)
    
    attributions_fa = llm_attr.attribute(
                          inp,
                          target=targetIdxTensor,
                          additional_forward_args=dict(
                            token_type_ids=tst['token_type_ids'],
                            # position_ids=position_ids, 
                            attention_mask=tst['attention_mask'],
                            )
                          )

  • throws error:

      File ".../captumPerturb_min.py", line 160, in captumPerturbOne
      attributions_fa = llm_attr.attribute(
      ^^^^^^^^^^^^^^^^^^^
      File ".../site-packages/captum/attr/_core/llm_attr.py", line 674, in attribute
      cur_attr = self.attr_method.attribute(
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
      TypeError: captum.attr._core.feature_ablation.FeatureAblation.attribute() got multiple values for keyword argument 'additional_forward_args'
    
  • dropping additional_forward_args parameter gets farther, but
    throws:

      File ".../captumPerturb_min.py", line 160, in captumPerturbOne
      attributions_fa = llm_attr.attribute(
      ^^^^^^^^^^^^^^^^^^^
      File ".../site-packages/captum/attr/_core/llm_attr.py", line 674, in attribute
      cur_attr = self.attr_method.attribute(
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File ".../site-packages/captum/log/dummy_log.py", line 39, in wrapper
      return func(*args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^
      File ".../site-packages/captum/attr/_core/feature_ablation.py", line 288, in attribute
      initial_eval: Union[Tensor, Future[Tensor]] = _run_forward(
      ^^^^^^^^^^^^^
      File ".../site-packages/captum/_utils/common.py", line 588, in _run_forward
      output = forward_func(
      ^^^^^^^^^^^^^
      File ".../site-packages/captum/attr/_core/llm_attr.py", line 574, in _forward_func
      model_inputs = prep_inputs_for_generation(  # type: ignore
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File ".../site-packages/transformers/generation/utils.py", line 376, in prepare_inputs_for_generation
      raise NotImplementedError(
      NotImplementedError: A model class needs to define a `prepare_inputs_for_generation` method in order to use `.generate()`.
    
  • looking in _forward_func variables:

self.model.prepare_inputs_for_generation

    <bound method GenerationMixin.prepare_inputs_for_generation of BertForMultipleChoice(
	(bert): BertModel(
	(embeddings): BertEmbeddings(
	(word_embeddings): Embedding(30522, 768, padding_idx=0)
	(position_embeddings): Embedding(512, 768)
	(token_type_embeddings): Embedding(2, 768)
	(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
	(dropout): Dropout(p=0.1, inplace=False)
	)
	(encoder): BertEncoder(
	(layer): ModuleList(
	(0-11): 12 x BertLayer(
	(attention): BertAttention(
	(self): BertSdpaSelfAttention(
	(query): Linear(in_features=768, out_features=768, bias=True)
	(key): Linear(in_features=768, out_features=768, bias=True)
	(value): Linear(in_features=768, out_features=768, bias=True)
	(dropout): Dropout(p=0.1, inplace=False)
	)
	(output): BertSelfOutput(
	(dense): Linear(in_features=768, out_features=768, bias=True)
	(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
	(dropout): Dropout(p=0.1, inplace=False)
	)
	)
	(intermediate): BertIntermediate(
	(dense): Linear(in_features=768, out_features=3072, bias=True)
	(intermediate_act_fn): GELUActivation()
	)
	(output): BertOutput(
	(dense): Linear(in_features=3072, out_features=768, bias=True)
	(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
	(dropout): Dropout(p=0.1, inplace=False)
	)
	)
	)
	)
	(pooler): BertPooler(
	(dense): Linear(in_features=768, out_features=768, bias=True)
	(activation): Tanh()
	)
	)
	(dropout): Dropout(p=0.1, inplace=False)
	(classifier): Linear(in_features=768, out_features=1, bias=True)
	)>
  • model_inp: tensor, torch.Size([1, 112])

  • model_kwargs.keys()

      dict_keys(['attention_mask', 'cache_position', 'use_cache'])
    

Second approach

  • code
    def multChoice_forward(inputs, token_type_ids=None, position_ids=None, attention_mask=None, target=None):
        output = model(inputs, token_type_ids=token_type_ids,
                     position_ids=position_ids, attention_mask=attention_mask, )
        log_probs = torch.log_softmax(output.logits,1)
       # specify which choice's prob
        return log_probs[target]

    fa = FeatureAblation(multChoice_forward) 
    
    attributions_fa = fa.attribute(
                          tst['input_ids'], 
                          additional_forward_args=dict(
                            token_type_ids=tst['token_type_ids'], 
                            attention_mask=tst['attention_mask'], 
                            target=targetIdxTensor
                          )
                        )

  • throws

      File ".../captumPerturb_min.py", line 294, in main
      captumPerturbOne(model,tokenizer,tstDict,tstTarget)
      File ".../captumPerturb_min.py", line 184, in captumPerturbOne
      attributions_fa = fa.attribute(
      ^^^^^^^^^^^^^
      File ".../site-packages/captum/log/dummy_log.py", line 39, in wrapper
      return func(*args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^
      File ".../site-packages/captum/attr/_core/feature_ablation.py", line 288, in attribute
      initial_eval: Union[Tensor, Future[Tensor]] = _run_forward(
      ^^^^^^^^^^^^^
      File ".../site-packages/captum/_utils/common.py", line 588, in _run_forward
      output = forward_func(
      ^^^^^^^^^^^^^
      File ".../captumPerturb_min.py", line 175, in multChoice_forward
      output = model(inputs, token_type_ids=token_type_ids,
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File ".../site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File ".../site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
      return forward_call(*args, **kwargs)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File ".../site-packages/transformers/models/bert/modeling_bert.py", line 1799, in forward
      token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
      ^^^^^^^^^^^^^^^^^^^
      AttributeError: 'dict' object has no attribute 'view'
    

This is too far into Transformer API-land for me to follow.

Additional context

Additional details in original issue #1523

@rbelew
Copy link
Author

rbelew commented Mar 15, 2025

Continuing to pursue...

first approach

The exception

  File ".../site-packages/captum/attr/_core/llm_attr.py", line 574, in _forward_func
  model_inputs = prep_inputs_for_generation(  # type: ignore

has had me looking into all the technology around text generation. But I don't see why text generation should even be involved in a multiple choice task?

I also ran across recent work on multiple choice as part of logits-processor-zoo, but this logits processing doesn't seem what I need to use with TextTokenInput?

second approach

  • In multChoice_forward, need to retrieve the token_type_ids element from the DICTIONARY token_type_ids (why?) but then it works!
ttID = token_type_ids['token_type_ids']
output = model(inputs, token_type_ids=ttID,
                           position_ids=position_ids, attention_mask=attention_mask, )
  • now attributions_fa returns and has shape = [5, 5, 128], same as tst['input_ids']

  • ¿ Now, how does one interpret attributions_fa?!. This must be documented somewhere?

@aobo-y
Copy link
Contributor

aobo-y commented Apr 1, 2025

@rbelew

1st approach

NotImplementedError: A model class needs to define a prepare_inputs_for_generation method in order to use .generate().

i think it shows for huggingface, the implementation of the Bert you are using is different with other more common LLM, like Llama. Their APIs are not compatible.

For this specific case, you can further try set this flag to False to disable calling huggingface's prepare_inputs_for_generation. But I won't suggest it, coz there you may still run into other incompatibilities.

2nd approach

Glad to see you made the 2nd approach work. As you have found, the original error has nth to do with Captum. It means one arg token_type_ids your model needs is passed incorrectly.

attributions_fa means how much each element in your input_ids of [5, 5, 128] impact the target defined as log_probs[target]. You can compare their signs, i.e., pos/neg change the log_prob, or their magnitude, i.e., who is more impactful than others.

For the shape [5, 5, 128], the 1st dim should be batch_size, but I don't know exactly the other 2 dims mean for your model. A
re the seq_length, and token_embedding_size?

@rbelew
Copy link
Author

rbelew commented Apr 1, 2025

@aobo-y thank you once again!

Re: 1st approach, I agree that the BertForMultipleChoice is too different a model from what Captum expects, related to prepare_inputs_for_generation but perhaps more broadly.

Re: 2d approach,

attributions_fa means how much each element in your input_ids of [5, 5, 128] impact the target defined as log_probs[target]

first, you're right about the input's shape (batch_size, seq_length, embedding_size). But can you unpack what "impact" means? As an ablation technique, can I assume it has to do with the change in log_probs[target] WITHOUT that input? I'm looking for the formula/ documentation for these specifics.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants