diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index 06190af6e..2c2a77d8a 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -685,7 +685,7 @@ def _construct_ablated_input_across_tensors( for i, input_tensor in enumerate(inputs): if i not in tensor_idxs: - ablated_inputs.append(input_tensor) + ablated_inputs.append(input_tensor.clone()) current_masks.append(None) continue tensor_mask = [] diff --git a/captum/attr/_core/feature_permutation.py b/captum/attr/_core/feature_permutation.py index b9630ab73..896de1935 100644 --- a/captum/attr/_core/feature_permutation.py +++ b/captum/attr/_core/feature_permutation.py @@ -20,6 +20,7 @@ def _permute_feature(x: Tensor, feature_mask: Tensor) -> Tensor: while (perm == no_perm).all(): perm = torch.randperm(n) + perm = perm.to(x.device) return (x[perm] * feature_mask.to(dtype=x.dtype)) + ( x * feature_mask.bitwise_not().to(dtype=x.dtype) ) @@ -391,7 +392,7 @@ def _construct_ablated_input_across_tensors( for i, input_tensor in enumerate(inputs): if i not in tensor_idxs: current_masks.append(None) - permuted_inputs.append(input_tensor) + permuted_inputs.append(input_tensor.clone()) continue tensor_mask = [] permuted_input = input_tensor.clone()