1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16+ import copy
1617import logging
1718import unittest
1819
@@ -54,8 +55,8 @@ def __init__(self):
5455 super (SampleModule , self ).__init__ ()
5556 self .fc1 = nn .Linear (2 , 2 )
5657 self .fc3 = nn .Linear (2 , 1024 )
57- self .fc4 = nn .Linear (1024 , 1024 )
58- self .fc5 = nn .Linear (1024 , 1 )
58+ self .fc4 = nn .Linear (1024 , 10 )
59+ self .fc5 = nn .Linear (10 , 1 )
5960 self .layer_norm = nn .LayerNorm (2 )
6061
6162 def forward (self , x ):
@@ -119,7 +120,7 @@ def setUp_data_sequantial(self, size, length, dim):
119120
120121 @given (
121122 size = st .sampled_from ([10 ]),
122- length = st .sampled_from ([1 ]),
123+ length = st .sampled_from ([1 , 10 ]),
123124 dim = st .sampled_from ([2 ]),
124125 )
125126 @settings (deadline = 1000000 )
@@ -131,7 +132,7 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
131132 self .size = size
132133 self .dim = dim
133134
134- self .criterion = torch .nn .CrossEntropyLoss (reduction = "none " )
135+ self .criterion = torch .nn .CrossEntropyLoss (reduction = "mean " )
135136 self .setUp_data_sequantial (self .size , self .length , self .dim )
136137 noise_multiplier = 0.0
137138 batch_size = self .size
@@ -150,19 +151,21 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
150151 clone_module (sample_module ),
151152 max_grad_norm = max_grad_norm ,
152153 use_ghost_clipping = True ,
154+ loss_reduction = "mean" ,
153155 )
154156 optimizer_gc = torch .optim .SGD (self .grad_sample_module .parameters (), lr = 1 )
155157 optimizer_gc = DPOptimizerFastGradientClipping (
156158 optimizer_gc ,
157159 noise_multiplier = noise_multiplier ,
158160 max_grad_norm = max_grad_norm ,
159161 expected_batch_size = batch_size ,
162+ loss_reduction = "mean" ,
160163 )
161164
162165 (input_data , target_data ) = list (self .dl )[0 ]
163166 optimizer_normal .zero_grad ()
164167 output_normal = self .model_normal (input_data )
165- loss_normal = torch . mean ( self .criterion (output_normal , target_data ), dim = 0 )
168+ loss_normal = self .criterion (output_normal , target_data )
166169 loss_normal .backward ()
167170 all_norms_normal = torch .stack (
168171 [
@@ -173,19 +176,13 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
173176 )
174177 flat_norms_normal = torch .cat ([p .flatten () for p in all_norms_normal ])
175178
176- self .grad_sample_module .enable_hooks ()
177- output_gc = self .grad_sample_module (input_data )
178-
179- first_loss_per_sample = self .criterion (output_gc , target_data )
180- first_loss = torch .mean (first_loss_per_sample )
181- first_loss .backward (retain_graph = True )
182-
183179 optimizer_gc .zero_grad ()
184- coeff = self .grad_sample_module .get_clipping_coef ()
185- second_loss_per_sample = coeff * first_loss_per_sample
186- second_loss = torch .sum (second_loss_per_sample )
187- self .grad_sample_module .disable_hooks ()
188- second_loss .backward ()
180+ criterion_gc = DPLossFastGradientClipping (
181+ self .grad_sample_module , optimizer_gc , copy .deepcopy (self .criterion )
182+ )
183+ output_gc = self .grad_sample_module (input_data )
184+ loss_gc = criterion_gc (output_gc , target_data )
185+ loss_gc .backward ()
189186
190187 all_norms_gc = [
191188 param ._norm_sample for param in self .grad_sample_module .parameters ()
@@ -194,13 +191,13 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
194191
195192 diff = flat_norms_normal - flat_norms_gc
196193
197- logging .info (f"Diff = { diff } " ),
194+ logging .info (f"Max difference between (vanilla) Opacus and FGC = { max ( diff ) } " )
198195 msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
199196 assert torch .allclose (flat_norms_normal , flat_norms_gc , atol = 1e-3 ), msg
200197
201198 @given (
202199 size = st .sampled_from ([10 ]),
203- length = st .sampled_from ([1 , 5 ]),
200+ length = st .sampled_from ([1 , 10 ]),
204201 dim = st .sampled_from ([2 ]),
205202 )
206203 @settings (deadline = 1000000 )
@@ -243,7 +240,7 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):
243240 )
244241
245242 criterion_gc = DPLossFastGradientClipping (
246- self .grad_sample_module , optimizer_gc , self .criterion
243+ self .grad_sample_module , optimizer_gc , copy . deepcopy ( self .criterion )
247244 )
248245
249246 (input_data , target_data ) = list (self .dl )[0 ]
@@ -273,7 +270,7 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):
273270 for (g_gc , g_normal ) in zip (flat_grads_gc , flat_grads_normal )
274271 ]
275272 )
276- logging .info (f"Diff = { diff } " )
273+ logging .info (f"Max difference between (vanilla) Opacus and FGC = { max ( diff ) } " )
277274 msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
278275 assert torch .allclose (flat_grads_normal , flat_grads_gc , atol = 1e-3 ), msg
279276
@@ -350,7 +347,7 @@ def test_norm_calculation(self):
350347
351348 diff = flat_norms_normal - flat_norms_gc
352349
353- logging .info (f"Diff = { diff } " )
350+ logging .info (f"Max difference between (vanilla) Opacus and FGC = { max ( diff ) } " )
354351 msg = "Fail: Gradient norms from vanilla DP-SGD and from fast gradient clipping are different"
355352 assert torch .allclose (flat_norms_normal , flat_norms_gc , atol = 1e-3 ), msg
356353
@@ -421,6 +418,6 @@ def test_gradient_calculation(self):
421418 ]
422419 )
423420
424- logging .info (f"Diff = { diff } " )
421+ logging .info (f"Max difference between (vanilla) Opacus and FGC = { max ( diff ) } " )
425422 msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
426423 assert torch .allclose (flat_grads_normal , flat_grads_gc , atol = 1e-3 ), msg
0 commit comments