Skip to content

Commit 1cdebca

Browse files
HuanyuZhangfacebook-github-bot
authored andcommitted
Fix "to_standard_module" for Ghost Clipping (meta-pytorch#754)
Summary: Bug on removing attributes associated with GC. Differential Revision: D74019695
1 parent cbc12de commit 1cdebca

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

opacus/grad_sample/gsm_base.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@
2424

2525
logger = logging.getLogger(__name__)
2626

27-
OPACUS_PARAM_MONKEYPATCH_ATTRS = ["_forward_counter", "_current_grad_sample"]
27+
OPACUS_PARAM_MONKEYPATCH_ATTRS = [
28+
"grad_sample",
29+
"_forward_counter",
30+
"_current_grad_sample",
31+
"_norm_sample",
32+
]
2833

2934

3035
class AbstractGradSampleModule(nn.Module, ABC):
@@ -131,18 +136,15 @@ def to_standard_module(self) -> nn.Module:
131136
return self._module
132137

133138
def _close(self):
134-
self.del_grad_sample()
135-
self._clean_up_attributes()
136-
137-
def __repr__(self):
138-
return f"{type(self).__name__}({self._module.__repr__()})"
139-
140-
def _clean_up_attributes(self):
139+
# Clean up attributes
141140
for attr in OPACUS_PARAM_MONKEYPATCH_ATTRS:
142141
for p in self.parameters():
143142
if hasattr(p, attr):
144143
delattr(p, attr)
145144

145+
def __repr__(self):
146+
return f"{type(self).__name__}({self._module.__repr__()})"
147+
146148
def forbid_grad_accumulation(self):
147149
"""
148150
Sets a flag to detect gradient accumulation (multiple forward/backward passes

opacus/tests/grad_sample_module_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def test_to_standard_module(self):
130130
self.original_model.state_dict(),
131131
strict=True,
132132
)
133-
new_grad_sample_module = GradSampleModule(
133+
new_grad_sample_module = self.CLS(
134134
copy_of_original_model, batch_first=True, loss_reduction="mean"
135135
)
136136

0 commit comments

Comments
 (0)