Skip to content

Commit a83f1d2

Browse files
iden-kalemajfacebook-github-bot
authored andcommitted
Add multi_gpu test for ghost clipping (meta-pytorch#665)
Summary: Pull Request resolved: meta-pytorch#665 Modify the existing `multigpu_gradcheck.py` test to check gradient correctness for ghost clipping in a distributed setting. Differential Revision: D60840755
1 parent f2a591a commit a83f1d2

File tree

1 file changed

+88
-28
lines changed

1 file changed

+88
-28
lines changed

opacus/tests/multigpu_gradcheck.py

Lines changed: 88 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import itertools
1617
import os
1718
import sys
1819
import unittest
@@ -24,11 +25,16 @@
2425
import torch.optim as optim
2526
from opacus import PrivacyEngine
2627
from opacus.distributed import DifferentiallyPrivateDistributedDataParallel as DPDDP
28+
from opacus.grad_sample import GradSampleModuleFastGradientClipping
2729
from opacus.optimizers.ddp_perlayeroptimizer import (
2830
DistributedPerLayerOptimizer,
2931
SimpleDistributedPerLayerOptimizer,
3032
)
3133
from opacus.optimizers.ddpoptimizer import DistributedDPOptimizer
34+
from opacus.optimizers.ddpoptimizer_fast_gradient_clipping import (
35+
DistributedDPOptimizerFastGradientClipping,
36+
)
37+
from opacus.utils.fast_gradient_clipping_utils import double_backward
3238
from torch.nn.parallel import DistributedDataParallel as DDP
3339
from torch.utils.data import DataLoader, TensorDataset
3440
from torch.utils.data.distributed import DistributedSampler
@@ -69,6 +75,45 @@ def forward(self, x):
6975
return self.net2(self.relu(self.net1(x)))
7076

7177

78+
def run_ghost_clipping_test(
79+
model, optimizer, data_loader, batch_size, max_grad_norm, weight, rank
80+
):
81+
82+
ddp_model = DPDDP(model)
83+
ddp_model = GradSampleModuleFastGradientClipping(
84+
ddp_model,
85+
max_grad_norm=max_grad_norm,
86+
use_ghost_clipping=True,
87+
)
88+
optimizer = DistributedDPOptimizerFastGradientClipping(
89+
optimizer,
90+
noise_multiplier=0,
91+
max_grad_norm=max_grad_norm,
92+
expected_batch_size=batch_size,
93+
)
94+
95+
assert isinstance(optimizer, DistributedDPOptimizerFastGradientClipping)
96+
97+
loss_fn = nn.CrossEntropyLoss(reduction="none")
98+
99+
for x, y in data_loader:
100+
ddp_model.enable_hooks()
101+
outputs = ddp_model(x.to(rank))
102+
loss_per_sample = loss_fn(outputs, y)
103+
torch.mean(loss_per_sample).backward(retain_graph=True)
104+
optimizer.zero_grad()
105+
rescaled_loss_per_sample = ddp_model.get_coeff() * loss_per_sample
106+
rescaled_loss = torch.sum(rescaled_loss_per_sample)
107+
ddp_model.disable_hooks()
108+
rescaled_loss.backward()
109+
ddp_model.enable_hooks()
110+
optimizer.step()
111+
break
112+
113+
weight.copy_(model.net1.weight.data.cpu())
114+
cleanup()
115+
116+
72117
def demo_basic(rank, weight, world_size, dp, clipping, grad_sample_mode):
73118
torch.manual_seed(world_size)
74119
batch_size = 32
@@ -79,12 +124,15 @@ def demo_basic(rank, weight, world_size, dp, clipping, grad_sample_mode):
79124
model.net1.weight.data.zero_()
80125
optimizer = optim.SGD(model.parameters(), lr=1)
81126

127+
# create dataset
82128
labels = torch.randn(2 * batch_size, 5).to(rank)
83129
data = torch.randn(2 * batch_size, 10)
84-
85130
dataset = TensorDataset(data, labels)
86131

87-
loss_fn = nn.MSELoss()
132+
loss_fn = nn.CrossEntropyLoss()
133+
134+
max_grad_norm = 1e8
135+
88136
if dp and clipping == "flat":
89137
ddp_model = DPDDP(model)
90138
else:
@@ -96,8 +144,15 @@ def demo_basic(rank, weight, world_size, dp, clipping, grad_sample_mode):
96144
dataset, num_replicas=world_size, rank=rank, shuffle=False
97145
)
98146
data_loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
147+
148+
# use a separate function for ghost clipping since the procedure has a different structure
149+
if dp and clipping == "ghost":
150+
run_ghost_clipping_test(
151+
model, optimizer, data_loader, batch_size, max_grad_norm, weight, rank
152+
)
153+
return
154+
99155
if dp:
100-
max_grad_norm = 1e8
101156
if clipping == "per_layer":
102157
max_grad_norm = [max_grad_norm for _ in model.parameters()]
103158
ddp_model, optimizer, data_loader = privacy_engine.make_private(
@@ -141,33 +196,38 @@ def run_demo(demo_fn, weight, world_size, dp, clipping, grad_sample_mode):
141196

142197
class GradientComputationTest(unittest.TestCase):
143198
def test_gradient_correct(self) -> None:
144-
# Tests that gradient is the same with DP or with DDP
199+
# Tests that gradient is the same with DP or without DDP
145200
n_gpus = torch.cuda.device_count()
146201
self.assertTrue(
147202
n_gpus >= 2, f"Need at least 2 gpus but was provided only {n_gpus}."
148203
)
149204

150-
for clipping in ["flat", "per_layer"]:
151-
for grad_sample_mode in ["hooks", "ew"]:
152-
weight_dp, weight_nodp = torch.zeros(10, 10), torch.zeros(10, 10)
153-
154-
run_demo(
155-
demo_basic,
156-
weight_dp,
157-
2,
158-
dp=True,
159-
clipping=clipping,
160-
grad_sample_mode=grad_sample_mode,
161-
)
162-
run_demo(
163-
demo_basic,
164-
weight_nodp,
165-
2,
166-
dp=False,
167-
clipping=None,
168-
grad_sample_mode=None,
169-
)
170-
171-
self.assertTrue(
172-
torch.allclose(weight_dp, weight_nodp, atol=1e-5, rtol=1e-3)
173-
)
205+
clipping_grad_sample_pairs = list(
206+
itertools.product(["flat", "per_layer"], ["hooks", "ew"])
207+
)
208+
clipping_grad_sample_pairs.append(("ghost", "ghost"))
209+
210+
for clipping, grad_sample_mode in clipping_grad_sample_pairs:
211+
212+
weight_dp, weight_nodp = torch.zeros(10, 10), torch.zeros(10, 10)
213+
214+
run_demo(
215+
demo_basic,
216+
weight_dp,
217+
2,
218+
dp=True,
219+
clipping=clipping,
220+
grad_sample_mode=grad_sample_mode,
221+
)
222+
run_demo(
223+
demo_basic,
224+
weight_nodp,
225+
2,
226+
dp=False,
227+
clipping=None,
228+
grad_sample_mode=None,
229+
)
230+
231+
self.assertTrue(
232+
torch.allclose(weight_dp, weight_nodp, atol=1e-5, rtol=1e-3)
233+
)

0 commit comments

Comments
 (0)