Skip to content

Commit f86ddf4

Browse files
iden-kalemajfacebook-github-bot
authored andcommitted
Separate function for preparing criterion in PrivacyEngine (#703)
Summary: Pull Request resolved: #703 Having a separate function for preparing the criterion makes it easy to build custom extensions of PrivacyEnginge for methods that require a different DPLoss class, e.g., adaptive clipping. Reviewed By: EnayatUllah Differential Revision: D67458234 fbshipit-source-id: 9fca64fcde7714708ac1cb9a35a991099606f449
1 parent 144bd2a commit f86ddf4

File tree

1 file changed

+27
-2
lines changed

1 file changed

+27
-2
lines changed

opacus/privacy_engine.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,26 @@ def _prepare_model(
212212
loss_reduction=loss_reduction,
213213
)
214214

215+
def _prepare_criterion(
216+
self,
217+
*,
218+
module: GradSampleModule,
219+
optimizer: DPOptimizer,
220+
criterion=nn.CrossEntropyLoss(),
221+
loss_reduction: str = "mean",
222+
**kwargs,
223+
) -> DPLossFastGradientClipping:
224+
"""
225+
Args:
226+
module: GradSampleModule used for training,
227+
optimizer: DPOptimizer used for training,
228+
criterion: Loss function used for training,
229+
loss_reduction: "mean" or "sum", indicates if the loss reduction (for aggregating the gradients)
230+
231+
Prepare the DP loss class, which packages the two backward passes for fast gradient clipping.
232+
"""
233+
return DPLossFastGradientClipping(module, optimizer, criterion, loss_reduction)
234+
215235
def is_compatible(
216236
self,
217237
*,
@@ -403,9 +423,14 @@ def make_private(
403423
self.accountant.get_optimizer_hook_fn(sample_rate=sample_rate)
404424
)
405425
if grad_sample_mode == "ghost":
406-
criterion = DPLossFastGradientClipping(
407-
module, optimizer, criterion, loss_reduction
426+
criterion = self._prepare_criterion(
427+
module=module,
428+
optimizer=optimizer,
429+
criterion=criterion,
430+
loss_reduction=loss_reduction,
431+
**kwargs,
408432
)
433+
409434
return module, optimizer, criterion, data_loader
410435

411436
return module, optimizer, data_loader

0 commit comments

Comments
 (0)