Skip to content

Commit 9542a9c

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Attribution API refactor: Introduce an optional agg/element-wise variance to LLM attribution results (#1659)
Summary: As title. It is possible for attr to be computed as an estimated amount over multiple samples of the response, so the estimate has variance. This adds an attribute to store this variance in the results, if we have it. Differential Revision: D84970183
1 parent 1e44c99 commit 9542a9c

File tree

1 file changed

+85
-10
lines changed

1 file changed

+85
-10
lines changed

captum/attr/_core/llm_attr.py

Lines changed: 85 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
Dict,
1515
Generic,
1616
List,
17+
Literal,
1718
Optional,
19+
overload,
1820
Tuple,
1921
Type,
2022
TYPE_CHECKING,
@@ -62,6 +64,26 @@
6264
TTargetValue = TypeVar("TTargetValue")
6365

6466

67+
@overload
68+
def _to_tensor(
69+
name: str, arr: Optional[npt.ArrayLike], none_ok: Literal[True] = ...
70+
) -> Optional[Tensor]: ...
71+
@overload
72+
def _to_tensor(
73+
name: str, arr: Optional[npt.ArrayLike], none_ok: Literal[False] = ...
74+
) -> Tensor: ...
75+
def _to_tensor(
76+
name: str, arr: Optional[npt.ArrayLike], none_ok: bool = False
77+
) -> Optional[Tensor]:
78+
if arr is None:
79+
if none_ok:
80+
return None
81+
raise TypeError(f"Expected array-like for `{name}` but received None!")
82+
if not isinstance(arr, Tensor):
83+
arr = torch.tensor(arr)
84+
return arr
85+
86+
6587
@dataclass(kw_only=True)
6688
class BaseLLMAttributionResult(ABC, Generic[TInputValue, TTargetValue]):
6789
"""
@@ -77,6 +99,8 @@ class BaseLLMAttributionResult(ABC, Generic[TInputValue, TTargetValue]):
7799
] # value for each target name e.g. token prob
78100
_aggregate_attr: Tensor # 1D [# input_values]
79101
_element_attr: Optional[Tensor] = None # 2D [# target_names, # input_values]
102+
_aggregate_attr_var: Optional[Tensor] = None # 1D [# input_values]
103+
_element_attr_var: Optional[Tensor] = None # 2D [# target_names, # input_values]
80104
aggregate_descriptor: str = "Aggregate"
81105
element_descriptor: str = "Element"
82106

@@ -88,6 +112,8 @@ def __init__(
88112
target_values: Optional[Union[npt.ArrayLike, List[TTargetValue]]] = None,
89113
aggregate_attr: npt.ArrayLike,
90114
element_attr: Optional[npt.ArrayLike] = None,
115+
aggregate_attr_var: Optional[npt.ArrayLike] = None,
116+
element_attr_var: Optional[npt.ArrayLike] = None,
91117
aggregate_descriptor: str = "Aggregate",
92118
element_descriptor: str = "Element",
93119
) -> None:
@@ -96,6 +122,8 @@ def __init__(
96122
self.target_values = target_values
97123
self.aggregate_attr = aggregate_attr
98124
self.element_attr = element_attr
125+
self.aggregate_attr_var = aggregate_attr_var
126+
self.element_attr_var = element_attr_var
99127
self.aggregate_descriptor = aggregate_descriptor
100128
self.element_descriptor = element_descriptor
101129

@@ -105,10 +133,9 @@ def aggregate_attr(self) -> Tensor:
105133

106134
@aggregate_attr.setter
107135
def aggregate_attr(self, aggregate_attr: npt.ArrayLike) -> None:
108-
if isinstance(aggregate_attr, Tensor):
109-
self._aggregate_attr = aggregate_attr
110-
else:
111-
self._aggregate_attr = torch.tensor(aggregate_attr)
136+
self._aggregate_attr = _to_tensor(
137+
"aggregate_attr", aggregate_attr, none_ok=False
138+
)
112139
# IDEA: in the future we might want to support higher dim seq_attr
113140
# (e.g. attention w.r.t. multiple layers, gradients w.r.t. different classes)
114141
assert len(self._aggregate_attr.shape) == 1, "seq_attr must be a 1D tensor"
@@ -122,12 +149,7 @@ def element_attr(self) -> Optional[Tensor]:
122149

123150
@element_attr.setter
124151
def element_attr(self, element_attr: Optional[npt.ArrayLike]) -> None:
125-
if element_attr is None:
126-
self._element_attr = None
127-
elif isinstance(element_attr, Tensor):
128-
self._element_attr = element_attr
129-
else:
130-
self._element_attr = torch.tensor(element_attr)
152+
self._element_attr = _to_tensor("element_attr", element_attr, none_ok=True)
131153

132154
if self._element_attr is not None:
133155
# IDEA: in the future we might want to support higher dim seq_attr
@@ -141,6 +163,39 @@ def element_attr(self, element_attr: Optional[npt.ArrayLike]) -> None:
141163
f"got {self._element_attr.shape}"
142164
)
143165

166+
@property
167+
def aggregate_attr_var(self) -> Optional[Tensor]:
168+
return self._aggregate_attr_var
169+
170+
@aggregate_attr_var.setter
171+
def aggregate_attr_var(self, aggregate_attr_var: Optional[npt.ArrayLike]) -> None:
172+
self._aggregate_attr_var = _to_tensor(
173+
"aggregate_attr_var", aggregate_attr_var, none_ok=True
174+
)
175+
if self._aggregate_attr_var is not None:
176+
assert self._aggregate_attr_var.shape == self._aggregate_attr.shape, (
177+
f"aggregate_attr ({self._aggregate_attr.shape}) must have same shape "
178+
f"as aggregate_attr_var ({self._aggregate_attr_var.shape})"
179+
)
180+
181+
@property
182+
def element_attr_var(self) -> Optional[Tensor]:
183+
return self._element_attr_var
184+
185+
@element_attr_var.setter
186+
def element_attr_var(self, element_attr_var: Optional[npt.ArrayLike]) -> None:
187+
self._element_attr_var = _to_tensor(
188+
"element_attr_var", element_attr_var, none_ok=True
189+
)
190+
if self._element_attr_var is not None:
191+
assert (
192+
self._element_attr is not None
193+
), "element_attr must be set before setting element_attr_var"
194+
assert self._element_attr_var.shape == self._element_attr.shape, (
195+
f"element_attr ({self._element_attr.shape}) must have same shape "
196+
f"as element_attr_var ({self._element_attr_var.shape})"
197+
)
198+
144199
@property
145200
def target_values(self) -> Optional[List[TTargetValue]]:
146201
return self._target_values
@@ -377,6 +432,22 @@ def token_attr(self) -> Optional[Tensor]:
377432
def token_attr(self, token_attr: Optional[npt.ArrayLike]) -> None:
378433
self.element_attr = token_attr
379434

435+
@property
436+
def seq_attr_var(self) -> Optional[Tensor]:
437+
return self.aggregate_attr_var
438+
439+
@seq_attr_var.setter
440+
def seq_attr_var(self, seq_attr_var: Optional[npt.ArrayLike]) -> None:
441+
self.aggregate_attr_var = seq_attr_var
442+
443+
@property
444+
def token_attr_var(self) -> Optional[Tensor]:
445+
return self.element_attr_var
446+
447+
@token_attr_var.setter
448+
def token_attr_var(self, token_attr_var: Optional[npt.ArrayLike]) -> None:
449+
self.element_attr_var = token_attr_var
450+
380451
@property
381452
def seq_attr_dict(self) -> Dict[TInputValue, float]:
382453
return self.aggregate_attr_dict
@@ -402,6 +473,8 @@ def __init__(
402473
output_tokens: List[str],
403474
seq_attr: npt.ArrayLike,
404475
token_attr: Optional[npt.ArrayLike] = None,
476+
seq_attr_var: Optional[npt.ArrayLike] = None,
477+
token_attr_var: Optional[npt.ArrayLike] = None,
405478
output_probs: Optional[npt.ArrayLike] = None,
406479
) -> None:
407480
super().__init__(
@@ -410,6 +483,8 @@ def __init__(
410483
target_values=output_probs,
411484
aggregate_attr=seq_attr,
412485
element_attr=token_attr,
486+
aggregate_attr_var=seq_attr_var,
487+
element_attr_var=token_attr_var,
413488
aggregate_descriptor="Sequence",
414489
element_descriptor="Token",
415490
)

0 commit comments

Comments
 (0)