1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16+ import itertools
1617import os
1718import sys
1819import unittest
2425import torch .optim as optim
2526from opacus import PrivacyEngine
2627from opacus .distributed import DifferentiallyPrivateDistributedDataParallel as DPDDP
28+ from opacus .grad_sample import GradSampleModuleFastGradientClipping
2729from opacus .optimizers .ddp_perlayeroptimizer import (
2830 DistributedPerLayerOptimizer ,
2931 SimpleDistributedPerLayerOptimizer ,
3032)
3133from opacus .optimizers .ddpoptimizer import DistributedDPOptimizer
34+ from opacus .optimizers .ddpoptimizer_fast_gradient_clipping import (
35+ DistributedDPOptimizerFastGradientClipping ,
36+ )
37+ from opacus .utils .fast_gradient_clipping_utils import double_backward
3238from torch .nn .parallel import DistributedDataParallel as DDP
3339from torch .utils .data import DataLoader , TensorDataset
3440from torch .utils .data .distributed import DistributedSampler
@@ -69,6 +75,39 @@ def forward(self, x):
6975 return self .net2 (self .relu (self .net1 (x )))
7076
7177
78+ def run_ghost_clipping_test (
79+ model , optimizer , data_loader , batch_size , max_grad_norm , weight , rank
80+ ):
81+
82+ ddp_model = DPDDP (model )
83+ ddp_model = GradSampleModuleFastGradientClipping (
84+ ddp_model ,
85+ max_grad_norm = max_grad_norm ,
86+ use_ghost_clipping = True ,
87+ )
88+ optimizer = DistributedDPOptimizerFastGradientClipping (
89+ optimizer ,
90+ noise_multiplier = 0 ,
91+ max_grad_norm = max_grad_norm ,
92+ expected_batch_size = batch_size ,
93+ )
94+
95+ assert isinstance (optimizer , DistributedDPOptimizerFastGradientClipping )
96+
97+ loss_fn = nn .CrossEntropyLoss (reduction = "none" )
98+
99+ for x , y in data_loader :
100+ ddp_model .enable_hooks ()
101+ outputs = ddp_model (x .to (rank ))
102+ loss_per_sample = loss_fn (outputs , y )
103+ double_backward (ddp_model , optimizer , loss_per_sample )
104+ optimizer .step ()
105+ break
106+
107+ weight .copy_ (model .net1 .weight .data .cpu ())
108+ cleanup ()
109+
110+
72111def demo_basic (rank , weight , world_size , dp , clipping , grad_sample_mode ):
73112 torch .manual_seed (world_size )
74113 batch_size = 32
@@ -79,12 +118,15 @@ def demo_basic(rank, weight, world_size, dp, clipping, grad_sample_mode):
79118 model .net1 .weight .data .zero_ ()
80119 optimizer = optim .SGD (model .parameters (), lr = 1 )
81120
121+ # create dataset
82122 labels = torch .randn (2 * batch_size , 5 ).to (rank )
83123 data = torch .randn (2 * batch_size , 10 )
84-
85124 dataset = TensorDataset (data , labels )
86125
87- loss_fn = nn .MSELoss ()
126+ loss_fn = nn .CrossEntropyLoss ()
127+
128+ max_grad_norm = 1e8
129+
88130 if dp and clipping == "flat" :
89131 ddp_model = DPDDP (model )
90132 else :
@@ -96,8 +138,15 @@ def demo_basic(rank, weight, world_size, dp, clipping, grad_sample_mode):
96138 dataset , num_replicas = world_size , rank = rank , shuffle = False
97139 )
98140 data_loader = DataLoader (dataset , batch_size = batch_size , sampler = sampler )
141+
142+ # use a separate function for ghost clipping since the procedure has a different structure
143+ if dp and clipping == "ghost" :
144+ run_ghost_clipping_test (
145+ model , optimizer , data_loader , batch_size , max_grad_norm , weight , rank
146+ )
147+ return
148+
99149 if dp :
100- max_grad_norm = 1e8
101150 if clipping == "per_layer" :
102151 max_grad_norm = [max_grad_norm for _ in model .parameters ()]
103152 ddp_model , optimizer , data_loader = privacy_engine .make_private (
@@ -141,33 +190,38 @@ def run_demo(demo_fn, weight, world_size, dp, clipping, grad_sample_mode):
141190
142191class GradientComputationTest (unittest .TestCase ):
143192 def test_gradient_correct (self ) -> None :
144- # Tests that gradient is the same with DP or with DDP
193+ # Tests that gradient is the same with DP or without DDP
145194 n_gpus = torch .cuda .device_count ()
146195 self .assertTrue (
147196 n_gpus >= 2 , f"Need at least 2 gpus but was provided only { n_gpus } ."
148197 )
149198
150- for clipping in ["flat" , "per_layer" ]:
151- for grad_sample_mode in ["hooks" , "ew" ]:
152- weight_dp , weight_nodp = torch .zeros (10 , 10 ), torch .zeros (10 , 10 )
153-
154- run_demo (
155- demo_basic ,
156- weight_dp ,
157- 2 ,
158- dp = True ,
159- clipping = clipping ,
160- grad_sample_mode = grad_sample_mode ,
161- )
162- run_demo (
163- demo_basic ,
164- weight_nodp ,
165- 2 ,
166- dp = False ,
167- clipping = None ,
168- grad_sample_mode = None ,
169- )
170-
171- self .assertTrue (
172- torch .allclose (weight_dp , weight_nodp , atol = 1e-5 , rtol = 1e-3 )
173- )
199+ clipping_grad_sample_pairs = list (
200+ itertools .product (["flat" , "per_layer" ], ["hooks" , "ew" ])
201+ )
202+ clipping_grad_sample_pairs .append (("ghost" , "ghost" ))
203+
204+ for clipping , grad_sample_mode in clipping_grad_sample_pairs :
205+
206+ weight_dp , weight_nodp = torch .zeros (10 , 10 ), torch .zeros (10 , 10 )
207+
208+ run_demo (
209+ demo_basic ,
210+ weight_dp ,
211+ 2 ,
212+ dp = True ,
213+ clipping = clipping ,
214+ grad_sample_mode = grad_sample_mode ,
215+ )
216+ run_demo (
217+ demo_basic ,
218+ weight_nodp ,
219+ 2 ,
220+ dp = False ,
221+ clipping = None ,
222+ grad_sample_mode = None ,
223+ )
224+
225+ self .assertTrue (
226+ torch .allclose (weight_dp , weight_nodp , atol = 1e-5 , rtol = 1e-3 )
227+ )
0 commit comments