Skip to content

Commit d4bc8c7

Browse files
HuanyuZhangfacebook-github-bot
authored andcommitted
Fix gradient shape error for DPMultiheadAttention (issue 650) (#651)
Summary: When batch_first = True, the activation and partial gradient for each linear layer in DPMultiheadAttention still has batch_size in the second dimension, thus causing wrong gradient shape in optimizer.step(). Details in: #650 Reviewed By: EnayatUllah Differential Revision: D57446245
1 parent 7d65ddf commit d4bc8c7

File tree

1 file changed

+14
-15
lines changed

1 file changed

+14
-15
lines changed

opacus/layers/dp_multihead_attention.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -203,17 +203,12 @@ def forward(
203203
r"""
204204
Using the same logic with ``nn.MultiheadAttention`` (https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html).
205205
"""
206-
if self.batch_first:
207-
if key is value:
208-
if query is key:
209-
query = key = value = query.transpose(1, 0)
210-
else:
211-
query, key = [x.transpose(1, 0) for x in (query, key)]
212-
value = key
213-
else:
214-
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
215206

216-
tgt_len, bsz, embed_dim = query.size()
207+
if not self.batch_first:
208+
tgt_len, bsz, embed_dim = query.size()
209+
else:
210+
bsz, tgt_len, embed_dim = query.size()
211+
217212
if embed_dim != self.embed_dim:
218213
raise ValueError(
219214
f"query has as size of {embed_dim} while the embedding"
@@ -234,6 +229,9 @@ def forward(
234229

235230
q = q * scaling
236231

232+
if self.batch_first:
233+
q, k, v = [x.transpose(0, 1) for x in (q, k, v)]
234+
237235
if attn_mask is not None:
238236
if attn_mask.dtype not in (
239237
torch.float32,
@@ -352,13 +350,14 @@ def forward(
352350

353351
attn_output = torch.bmm(attn_output_weights, v)
354352
assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, head_dim]
355-
attn_output = (
356-
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
357-
)
358-
attn_output = self.out_proj(attn_output)
359353

360354
if self.batch_first:
361-
attn_output = attn_output.transpose(1, 0)
355+
attn_output = attn_output.contiguous().view(bsz, tgt_len, embed_dim)
356+
else:
357+
attn_output = (
358+
attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
359+
)
360+
attn_output = self.out_proj(attn_output)
362361

363362
if need_weights:
364363
# average attention weights over heads

0 commit comments

Comments
 (0)