Skip to content

Commit 2689f21

Browse files
cyrjanofacebook-github-bot
authored andcommitted
Add support for nested progressbar to tracincp even when tqdm is not available. (#1046)
Summary: This Pull request add support for nested progressbar in SimpleProgress bar (i.e. when tqdm is not installed) by leveraging Python contexts (with statements) that is also implemented in tqdm. To keep `SimpleProgress` simple this is the behaviour: - Add a new line per each update of the parent progress bar. - Guarantee that each update for parent progress bar is refreshed. - Do not write refresh at the ending of parent progress bar (to avoid duplicate lines). This support is used now in TracInCp and TracInCpFast methods. ![Screen Shot 2022-10-11 at 12 43 23 PM](https://user-images.githubusercontent.com/3238673/195717879-1ffc3e4a-a8d4-4f4f-a661-4c11fd93252c.png) Pull Request resolved: #1046 Reviewed By: aobo-y Differential Revision: D40397776 Pulled By: cyrjano fbshipit-source-id: 26316255296a50fc4c80a22658be386769877c5e
1 parent 5f878af commit 2689f21

File tree

5 files changed

+156
-61
lines changed

5 files changed

+156
-61
lines changed

captum/_utils/progress.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from time import time
66
from typing import cast, Iterable, Sized, TextIO
77

8+
from captum._utils.typing import Literal
9+
810
try:
911
from tqdm.auto import tqdm
1012
except ImportError:
@@ -40,6 +42,38 @@ def flush(self, *args, **kwargs):
4042
return self._wrapped_run(self._wrapped.flush, *args, **kwargs)
4143

4244

45+
class NullProgress:
46+
"""Passthrough class that implements the progress API.
47+
48+
This class implements the tqdm and SimpleProgressBar api but
49+
does nothing. This class can be used as a stand-in for an
50+
optional progressbar, most commonly in the case of nested
51+
progress bars.
52+
"""
53+
54+
def __init__(self, iterable: Iterable = None, *args, **kwargs):
55+
del args, kwargs
56+
self.iterable = iterable
57+
58+
def __enter__(self):
59+
return self
60+
61+
def __exit__(self, exc_type, exc_value, exc_traceback) -> Literal[False]:
62+
return False
63+
64+
def __iter__(self):
65+
if not self.iterable:
66+
return
67+
for it in self.iterable:
68+
yield it
69+
70+
def update(self, amount: int = 1):
71+
pass
72+
73+
def close(self):
74+
pass
75+
76+
4377
class SimpleProgress:
4478
def __init__(
4579
self,
@@ -51,10 +85,13 @@ def __init__(
5185
) -> None:
5286
"""
5387
Simple progress output used when tqdm is unavailable.
54-
Same as tqdm, output to stderr channel
88+
Same as tqdm, output to stderr channel.
89+
If you want to do nested Progressbars with simple progress
90+
the parent progress bar should be used as a context
91+
(i.e. with statement) and the nested progress bar should be
92+
created inside this context.
5593
"""
5694
self.cur = 0
57-
5895
self.iterable = iterable
5996
self.total = total
6097
if total is None and hasattr(iterable, "__len__"):
@@ -69,6 +106,16 @@ def __init__(
69106
self.mininterval = mininterval
70107
self.last_print_t = 0.0
71108
self.closed = False
109+
self._is_parent = False
110+
111+
def __enter__(self):
112+
self._is_parent = True
113+
self._refresh()
114+
return self
115+
116+
def __exit__(self, exc_type, exc_value, exc_traceback) -> Literal[False]:
117+
self.close()
118+
return False
72119

73120
def __iter__(self):
74121
if self.closed or not self.iterable:
@@ -87,8 +134,8 @@ def _refresh(self):
87134
else:
88135
# e.g., progress: .....
89136
progress_str += "." * self.cur
90-
91-
print("\r" + progress_str, end="", file=self.file)
137+
end = "\n" if self._is_parent else ""
138+
print("\r" + progress_str, end=end, file=self.file)
92139

93140
def update(self, amount: int = 1):
94141
if self.closed:
@@ -101,7 +148,7 @@ def update(self, amount: int = 1):
101148
self.last_print_t = cur_t
102149

103150
def close(self):
104-
if not self.closed:
151+
if not self.closed and not self._is_parent:
105152
self._refresh()
106153
print(file=self.file) # end with new line
107154
self.closed = True

captum/influence/_core/tracincp.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
_compute_jacobian_wrt_params,
2424
_compute_jacobian_wrt_params_with_sample_wise_trick,
2525
)
26-
from captum._utils.progress import progress
26+
from captum._utils.progress import NullProgress, progress
2727
from captum.influence._core.influence import DataInfluence
2828
from captum.influence._utils.common import (
2929
_format_inputs_dataset,
@@ -1006,13 +1006,6 @@ def _self_influence_by_checkpoints(
10061006
# If `show_progress` is true, create an outer progress bar that keeps track of
10071007
# how many checkpoints have been processed
10081008
if show_progress:
1009-
checkpoints_progress = progress(
1010-
desc=(
1011-
f"Using {self.get_name()} to compute self "
1012-
"influence. Processing checkpoint"
1013-
),
1014-
total=len(self.checkpoints),
1015-
)
10161009
# Try to determine length of inner progress bar if possible, with a default
10171010
# of `None`.
10181011
inputs_dataset_len = None
@@ -1090,17 +1083,29 @@ def get_checkpoint_contribution(checkpoint):
10901083
# We concatenate the contributions from each batch into a single 1D tensor,
10911084
# which represents the contributions for all batches in `inputs_dataset`
10921085

1093-
if show_progress:
1094-
checkpoints_progress.update()
1095-
10961086
return torch.cat(checkpoint_contribution, dim=0)
10971087

1098-
batches_self_tracin_scores = get_checkpoint_contribution(self.checkpoints[0])
1099-
1100-
# The self influence score for all examples is the sum of contributions from
1101-
# each checkpoint
1102-
for checkpoint in self.checkpoints[1:]:
1103-
batches_self_tracin_scores += get_checkpoint_contribution(checkpoint)
1088+
if show_progress:
1089+
checkpoints_progress = progress(
1090+
desc=(
1091+
f"Using {self.get_name()} to compute self "
1092+
"influence. Processing checkpoint"
1093+
),
1094+
total=len(self.checkpoints),
1095+
mininterval=0.0,
1096+
)
1097+
else:
1098+
checkpoints_progress = NullProgress()
1099+
with checkpoints_progress:
1100+
batches_self_tracin_scores = get_checkpoint_contribution(
1101+
self.checkpoints[0]
1102+
)
1103+
checkpoints_progress.update()
1104+
# The self influence score for all examples is the sum of contributions from
1105+
# each checkpoint
1106+
for checkpoint in self.checkpoints[1:]:
1107+
batches_self_tracin_scores += get_checkpoint_contribution(checkpoint)
1108+
checkpoints_progress.update()
11041109

11051110
return batches_self_tracin_scores
11061111

captum/influence/_core/tracincp_fast_rand_proj.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99
from captum._utils.common import _format_inputs, _get_module_from_name, _sort_key_list
1010
from captum._utils.gradient import _gather_distributed_tensors
11-
from captum._utils.progress import progress
11+
from captum._utils.progress import NullProgress, progress
1212

1313
from captum.influence._core.tracincp import (
1414
_influence_route_to_helpers,
@@ -556,13 +556,6 @@ def _self_influence_by_checkpoints(
556556
# If `show_progress` is true, create an outer progress bar that keeps track of
557557
# how many checkpoints have been processed
558558
if show_progress:
559-
checkpoints_progress = progress(
560-
desc=(
561-
f"Using {self.get_name()} to compute self "
562-
"influence. Processing checkpoint"
563-
),
564-
total=len(self.checkpoints),
565-
)
566559
# Try to determine length of inner progress bar if possible, with a default
567560
# of `None`.
568561
inputs_dataset_len = None
@@ -621,20 +614,31 @@ def get_checkpoint_contribution(checkpoint):
621614

622615
# We concatenate the contributions from each batch into a single 1D tensor,
623616
# which represents the contributions for all batches in `inputs_dataset`
624-
625-
if show_progress:
626-
checkpoints_progress.update()
627-
628617
return torch.cat(checkpoint_contribution, dim=0)
629618

630-
batches_self_tracin_scores = get_checkpoint_contribution(self.checkpoints[0])
631-
632-
# The self influence score for all examples is the sum of contributions from
633-
# each checkpoint
634-
for checkpoint in self.checkpoints[1:]:
635-
batches_self_tracin_scores += get_checkpoint_contribution(checkpoint)
619+
if show_progress:
620+
checkpoints_progress = progress(
621+
desc=(
622+
f"Using {self.get_name()} to compute self "
623+
"influence. Processing checkpoint"
624+
),
625+
total=len(self.checkpoints),
626+
mininterval=0.0,
627+
)
628+
else:
629+
checkpoints_progress = NullProgress()
636630

637-
return batches_self_tracin_scores
631+
with checkpoints_progress:
632+
batches_self_tracin_scores = get_checkpoint_contribution(
633+
self.checkpoints[0]
634+
)
635+
checkpoints_progress.update()
636+
# The self influence score for all examples is the sum of contributions from
637+
# each checkpoint
638+
for checkpoint in self.checkpoints[1:]:
639+
batches_self_tracin_scores += get_checkpoint_contribution(checkpoint)
640+
checkpoints_progress.update()
641+
return batches_self_tracin_scores
638642

639643
def self_influence(
640644
self,

tests/influence/_core/test_tracin_show_progress.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,15 @@ def _check_error_msg_multiplicity(
4949
output = mock_stderr.getvalue()
5050
actual_msg_multiplicity = output.count(msg)
5151
assert isinstance(actual_msg_multiplicity, int)
52-
error_msg = f"Error in progress of batches with output: {repr(output)}"
52+
error_msg = (
53+
f"Error in progress of batches with output looking for '{msg}'"
54+
f" at least {msg_multiplicity} times"
55+
f"(found {actual_msg_multiplicity}) in {repr(output)}"
56+
)
5357
if greater_than:
54-
self.assertTrue(actual_msg_multiplicity - msg_multiplicity >= 0, error_msg)
58+
self.assertGreaterEqual(
59+
actual_msg_multiplicity, msg_multiplicity, error_msg
60+
)
5561
else:
5662
self.assertEqual(
5763
actual_msg_multiplicity,
@@ -124,23 +130,6 @@ def test_tracin_show_progress(
124130
# `outer_loop_by_checkpoints` is True. In this case, we should see a
125131
# single outer progress bar over checkpoints, and for every
126132
# checkpoints, a separate progress bar over batches
127-
128-
# In this case, displaying progress involves nested progress
129-
# bars, which are not currently supported by the backup
130-
# `SimpleProgress` that is used if `tqdm` is not installed.
131-
# Therefore, we skip the test in this case.
132-
# TODO: support nested progress bars for `SimpleProgress`
133-
try:
134-
import tqdm # noqa
135-
except ModuleNotFoundError:
136-
raise unittest.SkipTest(
137-
(
138-
"Skipping self influence progress bar tests for "
139-
f"{tracin.get_name()}, because proper displaying "
140-
"requires the tqdm module, which is not installed."
141-
)
142-
)
143-
144133
tracin.self_influence(
145134
DataLoader(train_dataset, batch_size=batch_size),
146135
show_progress=True,

tests/utils/test_progress.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,61 @@
44
import unittest
55
import unittest.mock
66

7-
from captum._utils.progress import progress
7+
from captum._utils.progress import NullProgress, progress
88
from tests.helpers.basic import BaseTest
99

1010

1111
class Test(BaseTest):
12+
@unittest.mock.patch("sys.stderr", new_callable=io.StringIO)
13+
def test_nullprogress(self, mock_stderr) -> None:
14+
count = 0
15+
with NullProgress(["x", "y", "z"]) as np:
16+
for _ in np:
17+
for _ in NullProgress([1, 2, 3]):
18+
count += 1
19+
20+
self.assertEqual(count, 9)
21+
output = mock_stderr.getvalue()
22+
self.assertEqual(output, "")
23+
24+
@unittest.mock.patch("sys.stderr", new_callable=io.StringIO)
25+
def test_nested_progress_tqdm(self, mock_stderr) -> None:
26+
try:
27+
import tqdm # noqa: F401
28+
except ImportError:
29+
raise unittest.SkipTest("Skipping tqdm test, tqdm not available.")
30+
31+
parent_data = ["x", "y", "z"]
32+
test_data = [1, 2, 3]
33+
with progress(parent_data, desc="parent progress") as parent:
34+
for item in parent:
35+
for _ in progress(test_data, desc=f"test progress {item}"):
36+
pass
37+
output = mock_stderr.getvalue()
38+
self.assertIn("parent progress:", output)
39+
for item in parent_data:
40+
self.assertIn(f"test progress {item}:", output)
41+
42+
@unittest.mock.patch("sys.stderr", new_callable=io.StringIO)
43+
def test_nested_simple_progress(self, mock_stderr) -> None:
44+
parent_data = ["x", "y", "z"]
45+
test_data = [1, 2, 3]
46+
with progress(
47+
parent_data, desc="parent progress", use_tqdm=False, mininterval=0.0
48+
) as parent:
49+
for item in parent:
50+
for _ in progress(
51+
test_data, desc=f"test progress {item}", use_tqdm=False
52+
):
53+
pass
54+
55+
output = mock_stderr.getvalue()
56+
self.assertEqual(
57+
output.count("parent progress:"), 5, "5 'parent' progress bar expected"
58+
)
59+
for item in parent_data:
60+
self.assertIn(f"test progress {item}:", output)
61+
1262
@unittest.mock.patch("sys.stderr", new_callable=io.StringIO)
1363
def test_progress_tqdm(self, mock_stderr) -> None:
1464
try:

0 commit comments

Comments
 (0)