Skip to content

Commit 7264cd7

Browse files
EnayatUllahfacebook-github-bot
authored andcommitted
Fix Fast Gradient Clipping bias gradient calculation for three dim data (#751)
Summary: Pull Request resolved: #751 The bias grad calculation for three dim data was incorect. Let `G = g^Tg`, where `g`, of dimensions `Txd` be the per-sample activation gradient, where `T` is the number of tokens and `d` dimension. The per-sample gradient norm with respect to bias is `vec(G)^T vec(1)`, instead of the erroneous,`vec(G)^T vec(G)` before. This diff fixes it. Reviewed By: aparna-aketi, HuanyuZhang Differential Revision: D70823094 fbshipit-source-id: c1fe1dd7f5834a8ad9632c172f93d9c9bececb71
1 parent 8cbf8e0 commit 7264cd7

File tree

2 files changed

+22
-26
lines changed

2 files changed

+22
-26
lines changed

opacus/grad_sample/linear.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def compute_linear_norm_sample(
8383

8484
ret[layer.weight] = torch.sqrt(ga)
8585
if layer.bias is not None and layer.bias.requires_grad:
86-
ggT = torch.einsum("nik,njk->nij", backprops, backprops)
87-
gg = torch.einsum("n...i,n...i->n", ggT, ggT).clamp(min=0)
88-
ret[layer.bias] = torch.sqrt(gg)
86+
ggT = torch.einsum("nik,njk->nij", backprops, backprops) # batchwise g g^T
87+
ret[layer.bias] = torch.sqrt(torch.einsum("n...i->n", ggT).clamp(min=0))
8988
return ret

opacus/tests/grad_sample_module_fast_gradient_clipping_test.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import copy
1617
import logging
1718
import unittest
1819

@@ -54,8 +55,8 @@ def __init__(self):
5455
super(SampleModule, self).__init__()
5556
self.fc1 = nn.Linear(2, 2)
5657
self.fc3 = nn.Linear(2, 1024)
57-
self.fc4 = nn.Linear(1024, 1024)
58-
self.fc5 = nn.Linear(1024, 1)
58+
self.fc4 = nn.Linear(1024, 10)
59+
self.fc5 = nn.Linear(10, 1)
5960
self.layer_norm = nn.LayerNorm(2)
6061

6162
def forward(self, x):
@@ -119,7 +120,7 @@ def setUp_data_sequantial(self, size, length, dim):
119120

120121
@given(
121122
size=st.sampled_from([10]),
122-
length=st.sampled_from([1]),
123+
length=st.sampled_from([1, 10]),
123124
dim=st.sampled_from([2]),
124125
)
125126
@settings(deadline=1000000)
@@ -131,7 +132,7 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
131132
self.size = size
132133
self.dim = dim
133134

134-
self.criterion = torch.nn.CrossEntropyLoss(reduction="none")
135+
self.criterion = torch.nn.CrossEntropyLoss(reduction="mean")
135136
self.setUp_data_sequantial(self.size, self.length, self.dim)
136137
noise_multiplier = 0.0
137138
batch_size = self.size
@@ -150,19 +151,21 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
150151
clone_module(sample_module),
151152
max_grad_norm=max_grad_norm,
152153
use_ghost_clipping=True,
154+
loss_reduction="mean",
153155
)
154156
optimizer_gc = torch.optim.SGD(self.grad_sample_module.parameters(), lr=1)
155157
optimizer_gc = DPOptimizerFastGradientClipping(
156158
optimizer_gc,
157159
noise_multiplier=noise_multiplier,
158160
max_grad_norm=max_grad_norm,
159161
expected_batch_size=batch_size,
162+
loss_reduction="mean",
160163
)
161164

162165
(input_data, target_data) = list(self.dl)[0]
163166
optimizer_normal.zero_grad()
164167
output_normal = self.model_normal(input_data)
165-
loss_normal = torch.mean(self.criterion(output_normal, target_data), dim=0)
168+
loss_normal = self.criterion(output_normal, target_data)
166169
loss_normal.backward()
167170
all_norms_normal = torch.stack(
168171
[
@@ -173,19 +176,13 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
173176
)
174177
flat_norms_normal = torch.cat([p.flatten() for p in all_norms_normal])
175178

176-
self.grad_sample_module.enable_hooks()
177-
output_gc = self.grad_sample_module(input_data)
178-
179-
first_loss_per_sample = self.criterion(output_gc, target_data)
180-
first_loss = torch.mean(first_loss_per_sample)
181-
first_loss.backward(retain_graph=True)
182-
183179
optimizer_gc.zero_grad()
184-
coeff = self.grad_sample_module.get_clipping_coef()
185-
second_loss_per_sample = coeff * first_loss_per_sample
186-
second_loss = torch.sum(second_loss_per_sample)
187-
self.grad_sample_module.disable_hooks()
188-
second_loss.backward()
180+
criterion_gc = DPLossFastGradientClipping(
181+
self.grad_sample_module, optimizer_gc, copy.deepcopy(self.criterion)
182+
)
183+
output_gc = self.grad_sample_module(input_data)
184+
loss_gc = criterion_gc(output_gc, target_data)
185+
loss_gc.backward()
189186

190187
all_norms_gc = [
191188
param._norm_sample for param in self.grad_sample_module.parameters()
@@ -194,13 +191,13 @@ def test_norm_calculation_fast_gradient_clipping(self, size, length, dim):
194191

195192
diff = flat_norms_normal - flat_norms_gc
196193

197-
logging.info(f"Diff = {diff}"),
194+
logging.info(f"Max difference between (vanilla) Opacus and FGC = {max(diff)}")
198195
msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
199196
assert torch.allclose(flat_norms_normal, flat_norms_gc, atol=1e-3), msg
200197

201198
@given(
202199
size=st.sampled_from([10]),
203-
length=st.sampled_from([1, 5]),
200+
length=st.sampled_from([1, 10]),
204201
dim=st.sampled_from([2]),
205202
)
206203
@settings(deadline=1000000)
@@ -243,7 +240,7 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):
243240
)
244241

245242
criterion_gc = DPLossFastGradientClipping(
246-
self.grad_sample_module, optimizer_gc, self.criterion
243+
self.grad_sample_module, optimizer_gc, copy.deepcopy(self.criterion)
247244
)
248245

249246
(input_data, target_data) = list(self.dl)[0]
@@ -273,7 +270,7 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim):
273270
for (g_gc, g_normal) in zip(flat_grads_gc, flat_grads_normal)
274271
]
275272
)
276-
logging.info(f"Diff = {diff}")
273+
logging.info(f"Max difference between (vanilla) Opacus and FGC = {max(diff)}")
277274
msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
278275
assert torch.allclose(flat_grads_normal, flat_grads_gc, atol=1e-3), msg
279276

@@ -350,7 +347,7 @@ def test_norm_calculation(self):
350347

351348
diff = flat_norms_normal - flat_norms_gc
352349

353-
logging.info(f"Diff = {diff}")
350+
logging.info(f"Max difference between (vanilla) Opacus and FGC = {max(diff)}")
354351
msg = "Fail: Gradient norms from vanilla DP-SGD and from fast gradient clipping are different"
355352
assert torch.allclose(flat_norms_normal, flat_norms_gc, atol=1e-3), msg
356353

@@ -421,6 +418,6 @@ def test_gradient_calculation(self):
421418
]
422419
)
423420

424-
logging.info(f"Diff = {diff}")
421+
logging.info(f"Max difference between (vanilla) Opacus and FGC = {max(diff)}")
425422
msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
426423
assert torch.allclose(flat_grads_normal, flat_grads_gc, atol=1e-3), msg

0 commit comments

Comments
 (0)