Implementation of various tools for multi-head attention explainability from transformers.
from explainable_attention.self_attention_attribution import compute
...
def objective(batch):
x, y = batch
y = model(x)
loss = loss_fn(x, y)
return loss
attribution = saa.compute(
model.transformer_encoder.layers,
objective,
batch,
integration_steps=20)