33# pyre-strict
44
55import math
6- from typing import Any , Callable , cast , Generator , List , Optional , Tuple , TypeVar , Union
6+ from typing import (
7+ Any ,
8+ Callable ,
9+ cast ,
10+ Dict ,
11+ Generator ,
12+ List ,
13+ Optional ,
14+ Tuple ,
15+ TypeVar ,
16+ Union ,
17+ )
718
819import torch
920from captum ._utils .common import (
@@ -465,13 +476,21 @@ def _attribute_with_cross_tensor_feature_masks(
465476 attrib_type : dtype ,
466477 ** kwargs : Any ,
467478 ) -> Tuple [List [Tensor ], List [Tensor ]]:
479+ feature_idx_to_tensor_idx : Dict [int , List [int ]] = {}
480+ for i , mask in enumerate (formatted_feature_mask ):
481+ for feature_idx in torch .unique (mask ):
482+ if feature_idx .item () not in feature_idx_to_tensor_idx .keys ():
483+ feature_idx_to_tensor_idx [feature_idx .item ()] = []
484+ feature_idx_to_tensor_idx [feature_idx .item ()].append (i )
485+
468486 for (
469487 current_inputs ,
470488 current_mask ,
471489 ) in self ._ablation_generator (
472490 formatted_inputs ,
473491 baselines ,
474492 formatted_feature_mask ,
493+ feature_idx_to_tensor_idx ,
475494 ** kwargs ,
476495 ):
477496 # modified_eval has (n_feature_perturbed * n_outputs) elements
@@ -511,11 +530,12 @@ def _ablation_generator(
511530 inputs : Tuple [Tensor , ...],
512531 baselines : BaselineType ,
513532 input_mask : Tuple [Tensor , ...],
533+ feature_idx_to_tensor_idx : Dict [int , List [int ]],
514534 ** kwargs : Any ,
515535 ) -> Generator [
516536 Tuple [
517537 Tuple [Tensor , ...],
518- Tuple [Tensor , ...],
538+ Tuple [Optional [ Tensor ] , ...],
519539 ],
520540 None ,
521541 None ,
@@ -531,7 +551,11 @@ def _ablation_generator(
531551 for feature_idx in unique_feature_ids :
532552 ablated_inputs , current_masks = (
533553 self ._construct_ablated_input_across_tensors (
534- inputs , input_mask , baselines , feature_idx
554+ inputs ,
555+ input_mask ,
556+ baselines ,
557+ feature_idx ,
558+ feature_idx_to_tensor_idx [feature_idx ],
535559 )
536560 )
537561 yield ablated_inputs , current_masks
@@ -542,18 +566,17 @@ def _construct_ablated_input_across_tensors(
542566 input_mask : Tuple [Tensor , ...],
543567 baselines : BaselineType ,
544568 feature_idx : int ,
545- ) -> Tuple [Tuple [Tensor , ...], Tuple [Tensor , ...]]:
569+ tensor_idxs : List [int ],
570+ ) -> Tuple [Tuple [Tensor , ...], Tuple [Optional [Tensor ], ...]]:
546571
547572 ablated_inputs = []
548573 current_masks = []
549574 for i , input_tensor in enumerate (inputs ):
550- mask = input_mask [i ]
551- tensor_mask = mask == feature_idx
552- if not tensor_mask .any ():
575+ if i not in tensor_idxs :
553576 ablated_inputs .append (input_tensor )
554- current_masks .append (torch . zeros_like ( tensor_mask ) )
577+ current_masks .append (None )
555578 continue
556- tensor_mask = tensor_mask .to (input_tensor .device ).long ()
579+ tensor_mask = ( input_mask [ i ] == feature_idx ) .to (input_tensor .device ).long ()
557580 baseline = baselines [i ] if isinstance (baselines , tuple ) else baselines
558581 if isinstance (baseline , torch .Tensor ):
559582 baseline = baseline .reshape (
@@ -1173,7 +1196,7 @@ def _process_ablated_out(
11731196 def _process_ablated_out_full (
11741197 self ,
11751198 modified_eval : Tensor ,
1176- current_mask : Tuple [Tensor , ...],
1199+ current_mask : Tuple [Optional [ Tensor ] , ...],
11771200 flattened_initial_eval : Tensor ,
11781201 inputs : TensorOrTupleOfTensorsGeneric ,
11791202 n_outputs : int ,
@@ -1195,9 +1218,10 @@ def _process_ablated_out_full(
11951218
11961219 if self .use_weights :
11971220 for weight , mask in zip (weights , current_mask ):
1198- weight += mask .float ().sum (dim = 0 )
1221+ if mask is not None :
1222+ weight += mask .float ().sum (dim = 0 )
11991223 for i , mask in enumerate (current_mask ):
1200- if inputs [i ].numel () == 0 :
1224+ if mask is None or inputs [i ].numel () == 0 :
12011225 continue
12021226 eval_diff = eval_diff .reshape (
12031227 eval_diff_shape + (inputs [i ].dim () - 1 ) * (1 ,)
0 commit comments