Skip to content

Commit 1f3aa0c

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Attribution API refactor: Update LLM attr typing + minor naming (#1658)
Summary: Update LLM attr definition to accommodate other typing considerations. Clean up some variable names as well. Differential Revision: D84721071
1 parent 7017df8 commit 1f3aa0c

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

captum/attr/_core/llm_attr.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(
8585
*,
8686
input_values: List[TInputValue],
8787
target_names: List[str],
88-
target_values: Optional[npt.ArrayLike] = None,
88+
target_values: Optional[Union[npt.ArrayLike, List[TTargetValue]]] = None,
8989
aggregate_attr: npt.ArrayLike,
9090
element_attr: Optional[npt.ArrayLike] = None,
9191
aggregate_descriptor: str = "Aggregate",
@@ -104,11 +104,11 @@ def aggregate_attr(self) -> Tensor:
104104
return self._aggregate_attr
105105

106106
@aggregate_attr.setter
107-
def aggregate_attr(self, seq_attr: npt.ArrayLike) -> None:
108-
if isinstance(seq_attr, Tensor):
109-
self._aggregate_attr = seq_attr
107+
def aggregate_attr(self, aggregate_attr: npt.ArrayLike) -> None:
108+
if isinstance(aggregate_attr, Tensor):
109+
self._aggregate_attr = aggregate_attr
110110
else:
111-
self._aggregate_attr = torch.tensor(seq_attr)
111+
self._aggregate_attr = torch.tensor(aggregate_attr)
112112
# IDEA: in the future we might want to support higher dim seq_attr
113113
# (e.g. attention w.r.t. multiple layers, gradients w.r.t. different classes)
114114
assert len(self._aggregate_attr.shape) == 1, "seq_attr must be a 1D tensor"
@@ -121,13 +121,13 @@ def element_attr(self) -> Optional[Tensor]:
121121
return self._element_attr
122122

123123
@element_attr.setter
124-
def element_attr(self, token_attr: Optional[npt.ArrayLike]) -> None:
125-
if token_attr is None:
124+
def element_attr(self, element_attr: Optional[npt.ArrayLike]) -> None:
125+
if element_attr is None:
126126
self._element_attr = None
127-
elif isinstance(token_attr, Tensor):
128-
self._element_attr = token_attr
127+
elif isinstance(element_attr, Tensor):
128+
self._element_attr = element_attr
129129
else:
130-
self._element_attr = torch.tensor(token_attr)
130+
self._element_attr = torch.tensor(element_attr)
131131

132132
if self._element_attr is not None:
133133
# IDEA: in the future we might want to support higher dim seq_attr
@@ -146,7 +146,9 @@ def target_values(self) -> Optional[List[TTargetValue]]:
146146
return self._target_values
147147

148148
@target_values.setter
149-
def target_values(self, target_values: Optional[npt.ArrayLike]) -> None:
149+
def target_values(
150+
self, target_values: Optional[Union[npt.ArrayLike, List[TTargetValue]]]
151+
) -> None:
150152
if target_values is None:
151153
self._target_values = None
152154
elif isinstance(target_values, (Tensor, np.ndarray)):

0 commit comments

Comments
 (0)