2222import torch .multiprocessing as mp
2323import torch .nn as nn
2424import torch .optim as optim
25+ import itertools
2526from opacus import PrivacyEngine
2627from opacus .distributed import DifferentiallyPrivateDistributedDataParallel as DPDDP
2728from opacus .optimizers .ddp_perlayeroptimizer import (
2829 DistributedPerLayerOptimizer ,
2930 SimpleDistributedPerLayerOptimizer ,
3031)
3132from opacus .optimizers .ddpoptimizer import DistributedDPOptimizer
33+ from opacus .optimizers .ddpoptimizer_fast_gradient_clipping import (
34+ DistributedDPOptimizerFastGradientClipping ,
35+ )
36+ from opacus .utils .fast_gradient_clipping_utils import double_backward
3237from torch .nn .parallel import DistributedDataParallel as DDP
3338from torch .utils .data import DataLoader , TensorDataset
3439from torch .utils .data .distributed import DistributedSampler
@@ -84,8 +89,10 @@ def demo_basic(rank, weight, world_size, dp, clipping, grad_sample_mode):
8489
8590 dataset = TensorDataset (data , labels )
8691
87- loss_fn = nn .MSELoss ()
88- if dp and clipping == "flat" :
92+ reduction = "none" if dp and clipping == "ghost" else "mean"
93+ loss_fn = nn .CrossEntropyLoss (reduction = reduction )
94+
95+ if dp and clipping in ["flat" , "ghost" ]:
8996 ddp_model = DPDDP (model )
9097 else :
9198 ddp_model = DDP (model , device_ids = [rank ])
@@ -115,15 +122,24 @@ def demo_basic(rank, weight, world_size, dp, clipping, grad_sample_mode):
115122 optimizer ,
116123 (DistributedPerLayerOptimizer , SimpleDistributedPerLayerOptimizer ),
117124 )
125+ elif clipping == "ghost" :
126+ assert isinstance (optimizer , DistributedDPOptimizerFastGradientClipping )
118127 else :
119128 assert isinstance (optimizer , DistributedDPOptimizer )
120129
121130 for x , y in data_loader :
122- outputs = ddp_model (x .to (rank ))
123- loss = loss_fn (outputs , y )
124- optimizer .zero_grad ()
125- loss .backward ()
126- optimizer .step ()
131+ if dp and clipping == "ghost" :
132+ ddp_model .enable_hooks ()
133+ outputs = ddp_model (x .to (rank ))
134+ loss_per_sample = loss_fn (outputs , y )
135+ double_backward (ddp_model , optimizer , loss_per_sample )
136+ optimizer .step ()
137+ else :
138+ outputs = ddp_model (x .to (rank ))
139+ loss = loss_fn (outputs , y )
140+ optimizer .zero_grad ()
141+ loss .backward ()
142+ optimizer .step ()
127143 break
128144
129145 weight .copy_ (model .net1 .weight .data .cpu ())
@@ -141,33 +157,38 @@ def run_demo(demo_fn, weight, world_size, dp, clipping, grad_sample_mode):
141157
142158class GradientComputationTest (unittest .TestCase ):
143159 def test_gradient_correct (self ) -> None :
144- # Tests that gradient is the same with DP or with DDP
160+ # Tests that gradient is the same with DP or without DDP
145161 n_gpus = torch .cuda .device_count ()
146162 self .assertTrue (
147163 n_gpus >= 2 , f"Need at least 2 gpus but was provided only { n_gpus } ."
148164 )
149165
150- for clipping in ["flat" , "per_layer" ]:
151- for grad_sample_mode in ["hooks" , "ew" ]:
152- weight_dp , weight_nodp = torch .zeros (10 , 10 ), torch .zeros (10 , 10 )
153-
154- run_demo (
155- demo_basic ,
156- weight_dp ,
157- 2 ,
158- dp = True ,
159- clipping = clipping ,
160- grad_sample_mode = grad_sample_mode ,
161- )
162- run_demo (
163- demo_basic ,
164- weight_nodp ,
165- 2 ,
166- dp = False ,
167- clipping = None ,
168- grad_sample_mode = None ,
169- )
170-
171- self .assertTrue (
172- torch .allclose (weight_dp , weight_nodp , atol = 1e-5 , rtol = 1e-3 )
173- )
166+ clipping_grad_sample_pairs = list (
167+ itertools .product (["flat" , "per_layer" ], ["hooks" , "ew" ])
168+ )
169+ clipping_grad_sample_pairs .append (("ghost" , "ghost" ))
170+
171+ for clipping , grad_sample_mode in clipping_grad_sample_pairs :
172+
173+ weight_dp , weight_nodp = torch .zeros (10 , 10 ), torch .zeros (10 , 10 )
174+
175+ run_demo (
176+ demo_basic ,
177+ weight_dp ,
178+ 2 ,
179+ dp = True ,
180+ clipping = clipping ,
181+ grad_sample_mode = grad_sample_mode ,
182+ )
183+ run_demo (
184+ demo_basic ,
185+ weight_nodp ,
186+ 2 ,
187+ dp = False ,
188+ clipping = None ,
189+ grad_sample_mode = None ,
190+ )
191+
192+ self .assertTrue (
193+ torch .allclose (weight_dp , weight_nodp , atol = 1e-5 , rtol = 1e-3 )
194+ )
0 commit comments