1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16+ import itertools
1617import os
1718import sys
1819import unittest
2425import torch .optim as optim
2526from opacus import PrivacyEngine
2627from opacus .distributed import DifferentiallyPrivateDistributedDataParallel as DPDDP
28+ from opacus .grad_sample import GradSampleModuleFastGradientClipping
2729from opacus .optimizers .ddp_perlayeroptimizer import (
2830 DistributedPerLayerOptimizer ,
2931 SimpleDistributedPerLayerOptimizer ,
3032)
3133from opacus .optimizers .ddpoptimizer import DistributedDPOptimizer
34+ from opacus .optimizers .ddpoptimizer_fast_gradient_clipping import (
35+ DistributedDPOptimizerFastGradientClipping ,
36+ )
37+ from opacus .utils .fast_gradient_clipping_utils import double_backward
3238from torch .nn .parallel import DistributedDataParallel as DDP
3339from torch .utils .data import DataLoader , TensorDataset
3440from torch .utils .data .distributed import DistributedSampler
@@ -69,6 +75,45 @@ def forward(self, x):
6975 return self .net2 (self .relu (self .net1 (x )))
7076
7177
78+ def run_ghost_clipping_test (
79+ model , optimizer , data_loader , batch_size , max_grad_norm , weight , rank
80+ ):
81+
82+ ddp_model = DPDDP (model )
83+ ddp_model = GradSampleModuleFastGradientClipping (
84+ ddp_model ,
85+ max_grad_norm = max_grad_norm ,
86+ use_ghost_clipping = True ,
87+ )
88+ optimizer = DistributedDPOptimizerFastGradientClipping (
89+ optimizer ,
90+ noise_multiplier = 0 ,
91+ max_grad_norm = max_grad_norm ,
92+ expected_batch_size = batch_size ,
93+ )
94+
95+ assert isinstance (optimizer , DistributedDPOptimizerFastGradientClipping )
96+
97+ loss_fn = nn .CrossEntropyLoss (reduction = "none" )
98+
99+ for x , y in data_loader :
100+ ddp_model .enable_hooks ()
101+ outputs = ddp_model (x .to (rank ))
102+ loss_per_sample = loss_fn (outputs , y )
103+ torch .mean (loss_per_sample ).backward (retain_graph = True )
104+ optimizer .zero_grad ()
105+ rescaled_loss_per_sample = ddp_model .get_coeff () * loss_per_sample
106+ rescaled_loss = torch .sum (rescaled_loss_per_sample )
107+ ddp_model .disable_hooks ()
108+ rescaled_loss .backward ()
109+ ddp_model .enable_hooks ()
110+ optimizer .step ()
111+ break
112+
113+ weight .copy_ (model .net1 .weight .data .cpu ())
114+ cleanup ()
115+
116+
72117def demo_basic (rank , weight , world_size , dp , clipping , grad_sample_mode ):
73118 torch .manual_seed (world_size )
74119 batch_size = 32
@@ -79,12 +124,15 @@ def demo_basic(rank, weight, world_size, dp, clipping, grad_sample_mode):
79124 model .net1 .weight .data .zero_ ()
80125 optimizer = optim .SGD (model .parameters (), lr = 1 )
81126
127+ # create dataset
82128 labels = torch .randn (2 * batch_size , 5 ).to (rank )
83129 data = torch .randn (2 * batch_size , 10 )
84-
85130 dataset = TensorDataset (data , labels )
86131
87- loss_fn = nn .MSELoss ()
132+ loss_fn = nn .CrossEntropyLoss ()
133+
134+ max_grad_norm = 1e8
135+
88136 if dp and clipping == "flat" :
89137 ddp_model = DPDDP (model )
90138 else :
@@ -96,8 +144,15 @@ def demo_basic(rank, weight, world_size, dp, clipping, grad_sample_mode):
96144 dataset , num_replicas = world_size , rank = rank , shuffle = False
97145 )
98146 data_loader = DataLoader (dataset , batch_size = batch_size , sampler = sampler )
147+
148+ # use a separate function for ghost clipping since the procedure has a different structure
149+ if dp and clipping == "ghost" :
150+ run_ghost_clipping_test (
151+ model , optimizer , data_loader , batch_size , max_grad_norm , weight , rank
152+ )
153+ return
154+
99155 if dp :
100- max_grad_norm = 1e8
101156 if clipping == "per_layer" :
102157 max_grad_norm = [max_grad_norm for _ in model .parameters ()]
103158 ddp_model , optimizer , data_loader = privacy_engine .make_private (
@@ -141,33 +196,38 @@ def run_demo(demo_fn, weight, world_size, dp, clipping, grad_sample_mode):
141196
142197class GradientComputationTest (unittest .TestCase ):
143198 def test_gradient_correct (self ) -> None :
144- # Tests that gradient is the same with DP or with DDP
199+ # Tests that gradient is the same with DP or without DDP
145200 n_gpus = torch .cuda .device_count ()
146201 self .assertTrue (
147202 n_gpus >= 2 , f"Need at least 2 gpus but was provided only { n_gpus } ."
148203 )
149204
150- for clipping in ["flat" , "per_layer" ]:
151- for grad_sample_mode in ["hooks" , "ew" ]:
152- weight_dp , weight_nodp = torch .zeros (10 , 10 ), torch .zeros (10 , 10 )
153-
154- run_demo (
155- demo_basic ,
156- weight_dp ,
157- 2 ,
158- dp = True ,
159- clipping = clipping ,
160- grad_sample_mode = grad_sample_mode ,
161- )
162- run_demo (
163- demo_basic ,
164- weight_nodp ,
165- 2 ,
166- dp = False ,
167- clipping = None ,
168- grad_sample_mode = None ,
169- )
170-
171- self .assertTrue (
172- torch .allclose (weight_dp , weight_nodp , atol = 1e-5 , rtol = 1e-3 )
173- )
205+ clipping_grad_sample_pairs = list (
206+ itertools .product (["flat" , "per_layer" ], ["hooks" , "ew" ])
207+ )
208+ clipping_grad_sample_pairs .append (("ghost" , "ghost" ))
209+
210+ for clipping , grad_sample_mode in clipping_grad_sample_pairs :
211+
212+ weight_dp , weight_nodp = torch .zeros (10 , 10 ), torch .zeros (10 , 10 )
213+
214+ run_demo (
215+ demo_basic ,
216+ weight_dp ,
217+ 2 ,
218+ dp = True ,
219+ clipping = clipping ,
220+ grad_sample_mode = grad_sample_mode ,
221+ )
222+ run_demo (
223+ demo_basic ,
224+ weight_nodp ,
225+ 2 ,
226+ dp = False ,
227+ clipping = None ,
228+ grad_sample_mode = None ,
229+ )
230+
231+ self .assertTrue (
232+ torch .allclose (weight_dp , weight_nodp , atol = 1e-5 , rtol = 1e-3 )
233+ )
0 commit comments