Skip to content

Commit d0a98cb

Browse files
EnayatUllahfacebook-github-bot
authored andcommitted
Fix Fast Gradient Clipping bias gradient calculation for three dim data (meta-pytorch#751)
Summary: The bias grad calculation for three dim data was incorect. Let `G = g^Tg`, where `g`, of dimensions `Txd` be the per-sample activation gradient, where `T` is the number of tokens and `d` dimension. The per-sample gradient norm with respect to bias is `vec(G)^T vec(1)`, instead of the erroneous,`vec(G)^T vec(G)` before. This diff fixes it. Reviewed By: HuanyuZhang Differential Revision: D70823094
1 parent 8cbf8e0 commit d0a98cb

File tree

2 files changed

+23
-26
lines changed

2 files changed

+23
-26
lines changed

opacus/grad_sample/linear.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def compute_linear_norm_sample(
8383

8484
ret[layer.weight] = torch.sqrt(ga)
8585
if layer.bias is not None and layer.bias.requires_grad:
86-
ggT = torch.einsum("nik,njk->nij", backprops, backprops)
87-
gg = torch.einsum("n...i,n...i->n", ggT, ggT).clamp(min=0)
88-
ret[layer.bias] = torch.sqrt(gg)
86+
ggT = torch.einsum("nik,njk->nij", backprops, backprops) # batchwise g g^T
87+
ret[layer.bias] = torch.sqrt(torch.einsum("n...i->n", ggT).clamp(min=0))
8988
return ret

opacus/tests/grad_sample_module_fast_gradient_clipping_test.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import logging
1717
import unittest
18+
import copy
1819

1920
import hypothesis.strategies as st
2021
import torch
@@ -27,6 +28,7 @@
2728
from opacus.utils.per_sample_gradients_utils import clone_module
2829
from torch.utils.data import DataLoader, Dataset
2930

31+
3032
from .grad_sample_module_test import GradSampleModuleTest, SampleConvNet
3133

3234

@@ -54,8 +56,8 @@ def __init__(self):
5456
super(SampleModule, self).__init__()
5557
self.fc1 = nn.Linear(2, 2)
5658
self.fc3 = nn.Linear(2, 1024)
57-
self.fc4 = nn.Linear(1024, 1024)
58-
self.fc5 = nn.Linear(1024, 1)
59+
self.fc4 = nn.Linear(1024, 10)
60+
self.fc5 = nn.Linear(10, 1)
5961
self.layer_norm = nn.LayerNorm(2)
6062

6163
def forward(self, x):
@@ -119,7 +121,7 @@ def setUp_data_sequantial(self, size, length, dim):
119121

120122
@given(
121123
size=st.sampled_from([10]),
122-
length=st.sampled_from([1]),
124+
length=st.sampled_from([1, 10]),
123125
dim=st.sampled_from([2]),
124126
)
125127
@settings(deadline=1000000)
@@ -131,7 +133,7 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
131133
self.size = size
132134
self.dim = dim
133135

134-
self.criterion = torch.nn.CrossEntropyLoss(reduction="none")
136+
self.criterion = torch.nn.CrossEntropyLoss(reduction="mean")
135137
self.setUp_data_sequantial(self.size, self.length, self.dim)
136138
noise_multiplier = 0.0
137139
batch_size = self.size
@@ -150,19 +152,21 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
150152
clone_module(sample_module),
151153
max_grad_norm=max_grad_norm,
152154
use_ghost_clipping=True,
155+
loss_reduction="mean",
153156
)
154157
optimizer_gc = torch.optim.SGD(self.grad_sample_module.parameters(), lr=1)
155158
optimizer_gc = DPOptimizerFastGradientClipping(
156159
optimizer_gc,
157160
noise_multiplier=noise_multiplier,
158161
max_grad_norm=max_grad_norm,
159162
expected_batch_size=batch_size,
163+
loss_reduction="mean",
160164
)
161165

162166
(input_data, target_data) = list(self.dl)[0]
163167
optimizer_normal.zero_grad()
164168
output_normal = self.model_normal(input_data)
165-
loss_normal = torch.mean(self.criterion(output_normal, target_data), dim=0)
169+
loss_normal = self.criterion(output_normal, target_data)
166170
loss_normal.backward()
167171
all_norms_normal = torch.stack(
168172
[
@@ -173,19 +177,13 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
173177
)
174178
flat_norms_normal = torch.cat([p.flatten() for p in all_norms_normal])
175179

176-
self.grad_sample_module.enable_hooks()
177-
output_gc = self.grad_sample_module(input_data)
178-
179-
first_loss_per_sample = self.criterion(output_gc, target_data)
180-
first_loss = torch.mean(first_loss_per_sample)
181-
first_loss.backward(retain_graph=True)
182-
183180
optimizer_gc.zero_grad()
184-
coeff = self.grad_sample_module.get_clipping_coef()
185-
second_loss_per_sample = coeff * first_loss_per_sample
186-
second_loss = torch.sum(second_loss_per_sample)
187-
self.grad_sample_module.disable_hooks()
188-
second_loss.backward()
181+
criterion_gc = DPLossFastGradientClipping(
182+
self.grad_sample_module, optimizer_gc, copy.deepcopy(self.criterion)
183+
)
184+
output_gc = self.grad_sample_module(input_data)
185+
loss_gc = criterion_gc(output_gc, target_data)
186+
loss_gc.backward()
189187

190188
all_norms_gc = [
191189
param._norm_sample for param in self.grad_sample_module.parameters()
@@ -194,13 +192,13 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
194192

195193
diff = flat_norms_normal - flat_norms_gc
196194

197-
logging.info(f"Diff = {diff}"),
195+
logging.info(f"Max difference between (vanilla) Opacus and FGC = {max(diff)}")
198196
msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
199197
assert torch.allclose(flat_norms_normal, flat_norms_gc, atol=1e-3), msg
200198

201199
@given(
202200
size=st.sampled_from([10]),
203-
length=st.sampled_from([1, 5]),
201+
length=st.sampled_from([1, 10]),
204202
dim=st.sampled_from([2]),
205203
)
206204
@settings(deadline=1000000)
@@ -243,7 +241,7 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):
243241
)
244242

245243
criterion_gc = DPLossFastGradientClipping(
246-
self.grad_sample_module, optimizer_gc, self.criterion
244+
self.grad_sample_module, optimizer_gc, copy.deepcopy(self.criterion)
247245
)
248246

249247
(input_data, target_data) = list(self.dl)[0]
@@ -273,7 +271,7 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):
273271
for (g_gc, g_normal) in zip(flat_grads_gc, flat_grads_normal)
274272
]
275273
)
276-
logging.info(f"Diff = {diff}")
274+
logging.info(f"Max difference between (vanilla) Opacus and FGC = {max(diff)}")
277275
msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
278276
assert torch.allclose(flat_grads_normal, flat_grads_gc, atol=1e-3), msg
279277

@@ -350,7 +348,7 @@ def test_norm_calculation(self):
350348

351349
diff = flat_norms_normal - flat_norms_gc
352350

353-
logging.info(f"Diff = {diff}")
351+
logging.info(f"Max difference between (vanilla) Opacus and FGC = {max(diff)}")
354352
msg = "Fail: Gradient norms from vanilla DP-SGD and from fast gradient clipping are different"
355353
assert torch.allclose(flat_norms_normal, flat_norms_gc, atol=1e-3), msg
356354

@@ -421,6 +419,6 @@ def test_gradient_calculation(self):
421419
]
422420
)
423421

424-
logging.info(f"Diff = {diff}")
422+
logging.info(f"Max difference between (vanilla) Opacus and FGC = {max(diff)}")
425423
msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
426424
assert torch.allclose(flat_grads_normal, flat_grads_gc, atol=1e-3), msg

0 commit comments

Comments
 (0)