Skip to content

Commit 784fc48

Browse files
Vivek Miglanifacebook-github-bot
authored andcommitted
Improve Shapley Value Perturbation Construction Performance (#1635)
Summary: Pull Request resolved: #1635 Improve perturbation generation in Shapley Value to avoid reconstructing tensors that match previous iteration Reviewed By: jjuncho Differential Revision: D80168924 fbshipit-source-id: df1cbd5423f5b64e0824e587377b2d0d29f60926
1 parent 3ec6da4 commit 784fc48

File tree

1 file changed

+86
-18
lines changed

1 file changed

+86
-18
lines changed

captum/attr/_core/shapley_value.py

Lines changed: 86 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,19 @@
55
import itertools
66
import math
77
import warnings
8-
from typing import Any, Callable, cast, Iterable, List, Optional, Sequence, Tuple, Union
8+
from collections import defaultdict
9+
from typing import (
10+
Any,
11+
Callable,
12+
cast,
13+
Dict,
14+
Iterable,
15+
List,
16+
Optional,
17+
Sequence,
18+
Tuple,
19+
Union,
20+
)
921

1022
import torch
1123
from captum._utils.common import (
@@ -766,13 +778,53 @@ def _prevResultTupleToFormattedAttr(
766778
formatted_attr = _format_output(is_inputs_tuple, attrib)
767779
return formatted_attr
768780

781+
def _update_current_tensors(
782+
self,
783+
current_tensors: Tuple[Tensor, ...],
784+
input_tensors: Tuple[Tensor, ...],
785+
feature_index: int,
786+
mask: Tuple[Tensor, ...],
787+
feat_tensor_index_map: Dict[int, List[int]],
788+
) -> Tuple[Tensor, ...]:
789+
feat_list = feat_tensor_index_map[feature_index]
790+
output_tensors = []
791+
for i in range(len(current_tensors)):
792+
if i in feat_list:
793+
output_tensors.append(
794+
current_tensors[i]
795+
* (~(mask[i] == feature_index)).to(current_tensors[i].dtype)
796+
+ input_tensors[i]
797+
* (mask[i] == feature_index).to(input_tensors[i].dtype)
798+
)
799+
800+
else:
801+
output_tensors.append(current_tensors[i])
802+
return tuple(output_tensors)
803+
804+
def _construct_selected_mask(
805+
self,
806+
feature_index: int,
807+
mask: Tuple[Tensor, ...],
808+
empty_mask: Tuple[Tensor, ...],
809+
feat_tensor_index_map: Dict[int, List[int]],
810+
device: torch.device,
811+
) -> Tuple[Tensor, ...]:
812+
feat_list = feat_tensor_index_map[feature_index]
813+
output_mask = []
814+
for i in range(len(mask)):
815+
if i in feat_list:
816+
output_mask.append((mask[i] == feature_index).to(device).unsqueeze(0))
817+
else:
818+
output_mask.append(empty_mask[i])
819+
return tuple(output_mask)
820+
769821
def _perturbation_generator(
770822
self,
771823
inputs: Tuple[Tensor, ...],
772824
additional_args: Optional[Tuple[object, ...]],
773825
target: TargetType,
774826
baselines: Tuple[Tensor, ...],
775-
input_masks: TensorOrTupleOfTensorsGeneric,
827+
input_masks: Tuple[Tensor, ...],
776828
feature_permutation: Sequence[int],
777829
perturbations_per_eval: int,
778830
) -> Iterable[Tuple[Tuple[Tensor, ...], object, TargetType, Tuple[Tensor, ...]]]:
@@ -792,29 +844,45 @@ def _perturbation_generator(
792844
if additional_args is not None
793845
else None
794846
)
847+
feat_tensor_index_map = defaultdict(list)
848+
for i in range(len(input_masks)):
849+
for elem in input_masks[i].view(-1):
850+
feat_tensor_index_map[elem.item()].append(i)
851+
empty_masks = tuple(torch.zeros_like(elem).unsqueeze(0) for elem in input_masks)
852+
795853
target_repeated = _expand_target(target, perturbations_per_eval)
796854
for i in range(len(feature_permutation)):
797-
current_tensors = tuple(
798-
current * (~(mask == feature_permutation[i])).to(current.dtype)
799-
+ input * (mask == feature_permutation[i]).to(input.dtype)
800-
for input, current, mask in zip(inputs, current_tensors, input_masks)
855+
current_tensors = self._update_current_tensors(
856+
current_tensors=current_tensors,
857+
input_tensors=inputs,
858+
feature_index=feature_permutation[i],
859+
mask=input_masks,
860+
feat_tensor_index_map=feat_tensor_index_map,
801861
)
802862
current_tensors_list.append(current_tensors)
803863
current_mask_list.append(
804-
tuple(
805-
(mask == feature_permutation[i]).to(inputs[0].device)
806-
for mask in input_masks
864+
self._construct_selected_mask(
865+
feature_index=feature_permutation[i],
866+
mask=input_masks,
867+
empty_mask=empty_masks,
868+
feat_tensor_index_map=feat_tensor_index_map,
869+
device=inputs[0].device,
807870
)
808871
)
872+
809873
if len(current_tensors_list) == perturbations_per_eval:
810-
combined_inputs = tuple(
811-
torch.cat(aligned_tensors, dim=0)
812-
for aligned_tensors in zip(*current_tensors_list)
813-
)
814-
combined_masks = tuple(
815-
torch.stack(aligned_masks, dim=0)
816-
for aligned_masks in zip(*current_mask_list)
817-
)
874+
if len(current_tensors_list) > 1:
875+
combined_inputs = tuple(
876+
torch.cat(aligned_tensors, dim=0)
877+
for aligned_tensors in zip(*current_tensors_list)
878+
)
879+
combined_masks = tuple(
880+
torch.cat(aligned_masks, dim=0)
881+
for aligned_masks in zip(*current_mask_list)
882+
)
883+
else:
884+
combined_inputs = current_tensors_list[0]
885+
combined_masks = current_mask_list[0]
818886
yield (
819887
combined_inputs,
820888
additional_args_repeated,
@@ -840,7 +908,7 @@ def _perturbation_generator(
840908
for aligned_tensors in zip(*current_tensors_list)
841909
)
842910
combined_masks = tuple(
843-
torch.stack(aligned_masks, dim=0)
911+
torch.cat(aligned_masks, dim=0)
844912
for aligned_masks in zip(*current_mask_list)
845913
)
846914
yield (

0 commit comments

Comments
 (0)