From 6ed036c326c33b8af35b9a06f8ed23f5fad1c398 Mon Sep 17 00:00:00 2001 From: Yang Jin Date: Mon, 7 Jul 2025 17:20:02 -0700 Subject: [PATCH] Resolve CUDA illegal memory access Summary: Fix the CUDA illegal memory access issue, due to race condition caused by concurrently read/write to the feature index input_tensor. Copy the index input_tensor before use for all tensors, including permuted + non-permuted. Differential Revision: D77902870 --- captum/attr/_core/feature_ablation.py | 2 +- captum/attr/_core/feature_permutation.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index 06190af6e..2c2a77d8a 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -685,7 +685,7 @@ def _construct_ablated_input_across_tensors( for i, input_tensor in enumerate(inputs): if i not in tensor_idxs: - ablated_inputs.append(input_tensor) + ablated_inputs.append(input_tensor.clone()) current_masks.append(None) continue tensor_mask = [] diff --git a/captum/attr/_core/feature_permutation.py b/captum/attr/_core/feature_permutation.py index b9630ab73..896de1935 100644 --- a/captum/attr/_core/feature_permutation.py +++ b/captum/attr/_core/feature_permutation.py @@ -20,6 +20,7 @@ def _permute_feature(x: Tensor, feature_mask: Tensor) -> Tensor: while (perm == no_perm).all(): perm = torch.randperm(n) + perm = perm.to(x.device) return (x[perm] * feature_mask.to(dtype=x.dtype)) + ( x * feature_mask.bitwise_not().to(dtype=x.dtype) ) @@ -391,7 +392,7 @@ def _construct_ablated_input_across_tensors( for i, input_tensor in enumerate(inputs): if i not in tensor_idxs: current_masks.append(None) - permuted_inputs.append(input_tensor) + permuted_inputs.append(input_tensor.clone()) continue tensor_mask = [] permuted_input = input_tensor.clone()