Skip to content
Open
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ commands:
sudo apt-key add /var/cuda-repo-ubuntu2004-11-4-local/7fa2af80.pub
sudo apt-get update
sudo dpkg --configure -a
sudo apt-get --yes --force-yes install cuda
sudo apt-get --yes --allow-downgrades --allow-remove-essential --allow-change-held-packages install cuda

jobs:

Expand Down
20 changes: 14 additions & 6 deletions captum/insights/attr_vis/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from io import BytesIO
from typing import Callable, List, Optional, Union

import numpy as np
from captum._utils.common import safe_div
from captum.attr._utils import visualization as viz
from captum.insights.attr_vis._utils.transforms import format_transforms
Expand Down Expand Up @@ -117,12 +118,19 @@ def visualization_type() -> str:
def visualize(self, attribution, data, contribution_frac) -> FeatureOutput:
if self.visualization_transform:
data = self.visualization_transform(data)

data_t, attribution_t = [
t.detach().squeeze().permute((1, 2, 0)).cpu().numpy()
for t in (data, attribution)
]

# [N, C, H, W] if C==3, its expected to be in RGB format
if data.shape[:-2][-1] == 3:
data_t, attribution_t = [
t.detach().squeeze().permute((1, 2, 0)).cpu().numpy()
for t in (data, attribution)
]
# [N, C, H, W] if C==1, its assumed to be a greyscale image
if data.shape[:-2][-1] == 1:
data_t, attribution_t = [
t.detach().squeeze().cpu().numpy() for t in (data, attribution)
]
data_t = np.expand_dims(data_t, axis=-1)
attribution_t = np.expand_dims(attribution_t, axis=-1)
orig_fig, _ = viz.visualize_image_attr(
attribution_t, data_t, method="original_image", use_pyplot=False
)
Expand Down