Skip to content

Using LayerIntegratedGradients with an explicit LSTM layer #1528

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ian-grover opened this issue Mar 14, 2025 · 1 comment
Open

Using LayerIntegratedGradients with an explicit LSTM layer #1528

ian-grover opened this issue Mar 14, 2025 · 1 comment

Comments

@ian-grover
Copy link

ian-grover commented Mar 14, 2025

Hi

I am exploring using Captum to get some interpretability for a hybrid model I am building.

My basic model calculates an embedding, which is combined with other inputs and passed to an nn.LSTM module.

Based on the documentation, as I calculate the embedding during training, I should use the LayerIntegratedGradients package because my raw inputs contain integers which get transformed by the embedding into floating features.

My issue is that when I specify the LSTM layer as a layer of interest, I am getting this error:

---> 70 packed_output, (h_n, c_n) = self.lstm(packed_input)

File [/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1553](https://90tjj0vq5z8apbc.studio.eu-central-1.sagemaker.aws/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py#line=1552), in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File [/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py:1616](https://90tjj0vq5z8apbc.studio.eu-central-1.sagemaker.aws/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py#line=1615), in Module._call_impl(self, *args, **kwargs)
   1614     hook_result = hook(self, args, kwargs, result)
   1615 else:
-> 1616     hook_result = hook(self, args, result)
   1618 if hook_result is not None:
   1619     result = hook_result

File [/opt/conda/lib/python3.11/site-packages/captum/_utils/gradient.py:277](https://90tjj0vq5z8apbc.studio.eu-central-1.sagemaker.aws/opt/conda/lib/python3.11/site-packages/captum/_utils/gradient.py#line=276), in _forward_layer_distributed_eval.<locals>.hook_wrapper.<locals>.forward_hook(module, inp, out)
    275     return eval_tsrs_to_return
    276 else:
--> 277     saved_layer[original_module][eval_tsrs[0].device] = tuple(
    278         eval_tsr.clone() for eval_tsr in eval_tsrs
    279     )

File [/opt/conda/lib/python3.11/site-packages/captum/_utils/gradient.py:278](https://90tjj0vq5z8apbc.studio.eu-central-1.sagemaker.aws/opt/conda/lib/python3.11/site-packages/captum/_utils/gradient.py#line=277), in <genexpr>(.0)
    275     return eval_tsrs_to_return
    276 else:
    277     saved_layer[original_module][eval_tsrs[0].device] = tuple(
--> 278         eval_tsr.clone() for eval_tsr in eval_tsrs
    279     )

AttributeError: 'tuple' object has no attribute 'clone'

If I am to interpret this correctly, a hook is attempting to copy each element of the output of the LSTM layer, but by construction this module returns the tensor output, and then a tuple of hidden layer information (h_n, c_n). I assume this tuple is being processed and throwing the error. Is there any way to work around this?

@jjuncho
Copy link
Contributor

jjuncho commented Mar 18, 2025

@ian-grover Hello! So as you've stated, the forward hook was written for only layers with tensor/tuple of tensor outputs in mind. We don't really have anything on our roadmap for supporting something like a tuple(tensor, tuple(int)) type of outputs, but a hacky way of getting around your issue would probably be to fork the repo, and replace line 278 in the gradients util with:
eval_tsr.clone() for eval_tsr in eval_tsrs if isinstance(eval_tsr, torch.Tensor)) so that the hook can ignore non-tensor outputs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants