1515
1616import logging
1717import unittest
18+ import copy
1819
1920import hypothesis .strategies as st
2021import torch
2728from opacus .utils .per_sample_gradients_utils import clone_module
2829from torch .utils .data import DataLoader , Dataset
2930
31+
3032from .grad_sample_module_test import GradSampleModuleTest , SampleConvNet
3133
3234
@@ -54,8 +56,8 @@ def __init__(self):
5456 super (SampleModule , self ).__init__ ()
5557 self .fc1 = nn .Linear (2 , 2 )
5658 self .fc3 = nn .Linear (2 , 1024 )
57- self .fc4 = nn .Linear (1024 , 1024 )
58- self .fc5 = nn .Linear (1024 , 1 )
59+ self .fc4 = nn .Linear (1024 , 10 )
60+ self .fc5 = nn .Linear (10 , 1 )
5961 self .layer_norm = nn .LayerNorm (2 )
6062
6163 def forward (self , x ):
@@ -119,7 +121,7 @@ def setUp_data_sequantial(self, size, length, dim):
119121
120122 @given (
121123 size = st .sampled_from ([10 ]),
122- length = st .sampled_from ([1 ]),
124+ length = st .sampled_from ([1 , 10 ]),
123125 dim = st .sampled_from ([2 ]),
124126 )
125127 @settings (deadline = 1000000 )
@@ -131,7 +133,7 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
131133 self .size = size
132134 self .dim = dim
133135
134- self .criterion = torch .nn .CrossEntropyLoss (reduction = "none " )
136+ self .criterion = torch .nn .CrossEntropyLoss (reduction = "mean " )
135137 self .setUp_data_sequantial (self .size , self .length , self .dim )
136138 noise_multiplier = 0.0
137139 batch_size = self .size
@@ -150,19 +152,21 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
150152 clone_module (sample_module ),
151153 max_grad_norm = max_grad_norm ,
152154 use_ghost_clipping = True ,
155+ loss_reduction = "mean" ,
153156 )
154157 optimizer_gc = torch .optim .SGD (self .grad_sample_module .parameters (), lr = 1 )
155158 optimizer_gc = DPOptimizerFastGradientClipping (
156159 optimizer_gc ,
157160 noise_multiplier = noise_multiplier ,
158161 max_grad_norm = max_grad_norm ,
159162 expected_batch_size = batch_size ,
163+ loss_reduction = "mean" ,
160164 )
161165
162166 (input_data , target_data ) = list (self .dl )[0 ]
163167 optimizer_normal .zero_grad ()
164168 output_normal = self .model_normal (input_data )
165- loss_normal = torch . mean ( self .criterion (output_normal , target_data ), dim = 0 )
169+ loss_normal = self .criterion (output_normal , target_data )
166170 loss_normal .backward ()
167171 all_norms_normal = torch .stack (
168172 [
@@ -173,19 +177,13 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
173177 )
174178 flat_norms_normal = torch .cat ([p .flatten () for p in all_norms_normal ])
175179
176- self .grad_sample_module .enable_hooks ()
177- output_gc = self .grad_sample_module (input_data )
178-
179- first_loss_per_sample = self .criterion (output_gc , target_data )
180- first_loss = torch .mean (first_loss_per_sample )
181- first_loss .backward (retain_graph = True )
182-
183180 optimizer_gc .zero_grad ()
184- coeff = self .grad_sample_module .get_clipping_coef ()
185- second_loss_per_sample = coeff * first_loss_per_sample
186- second_loss = torch .sum (second_loss_per_sample )
187- self .grad_sample_module .disable_hooks ()
188- second_loss .backward ()
181+ criterion_gc = DPLossFastGradientClipping (
182+ self .grad_sample_module , optimizer_gc , copy .deepcopy (self .criterion )
183+ )
184+ output_gc = self .grad_sample_module (input_data )
185+ loss_gc = criterion_gc (output_gc , target_data )
186+ loss_gc .backward ()
189187
190188 all_norms_gc = [
191189 param ._norm_sample for param in self .grad_sample_module .parameters ()
@@ -194,13 +192,13 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
194192
195193 diff = flat_norms_normal - flat_norms_gc
196194
197- logging .info (f"Diff = { diff } " ),
195+ logging .info (f"Max difference between (vanilla) Opacus and FGC = { max ( diff ) } " )
198196 msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
199197 assert torch .allclose (flat_norms_normal , flat_norms_gc , atol = 1e-3 ), msg
200198
201199 @given (
202200 size = st .sampled_from ([10 ]),
203- length = st .sampled_from ([1 , 5 ]),
201+ length = st .sampled_from ([1 , 10 ]),
204202 dim = st .sampled_from ([2 ]),
205203 )
206204 @settings (deadline = 1000000 )
@@ -243,7 +241,7 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):
243241 )
244242
245243 criterion_gc = DPLossFastGradientClipping (
246- self .grad_sample_module , optimizer_gc , self .criterion
244+ self .grad_sample_module , optimizer_gc , copy . deepcopy ( self .criterion )
247245 )
248246
249247 (input_data , target_data ) = list (self .dl )[0 ]
@@ -273,7 +271,7 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):
273271 for (g_gc , g_normal ) in zip (flat_grads_gc , flat_grads_normal )
274272 ]
275273 )
276- logging .info (f"Diff = { diff } " )
274+ logging .info (f"Max difference between (vanilla) Opacus and FGC = { max ( diff ) } " )
277275 msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
278276 assert torch .allclose (flat_grads_normal , flat_grads_gc , atol = 1e-3 ), msg
279277
@@ -350,7 +348,7 @@ def test_norm_calculation(self):
350348
351349 diff = flat_norms_normal - flat_norms_gc
352350
353- logging .info (f"Diff = { diff } " )
351+ logging .info (f"Max difference between (vanilla) Opacus and FGC = { max ( diff ) } " )
354352 msg = "Fail: Gradient norms from vanilla DP-SGD and from fast gradient clipping are different"
355353 assert torch .allclose (flat_norms_normal , flat_norms_gc , atol = 1e-3 ), msg
356354
@@ -421,6 +419,6 @@ def test_gradient_calculation(self):
421419 ]
422420 )
423421
424- logging .info (f"Diff = { diff } " )
422+ logging .info (f"Max difference between (vanilla) Opacus and FGC = { max ( diff ) } " )
425423 msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
426424 assert torch .allclose (flat_grads_normal , flat_grads_gc , atol = 1e-3 ), msg
0 commit comments