Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Fix "to_standard_module" for Ghost Clipping (#754)
Summary:
Pull Request resolved: #754

Issue [#749](#749)

Under FGC, we use ``del p.grad_sample`` to release used per-sample gradients. However, when running ``to_standard_module``, w/o new changes, it will lead to attribute not found errors.

Reviewed By: iden-kalemaj

Differential Revision: D74019695
  • Loading branch information
HuanyuZhang authored and facebook-github-bot committed May 6, 2025
commit c6d65701668f337aa6fe275e5a7a955b93ae9139
18 changes: 10 additions & 8 deletions opacus/grad_sample/gsm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@

logger = logging.getLogger(__name__)

OPACUS_PARAM_MONKEYPATCH_ATTRS = ["_forward_counter", "_current_grad_sample"]
OPACUS_PARAM_MONKEYPATCH_ATTRS = [
"grad_sample",
"_forward_counter",
"_current_grad_sample",
"_norm_sample",
]


class AbstractGradSampleModule(nn.Module, ABC):
Expand Down Expand Up @@ -131,18 +136,15 @@ def to_standard_module(self) -> nn.Module:
return self._module

def _close(self):
self.del_grad_sample()
self._clean_up_attributes()

def __repr__(self):
return f"{type(self).__name__}({self._module.__repr__()})"

def _clean_up_attributes(self):
# Clean up attributes
for attr in OPACUS_PARAM_MONKEYPATCH_ATTRS:
for p in self.parameters():
if hasattr(p, attr):
delattr(p, attr)

def __repr__(self):
return f"{type(self).__name__}({self._module.__repr__()})"

def forbid_grad_accumulation(self):
"""
Sets a flag to detect gradient accumulation (multiple forward/backward passes
Expand Down
2 changes: 1 addition & 1 deletion opacus/tests/grad_sample_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test_to_standard_module(self):
self.original_model.state_dict(),
strict=True,
)
new_grad_sample_module = GradSampleModule(
new_grad_sample_module = self.CLS(
copy_of_original_model, batch_first=True, loss_reduction="mean"
)

Expand Down
Loading