1
1
#!/usr/bin/env python3
2
2
3
3
# pyre-strict
4
- from typing import Any , Callable , Optional , Tuple , Union
4
+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
5
5
6
6
import numpy as np
7
7
import torch
@@ -267,6 +267,7 @@ def attribute( # type: ignore
267
267
shift_counts = tuple (shift_counts ),
268
268
strides = strides ,
269
269
show_progress = show_progress ,
270
+ enable_cross_tensor_attribution = True ,
270
271
)
271
272
272
273
def attribute_future (self ) -> None :
@@ -310,6 +311,7 @@ def _construct_ablated_input(
310
311
kwargs ["sliding_window_tensors" ],
311
312
kwargs ["strides" ],
312
313
kwargs ["shift_counts" ],
314
+ is_expanded_input = True ,
313
315
)
314
316
for j in range (start_feature , end_feature )
315
317
],
@@ -327,11 +329,12 @@ def _construct_ablated_input(
327
329
328
330
def _occlusion_mask (
329
331
self ,
330
- expanded_input : Tensor ,
332
+ input : Tensor ,
331
333
ablated_feature_num : int ,
332
334
sliding_window_tsr : Tensor ,
333
335
strides : Union [int , Tuple [int , ...]],
334
336
shift_counts : Tuple [int , ...],
337
+ is_expanded_input : bool ,
335
338
) -> Tensor :
336
339
"""
337
340
This constructs the current occlusion mask, which is the appropriate
@@ -365,8 +368,9 @@ def _occlusion_mask(
365
368
current_index .append ((remaining_total % shift_count ) * stride )
366
369
remaining_total = remaining_total // shift_count
367
370
371
+ dim = 2 if is_expanded_input else 1
368
372
remaining_padding = np .subtract (
369
- expanded_input .shape [2 :], np .add (current_index , sliding_window_tsr .shape )
373
+ input .shape [dim :], np .add (current_index , sliding_window_tsr .shape )
370
374
)
371
375
pad_values = [
372
376
val for pair in zip (remaining_padding , current_index ) for val in pair
@@ -391,3 +395,74 @@ def _get_feature_counts(
391
395
) -> Tuple [int , ...]:
392
396
"""return the numbers of possible input features"""
393
397
return tuple (np .prod (counts ).astype (int ) for counts in kwargs ["shift_counts" ])
398
+
399
+ def _get_feature_idx_to_tensor_idx (
400
+ self , formatted_feature_mask : Tuple [Tensor , ...], ** kwargs : Any
401
+ ) -> Dict [int , List [int ]]:
402
+ feature_idx_to_tensor_idx = {}
403
+ curr_feature_idx = 0
404
+ for i , shift_count in enumerate (kwargs ["shift_counts" ]):
405
+ num_features = int (np .prod (shift_count ))
406
+ for _ in range (num_features ):
407
+ feature_idx_to_tensor_idx [curr_feature_idx ] = [i ]
408
+ curr_feature_idx += 1
409
+ return feature_idx_to_tensor_idx
410
+
411
+ def _construct_ablated_input_across_tensors (
412
+ self ,
413
+ inputs : Tuple [Tensor , ...],
414
+ input_mask : Tuple [Tensor , ...],
415
+ baselines : BaselineType ,
416
+ feature_idxs : List [int ],
417
+ feature_idx_to_tensor_idx : Dict [int , List [int ]],
418
+ current_num_ablated_features : int ,
419
+ ** kwargs : Any ,
420
+ ) -> Tuple [Tuple [Tensor , ...], Tuple [Optional [Tensor ], ...]]:
421
+ ablated_inputs = []
422
+ current_masks : List [Optional [Tensor ]] = []
423
+ tensor_idxs = {
424
+ tensor_idx
425
+ for sublist in (
426
+ feature_idx_to_tensor_idx [feature_idx ] for feature_idx in feature_idxs
427
+ )
428
+ for tensor_idx in sublist
429
+ }
430
+
431
+ for i , input_tensor in enumerate (inputs ):
432
+ if i not in tensor_idxs :
433
+ ablated_inputs .append (input_tensor )
434
+ current_masks .append (None )
435
+ continue
436
+ tensor_mask = []
437
+ ablated_input = input_tensor .clone ()
438
+ baseline = baselines [i ] if isinstance (baselines , tuple ) else baselines
439
+ for j , feature_idx in enumerate (feature_idxs ):
440
+ original_input_size = (
441
+ input_tensor .shape [0 ] // current_num_ablated_features
442
+ )
443
+ start_idx = j * original_input_size
444
+ end_idx = (j + 1 ) * original_input_size
445
+
446
+ no_mask = feature_idx_to_tensor_idx [feature_idx ][0 ] != i
447
+ if j > 0 and no_mask :
448
+ tensor_mask .append (torch .zeros_like (tensor_mask [- 1 ]))
449
+ continue
450
+ mask = self ._occlusion_mask (
451
+ ablated_input ,
452
+ feature_idx ,
453
+ kwargs ["sliding_window_tensors" ][i ],
454
+ kwargs ["strides" ][i ],
455
+ kwargs ["shift_counts" ][i ],
456
+ is_expanded_input = False ,
457
+ )
458
+ if no_mask :
459
+ tensor_mask .append (torch .zeros_like (mask ))
460
+ continue
461
+ tensor_mask .append (mask )
462
+ assert baseline is not None , "baseline must be provided"
463
+ ablated_input [start_idx :end_idx ] = input_tensor [start_idx :end_idx ] * (
464
+ torch .ones (1 , dtype = torch .long , device = input_tensor .device ) - mask
465
+ ) + (baseline * mask .to (input_tensor .dtype ))
466
+ current_masks .append (torch .stack (tensor_mask , dim = 0 ))
467
+ ablated_inputs .append (ablated_input )
468
+ return tuple (ablated_inputs ), tuple (current_masks )
0 commit comments