Skip to content

Commit 4236431

Browse files
cyrjanometa-codesync[bot]
authored andcommitted
Make _should_skip_inputs_and_warn a free function (#1667)
Summary: Pull Request resolved: #1667 This diff makes a change to the `_should_skip_inputs_and_warn` function in the `captum/attr/_core/feature_ablation.py` file. The function is now a free function instead of being a method of a class. The function checks two conditions that would cause a feature group to be skipped during attribution computation: 1. If `min_examples_per_batch_grouped` is specified and any input tensor in the feature group has a batch size (0th dimension) smaller than this threshold. 2. If all input tensors in the feature group are empty Reviewed By: sarahtranfb Differential Revision: D87300652 fbshipit-source-id: 92b77da5e6521657897e54951a9fee910608e23d
1 parent ef19600 commit 4236431

File tree

2 files changed

+119
-28
lines changed

2 files changed

+119
-28
lines changed

captum/attr/_core/feature_ablation.py

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,52 @@ def check_output_shape_valid(
219219
)
220220

221221

222+
def _should_skip_inputs_and_warn(
223+
current_feature_idxs: List[int],
224+
feature_idx_to_tensor_idx: Dict[int, List[int]],
225+
formatted_inputs: Tuple[Tensor, ...],
226+
min_examples_per_batch_grouped: Optional[int] = None,
227+
) -> bool:
228+
"""
229+
Determines whether a feature group should be skipped during attribution computation.
230+
231+
This method checks two conditions that would cause a feature group to be skipped:
232+
1. If min_examples_per_batch_grouped is specified and any input tensor in the
233+
feature group has a batch size (0th dimension) smaller than this threshold.
234+
2. If all input tensors in the feature group are empty (contain no elements).
235+
236+
Returns:
237+
bool: True if the feature group should be skipped, False otherwise.
238+
"""
239+
should_skip = False
240+
all_empty = True
241+
tensor_idx_list = []
242+
for feature_idx in current_feature_idxs:
243+
tensor_idx_list += feature_idx_to_tensor_idx[feature_idx]
244+
for tensor_idx in set(tensor_idx_list):
245+
if all_empty and torch.numel(formatted_inputs[tensor_idx]) != 0:
246+
all_empty = False
247+
if min_examples_per_batch_grouped is not None and (
248+
formatted_inputs[tensor_idx].shape[0] < min_examples_per_batch_grouped
249+
):
250+
should_skip = True
251+
break
252+
if should_skip:
253+
logger.warning(
254+
f"Skipping feature group {current_feature_idxs} since it contains "
255+
f"at least one input tensor with 0th dim less than "
256+
f"{min_examples_per_batch_grouped}"
257+
)
258+
return True
259+
if all_empty:
260+
logger.info(
261+
f"Skipping feature group {current_feature_idxs} since all "
262+
f"input tensors are empty"
263+
)
264+
return True
265+
return False
266+
267+
222268
class FeatureAblation(PerturbationAttribution):
223269
"""
224270
A perturbation based approach to computing attribution, involving
@@ -688,34 +734,12 @@ def _should_skip_inputs_and_warn(
688734
feature_idx_to_tensor_idx: Dict[int, List[int]],
689735
formatted_inputs: Tuple[Tensor, ...],
690736
) -> bool:
691-
should_skip = False
692-
all_empty = True
693-
tensor_idx_list = []
694-
for feature_idx in current_feature_idxs:
695-
tensor_idx_list += feature_idx_to_tensor_idx[feature_idx]
696-
for tensor_idx in set(tensor_idx_list):
697-
if all_empty and torch.numel(formatted_inputs[tensor_idx]) != 0:
698-
all_empty = False
699-
if self._min_examples_per_batch_grouped is not None and (
700-
formatted_inputs[tensor_idx].shape[0]
701-
< cast(int, self._min_examples_per_batch_grouped)
702-
):
703-
should_skip = True
704-
break
705-
if should_skip:
706-
logger.warning(
707-
f"Skipping feature group {current_feature_idxs} since it contains "
708-
f"at least one input tensor with 0th dim less than "
709-
f"{self._min_examples_per_batch_grouped}"
710-
)
711-
return True
712-
if all_empty:
713-
logger.info(
714-
f"Skipping feature group {current_feature_idxs} since all "
715-
f"input tensors are empty"
716-
)
717-
return True
718-
return False
737+
return _should_skip_inputs_and_warn(
738+
current_feature_idxs=current_feature_idxs,
739+
feature_idx_to_tensor_idx=feature_idx_to_tensor_idx,
740+
formatted_inputs=formatted_inputs,
741+
min_examples_per_batch_grouped=self._min_examples_per_batch_grouped,
742+
)
719743

720744
def _construct_ablated_input_across_tensors(
721745
self,

tests/attr/test_feature_ablation.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
1515
from captum.attr._core.feature_ablation import (
1616
_parse_forward_out,
17+
_should_skip_inputs_and_warn,
1718
check_output_shape_valid,
1819
FeatureAblation,
1920
format_result,
@@ -1086,5 +1087,71 @@ def test_invalid_batch_size_not_divisible_by_num_examples(self) -> None:
10861087
)
10871088

10881089

1090+
class TestShouldSkipInputsAndWarn(BaseTest):
1091+
def test_skip_when_batch_size_less_than_min_examples(self) -> None:
1092+
current_feature_idxs = [0, 1]
1093+
feature_idx_to_tensor_idx = {0: [0], 1: [0]}
1094+
formatted_inputs = (torch.tensor([[1.0, 2.0], [3.0, 4.0]]),)
1095+
min_examples_per_batch_grouped = 3
1096+
1097+
with unittest.mock.patch(
1098+
"captum.attr._core.feature_ablation.logger"
1099+
) as mock_logger:
1100+
result = _should_skip_inputs_and_warn(
1101+
current_feature_idxs,
1102+
feature_idx_to_tensor_idx,
1103+
formatted_inputs,
1104+
min_examples_per_batch_grouped,
1105+
)
1106+
1107+
self.assertTrue(result)
1108+
mock_logger.warning.assert_called_once()
1109+
1110+
def test_no_skip_when_batch_size_equal_to_min_examples(self) -> None:
1111+
current_feature_idxs = [0, 1]
1112+
feature_idx_to_tensor_idx = {0: [0], 1: [0]}
1113+
formatted_inputs = (torch.tensor([[1.0, 2.0], [3.0, 4.0]]),)
1114+
min_examples_per_batch_grouped = 2
1115+
1116+
result = _should_skip_inputs_and_warn(
1117+
current_feature_idxs,
1118+
feature_idx_to_tensor_idx,
1119+
formatted_inputs,
1120+
min_examples_per_batch_grouped,
1121+
)
1122+
1123+
self.assertFalse(result)
1124+
1125+
def test_skip_when_all_tensors_empty(self) -> None:
1126+
current_feature_idxs = [0]
1127+
feature_idx_to_tensor_idx = {0: [0]}
1128+
formatted_inputs = (torch.tensor([]),)
1129+
1130+
with unittest.mock.patch(
1131+
"captum.attr._core.feature_ablation.logger"
1132+
) as mock_logger:
1133+
result = _should_skip_inputs_and_warn(
1134+
current_feature_idxs,
1135+
feature_idx_to_tensor_idx,
1136+
formatted_inputs,
1137+
)
1138+
1139+
self.assertTrue(result)
1140+
mock_logger.info.assert_called_once()
1141+
1142+
def test_no_skip_when_tensors_not_empty(self) -> None:
1143+
current_feature_idxs = [0, 1]
1144+
feature_idx_to_tensor_idx = {0: [0], 1: [0]}
1145+
formatted_inputs = (torch.tensor([[1.0, 2.0], [3.0, 4.0]]),)
1146+
1147+
result = _should_skip_inputs_and_warn(
1148+
current_feature_idxs,
1149+
feature_idx_to_tensor_idx,
1150+
formatted_inputs,
1151+
)
1152+
1153+
self.assertFalse(result)
1154+
1155+
10891156
if __name__ == "__main__":
10901157
unittest.main()

0 commit comments

Comments
 (0)