Skip to content

Commit e5356e4

Browse files
sarahtranfbfacebook-github-bot
authored andcommitted
Update occlusion to new method of constructing ablated batches
Differential Revision: D76483214
1 parent af89779 commit e5356e4

File tree

4 files changed

+101
-24
lines changed

4 files changed

+101
-24
lines changed

captum/_utils/common.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -913,22 +913,6 @@ def _get_max_feature_index(feature_mask: Tuple[Tensor, ...]) -> int:
913913
return int(max(torch.max(mask).item() for mask in feature_mask if mask.numel()))
914914

915915

916-
def _get_feature_idx_to_tensor_idx(
917-
formatted_feature_mask: Tuple[Tensor, ...],
918-
) -> Dict[int, List[int]]:
919-
"""
920-
For a given tuple of tensors, return dict of tensor values to list of tensor indices
921-
they appear in.
922-
"""
923-
feature_idx_to_tensor_idx: Dict[int, List[int]] = {}
924-
for i, mask in enumerate(formatted_feature_mask):
925-
for feature_idx in torch.unique(mask):
926-
if feature_idx.item() not in feature_idx_to_tensor_idx:
927-
feature_idx_to_tensor_idx[feature_idx.item()] = []
928-
feature_idx_to_tensor_idx[feature_idx.item()].append(i)
929-
return feature_idx_to_tensor_idx
930-
931-
932916
def _maybe_expand_parameters(
933917
perturbations_per_eval: int,
934918
formatted_inputs: Tuple[Tensor, ...],

captum/attr/_core/feature_ablation.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
_format_additional_forward_args,
2525
_format_feature_mask,
2626
_format_output,
27-
_get_feature_idx_to_tensor_idx,
2827
_is_tuple,
2928
_maybe_expand_parameters,
3029
_run_forward,
@@ -507,8 +506,8 @@ def _attribute_with_cross_tensor_feature_masks(
507506
perturbations_per_eval: int,
508507
**kwargs: Any,
509508
) -> Tuple[List[Tensor], List[Tensor]]:
510-
feature_idx_to_tensor_idx = _get_feature_idx_to_tensor_idx(
511-
formatted_feature_mask
509+
feature_idx_to_tensor_idx = self._get_feature_idx_to_tensor_idx(
510+
formatted_feature_mask, **kwargs
512511
)
513512
all_feature_idxs = list(feature_idx_to_tensor_idx.keys())
514513

@@ -575,6 +574,7 @@ def _attribute_with_cross_tensor_feature_masks(
575574
current_feature_idxs,
576575
feature_idx_to_tensor_idx,
577576
current_num_ablated_features,
577+
**kwargs,
578578
)
579579
)
580580

@@ -613,6 +613,21 @@ def _attribute_with_cross_tensor_feature_masks(
613613
)
614614
return total_attrib, weights
615615

616+
def _get_feature_idx_to_tensor_idx(
617+
self, formatted_feature_mask: Tuple[Tensor, ...], **kwargs: Any
618+
) -> Dict[int, List[int]]:
619+
"""
620+
For a given tuple of tensors, return dict of tensor values to list of tensor indices
621+
they appear in.
622+
"""
623+
feature_idx_to_tensor_idx: Dict[int, List[int]] = {}
624+
for i, mask in enumerate(formatted_feature_mask):
625+
for feature_idx in torch.unique(mask):
626+
if feature_idx.item() not in feature_idx_to_tensor_idx:
627+
feature_idx_to_tensor_idx[feature_idx.item()] = []
628+
feature_idx_to_tensor_idx[feature_idx.item()].append(i)
629+
return feature_idx_to_tensor_idx
630+
616631
def _should_skip_inputs_and_warn(
617632
self,
618633
current_feature_idxs: List[int],
@@ -656,6 +671,7 @@ def _construct_ablated_input_across_tensors(
656671
feature_idxs: List[int],
657672
feature_idx_to_tensor_idx: Dict[int, List[int]],
658673
current_num_ablated_features: int,
674+
**kwargs: Any,
659675
) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]:
660676
ablated_inputs = []
661677
current_masks: List[Optional[Tensor]] = []
@@ -946,8 +962,8 @@ def _attribute_with_cross_tensor_feature_masks_future(
946962
perturbations_per_eval: int,
947963
**kwargs: Any,
948964
) -> Future[Union[Tensor, Tuple[Tensor, ...]]]:
949-
feature_idx_to_tensor_idx = _get_feature_idx_to_tensor_idx(
950-
formatted_feature_mask
965+
feature_idx_to_tensor_idx = self._get_feature_idx_to_tensor_idx(
966+
formatted_feature_mask, **kwargs
951967
)
952968
all_feature_idxs = list(feature_idx_to_tensor_idx.keys())
953969

@@ -1016,6 +1032,7 @@ def _attribute_with_cross_tensor_feature_masks_future(
10161032
current_feature_idxs,
10171033
feature_idx_to_tensor_idx,
10181034
current_num_ablated_features,
1035+
**kwargs,
10191036
)
10201037
)
10211038

captum/attr/_core/feature_permutation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ def _construct_ablated_input_across_tensors(
377377
feature_idxs: List[int],
378378
feature_idx_to_tensor_idx: Dict[int, List[int]],
379379
current_num_ablated_features: int,
380+
**kwargs: Any,
380381
) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]:
381382
current_masks: List[Optional[Tensor]] = []
382383
tensor_idxs = {

captum/attr/_core/occlusion.py

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22

33
# pyre-strict
4-
from typing import Any, Callable, Optional, Tuple, Union
4+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
55

66
import numpy as np
77
import torch
@@ -267,6 +267,7 @@ def attribute( # type: ignore
267267
shift_counts=tuple(shift_counts),
268268
strides=strides,
269269
show_progress=show_progress,
270+
enable_cross_tensor_attribution=True,
270271
)
271272

272273
def attribute_future(self) -> None:
@@ -310,6 +311,7 @@ def _construct_ablated_input(
310311
kwargs["sliding_window_tensors"],
311312
kwargs["strides"],
312313
kwargs["shift_counts"],
314+
is_expanded_input=True,
313315
)
314316
for j in range(start_feature, end_feature)
315317
],
@@ -327,11 +329,12 @@ def _construct_ablated_input(
327329

328330
def _occlusion_mask(
329331
self,
330-
expanded_input: Tensor,
332+
input: Tensor,
331333
ablated_feature_num: int,
332334
sliding_window_tsr: Tensor,
333335
strides: Union[int, Tuple[int, ...]],
334336
shift_counts: Tuple[int, ...],
337+
is_expanded_input: bool,
335338
) -> Tensor:
336339
"""
337340
This constructs the current occlusion mask, which is the appropriate
@@ -365,8 +368,9 @@ def _occlusion_mask(
365368
current_index.append((remaining_total % shift_count) * stride)
366369
remaining_total = remaining_total // shift_count
367370

371+
dim = 2 if is_expanded_input else 1
368372
remaining_padding = np.subtract(
369-
expanded_input.shape[2:], np.add(current_index, sliding_window_tsr.shape)
373+
input.shape[dim:], np.add(current_index, sliding_window_tsr.shape)
370374
)
371375
pad_values = [
372376
val for pair in zip(remaining_padding, current_index) for val in pair
@@ -391,3 +395,74 @@ def _get_feature_counts(
391395
) -> Tuple[int, ...]:
392396
"""return the numbers of possible input features"""
393397
return tuple(np.prod(counts).astype(int) for counts in kwargs["shift_counts"])
398+
399+
def _get_feature_idx_to_tensor_idx(
400+
self, formatted_feature_mask: Tuple[Tensor, ...], **kwargs: Any
401+
) -> Dict[int, List[int]]:
402+
feature_idx_to_tensor_idx = {}
403+
curr_feature_idx = 0
404+
for i, shift_count in enumerate(kwargs["shift_counts"]):
405+
num_features = int(np.prod(shift_count))
406+
for _ in range(num_features):
407+
feature_idx_to_tensor_idx[curr_feature_idx] = [i]
408+
curr_feature_idx += 1
409+
return feature_idx_to_tensor_idx
410+
411+
def _construct_ablated_input_across_tensors(
412+
self,
413+
inputs: Tuple[Tensor, ...],
414+
input_mask: Tuple[Tensor, ...],
415+
baselines: BaselineType,
416+
feature_idxs: List[int],
417+
feature_idx_to_tensor_idx: Dict[int, List[int]],
418+
current_num_ablated_features: int,
419+
**kwargs: Any,
420+
) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]:
421+
ablated_inputs = []
422+
current_masks: List[Optional[Tensor]] = []
423+
tensor_idxs = {
424+
tensor_idx
425+
for sublist in (
426+
feature_idx_to_tensor_idx[feature_idx] for feature_idx in feature_idxs
427+
)
428+
for tensor_idx in sublist
429+
}
430+
431+
for i, input_tensor in enumerate(inputs):
432+
if i not in tensor_idxs:
433+
ablated_inputs.append(input_tensor)
434+
current_masks.append(None)
435+
continue
436+
tensor_mask = []
437+
ablated_input = input_tensor.clone()
438+
baseline = baselines[i] if isinstance(baselines, tuple) else baselines
439+
for j, feature_idx in enumerate(feature_idxs):
440+
original_input_size = (
441+
input_tensor.shape[0] // current_num_ablated_features
442+
)
443+
start_idx = j * original_input_size
444+
end_idx = (j + 1) * original_input_size
445+
446+
no_mask = feature_idx_to_tensor_idx[feature_idx][0] != i
447+
if j > 0 and no_mask:
448+
tensor_mask.append(torch.zeros_like(tensor_mask[-1]))
449+
continue
450+
mask = self._occlusion_mask(
451+
ablated_input,
452+
feature_idx,
453+
kwargs["sliding_window_tensors"][i],
454+
kwargs["strides"][i],
455+
kwargs["shift_counts"][i],
456+
is_expanded_input=False,
457+
)
458+
if no_mask:
459+
tensor_mask.append(torch.zeros_like(mask))
460+
continue
461+
tensor_mask.append(mask)
462+
assert baseline is not None, "baseline must be provided"
463+
ablated_input[start_idx:end_idx] = input_tensor[start_idx:end_idx] * (
464+
torch.ones(1, dtype=torch.long, device=input_tensor.device) - mask
465+
) + (baseline * mask.to(input_tensor.dtype))
466+
current_masks.append(torch.stack(tensor_mask, dim=0))
467+
ablated_inputs.append(ablated_input)
468+
return tuple(ablated_inputs), tuple(current_masks)

0 commit comments

Comments
 (0)