Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add embeddingbag grad sampler
  • Loading branch information
Alex Sablayrolles committed Oct 31, 2022
commit 62a80fee308ab20d8d6bf6693b31b31e1d7dded8
22 changes: 22 additions & 0 deletions opacus/grad_sample/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,25 @@ def compute_embedding_grad_sample(
torch.backends.cudnn.deterministic = saved
ret[layer.weight] = grad_sample
return ret


@register_grad_sampler(nn.EmbeddingBag)
def compute_embeddingbag_gradsampler(layer, inputs, backprops):
Comment on lines +59 to +60
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we maybe have a test for it, like we have for other supported layers?

index, offset = inputs
batch_size = offset.shape[0]
gsm = torch.zeros(batch_size, layer.num_embeddings, layer.embedding_dim)

for i in range(batch_size):
begin = offset[i]
if i < batch_size - 1:
end = offset[i + 1]
else:
end = index.shape[0] - 1

if layer.mode == "sum":
gsm[i][index[begin:end]] = backprops[i]

ret = {}
ret[layer.weight] = gsm

return ret