55import itertools
66import math
77import warnings
8- from typing import Any , Callable , cast , Iterable , List , Optional , Sequence , Tuple , Union
8+ from collections import defaultdict
9+ from typing import (
10+ Any ,
11+ Callable ,
12+ cast ,
13+ Dict ,
14+ Iterable ,
15+ List ,
16+ Optional ,
17+ Sequence ,
18+ Tuple ,
19+ Union ,
20+ )
921
1022import torch
1123from captum ._utils .common import (
@@ -766,13 +778,53 @@ def _prevResultTupleToFormattedAttr(
766778 formatted_attr = _format_output (is_inputs_tuple , attrib )
767779 return formatted_attr
768780
781+ def _update_current_tensors (
782+ self ,
783+ current_tensors : Tuple [Tensor , ...],
784+ input_tensors : Tuple [Tensor , ...],
785+ feature_index : int ,
786+ mask : Tuple [Tensor , ...],
787+ feat_tensor_index_map : Dict [int , List [int ]],
788+ ) -> Tuple [Tensor , ...]:
789+ feat_list = feat_tensor_index_map [feature_index ]
790+ output_tensors = []
791+ for i in range (len (current_tensors )):
792+ if i in feat_list :
793+ output_tensors .append (
794+ current_tensors [i ]
795+ * (~ (mask [i ] == feature_index )).to (current_tensors [i ].dtype )
796+ + input_tensors [i ]
797+ * (mask [i ] == feature_index ).to (input_tensors [i ].dtype )
798+ )
799+
800+ else :
801+ output_tensors .append (current_tensors [i ])
802+ return tuple (output_tensors )
803+
804+ def _construct_selected_mask (
805+ self ,
806+ feature_index : int ,
807+ mask : Tuple [Tensor , ...],
808+ empty_mask : Tuple [Tensor , ...],
809+ feat_tensor_index_map : Dict [int , List [int ]],
810+ device : torch .device ,
811+ ) -> Tuple [Tensor , ...]:
812+ feat_list = feat_tensor_index_map [feature_index ]
813+ output_mask = []
814+ for i in range (len (mask )):
815+ if i in feat_list :
816+ output_mask .append ((mask [i ] == feature_index ).to (device ).unsqueeze (0 ))
817+ else :
818+ output_mask .append (empty_mask [i ])
819+ return tuple (output_mask )
820+
769821 def _perturbation_generator (
770822 self ,
771823 inputs : Tuple [Tensor , ...],
772824 additional_args : Optional [Tuple [object , ...]],
773825 target : TargetType ,
774826 baselines : Tuple [Tensor , ...],
775- input_masks : TensorOrTupleOfTensorsGeneric ,
827+ input_masks : Tuple [ Tensor , ...] ,
776828 feature_permutation : Sequence [int ],
777829 perturbations_per_eval : int ,
778830 ) -> Iterable [Tuple [Tuple [Tensor , ...], object , TargetType , Tuple [Tensor , ...]]]:
@@ -792,29 +844,45 @@ def _perturbation_generator(
792844 if additional_args is not None
793845 else None
794846 )
847+ feat_tensor_index_map = defaultdict (list )
848+ for i in range (len (input_masks )):
849+ for elem in input_masks [i ].view (- 1 ):
850+ feat_tensor_index_map [elem .item ()].append (i )
851+ empty_masks = tuple (torch .zeros_like (elem ).unsqueeze (0 ) for elem in input_masks )
852+
795853 target_repeated = _expand_target (target , perturbations_per_eval )
796854 for i in range (len (feature_permutation )):
797- current_tensors = tuple (
798- current * (~ (mask == feature_permutation [i ])).to (current .dtype )
799- + input * (mask == feature_permutation [i ]).to (input .dtype )
800- for input , current , mask in zip (inputs , current_tensors , input_masks )
855+ current_tensors = self ._update_current_tensors (
856+ current_tensors = current_tensors ,
857+ input_tensors = inputs ,
858+ feature_index = feature_permutation [i ],
859+ mask = input_masks ,
860+ feat_tensor_index_map = feat_tensor_index_map ,
801861 )
802862 current_tensors_list .append (current_tensors )
803863 current_mask_list .append (
804- tuple (
805- (mask == feature_permutation [i ]).to (inputs [0 ].device )
806- for mask in input_masks
864+ self ._construct_selected_mask (
865+ feature_index = feature_permutation [i ],
866+ mask = input_masks ,
867+ empty_mask = empty_masks ,
868+ feat_tensor_index_map = feat_tensor_index_map ,
869+ device = inputs [0 ].device ,
807870 )
808871 )
872+
809873 if len (current_tensors_list ) == perturbations_per_eval :
810- combined_inputs = tuple (
811- torch .cat (aligned_tensors , dim = 0 )
812- for aligned_tensors in zip (* current_tensors_list )
813- )
814- combined_masks = tuple (
815- torch .stack (aligned_masks , dim = 0 )
816- for aligned_masks in zip (* current_mask_list )
817- )
874+ if len (current_tensors_list ) > 1 :
875+ combined_inputs = tuple (
876+ torch .cat (aligned_tensors , dim = 0 )
877+ for aligned_tensors in zip (* current_tensors_list )
878+ )
879+ combined_masks = tuple (
880+ torch .cat (aligned_masks , dim = 0 )
881+ for aligned_masks in zip (* current_mask_list )
882+ )
883+ else :
884+ combined_inputs = current_tensors_list [0 ]
885+ combined_masks = current_mask_list [0 ]
818886 yield (
819887 combined_inputs ,
820888 additional_args_repeated ,
@@ -840,7 +908,7 @@ def _perturbation_generator(
840908 for aligned_tensors in zip (* current_tensors_list )
841909 )
842910 combined_masks = tuple (
843- torch .stack (aligned_masks , dim = 0 )
911+ torch .cat (aligned_masks , dim = 0 )
844912 for aligned_masks in zip (* current_mask_list )
845913 )
846914 yield (
0 commit comments