File tree Expand file tree Collapse file tree 2 files changed +11
-9
lines changed Expand file tree Collapse file tree 2 files changed +11
-9
lines changed Original file line number Diff line number Diff line change 2424
2525logger = logging .getLogger (__name__ )
2626
27- OPACUS_PARAM_MONKEYPATCH_ATTRS = ["_forward_counter" , "_current_grad_sample" ]
27+ OPACUS_PARAM_MONKEYPATCH_ATTRS = [
28+ "grad_sample" ,
29+ "_forward_counter" ,
30+ "_current_grad_sample" ,
31+ "_norm_sample" ,
32+ ]
2833
2934
3035class AbstractGradSampleModule (nn .Module , ABC ):
@@ -131,18 +136,15 @@ def to_standard_module(self) -> nn.Module:
131136 return self ._module
132137
133138 def _close (self ):
134- self .del_grad_sample ()
135- self ._clean_up_attributes ()
136-
137- def __repr__ (self ):
138- return f"{ type (self ).__name__ } ({ self ._module .__repr__ ()} )"
139-
140- def _clean_up_attributes (self ):
139+ # Clean up attributes
141140 for attr in OPACUS_PARAM_MONKEYPATCH_ATTRS :
142141 for p in self .parameters ():
143142 if hasattr (p , attr ):
144143 delattr (p , attr )
145144
145+ def __repr__ (self ):
146+ return f"{ type (self ).__name__ } ({ self ._module .__repr__ ()} )"
147+
146148 def forbid_grad_accumulation (self ):
147149 """
148150 Sets a flag to detect gradient accumulation (multiple forward/backward passes
Original file line number Diff line number Diff line change @@ -130,7 +130,7 @@ def test_to_standard_module(self):
130130 self .original_model .state_dict (),
131131 strict = True ,
132132 )
133- new_grad_sample_module = GradSampleModule (
133+ new_grad_sample_module = self . CLS (
134134 copy_of_original_model , batch_first = True , loss_reduction = "mean"
135135 )
136136
You can’t perform that action at this time.
0 commit comments