Skip to content

Commit 56abc96

Browse files
cyrjanofacebook-github-bot
authored andcommitted
Fixes mypy errors. (#1058)
Summary: - Fixes issues in visualization related to original_image being optional. - Ignore typing error in test. - Move mypy run for tests to the end as they are lower priority. Pull Request resolved: #1058 Test Plan: run ./scripts/run_mypy.sh Comes back clean. Reviewed By: NarineK Differential Revision: D41037090 Pulled By: cyrjano fbshipit-source-id: 896160e32e203439b024ad9122db7a991c4f4136
1 parent 5543b4a commit 56abc96

File tree

3 files changed

+15
-10
lines changed

3 files changed

+15
-10
lines changed

captum/attr/_utils/visualization.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -236,10 +236,11 @@ def visualize_image_attr(
236236
if original_image is not None:
237237
if np.max(original_image) <= 1.0:
238238
original_image = _prepare_image(original_image * 255)
239-
else:
240-
assert (
241-
ImageVisualizationMethod[method] == ImageVisualizationMethod.heat_map
242-
), "Original Image must be provided for any visualization other than heatmap."
239+
elif ImageVisualizationMethod[method] != ImageVisualizationMethod.heat_map:
240+
raise ValueError(
241+
"Original Image must be provided for"
242+
"any visualization other than heatmap."
243+
)
243244

244245
# Remove ticks and tick labels from plot.
245246
plt_axis.xaxis.set_ticks_position("none")
@@ -251,6 +252,9 @@ def visualize_image_attr(
251252
heat_map = None
252253
# Show original image
253254
if ImageVisualizationMethod[method] == ImageVisualizationMethod.original_image:
255+
assert (
256+
original_image is not None
257+
), "Original image expected for original_image method."
254258
if len(original_image.shape) > 2 and original_image.shape[2] == 1:
255259
original_image = np.squeeze(original_image, axis=2)
256260
plt_axis.imshow(original_image)
@@ -284,6 +288,9 @@ def visualize_image_attr(
284288
ImageVisualizationMethod[method]
285289
== ImageVisualizationMethod.blended_heat_map
286290
):
291+
assert (
292+
original_image is not None
293+
), "Original Image expected for blended_heat_map method."
287294
plt_axis.imshow(np.mean(original_image, axis=2), cmap="gray")
288295
heat_map = plt_axis.imshow(
289296
norm_attr, cmap=cmap, vmin=vmin, vmax=vmax, alpha=alpha_overlay
@@ -684,10 +691,7 @@ def _plot_attrs_as_axvspan(attr_vals, x_vals, ax):
684691
plt_axis[0].set_prop_cycle(cycler)
685692

686693
for chan in range(num_channels):
687-
if channel_labels is not None:
688-
label = channel_labels[chan]
689-
else:
690-
label = None
694+
label = channel_labels[chan] if channel_labels else None
691695
plt_axis[0].plot(x_values, data[chan, :], label=label, **pyplot_kwargs)
692696

693697
_plot_attrs_as_axvspan(norm_attr, x_values, plt_axis[0])

scripts/run_mypy.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@ mypy -p captum.metrics --ignore-missing-imports --allow-redefinition
1010
mypy -p captum.robust --ignore-missing-imports --allow-redefinition
1111
mypy -p captum.concept --ignore-missing-imports --allow-redefinition
1212
mypy -p captum.influence --ignore-missing-imports --allow-redefinition
13-
mypy -p tests --ignore-missing-imports --allow-redefinition
1413
mypy -p captum._utils --ignore-missing-imports --allow-redefinition
14+
mypy -p tests --ignore-missing-imports --allow-redefinition

tests/attr/test_gradient_shap.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,9 @@ def generate_baselines_returns_array() -> ndarray:
156156
_assert_attribution_delta(self, (inputs,), (attributions,), n_samples, delta)
157157

158158
with self.assertRaises(AssertionError):
159-
attributions, delta = gradient_shap.attribute(
159+
attributions, delta = gradient_shap.attribute( # type: ignore
160160
inputs,
161+
# Intentionally passing wrong type.
161162
baselines=generate_baselines_returns_array,
162163
target=torch.tensor(1),
163164
n_samples=n_samples,

0 commit comments

Comments
 (0)