Skip to content

Commit 15de443

Browse files
EnayatUllahfacebook-github-bot
authored andcommitted
Fixed Opacus's Runtime error with an empty batch (issue 612) (#631)
Summary: In case of an empty batch, in the ```clip_and_accumulate``` function, the ```per_sample_clip_factor``` variable is set to a tensor of size 0. However, the device was not specified, which throws a runtime error. Added it. Differential Revision: D53733081
1 parent 1ba4113 commit 15de443

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

opacus/optimizers/optimizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,9 @@ def clip_and_accumulate(self):
396396

397397
if len(self.grad_samples[0]) == 0:
398398
# Empty batch
399-
per_sample_clip_factor = torch.zeros((0,))
399+
per_sample_clip_factor = torch.zeros(
400+
(0,),device=self.grad_samples[0].device
401+
)
400402
else:
401403
per_param_norms = [
402404
g.reshape(len(g), -1).norm(2, dim=-1) for g in self.grad_samples

0 commit comments

Comments
 (0)