Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
GradSampleModuleNoOp
  • Loading branch information
Alex Sablayrolles committed Sep 9, 2022
commit 8eec239385047efb6f6105c349618a766f43cbc0
46 changes: 46 additions & 0 deletions opacus/grad_sample/gsm_no_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
from opacus.grad_sample.gsm_base import AbstractGradSampleModule


class GradSampleModuleNoOp(AbstractGradSampleModule):
"""
ExpandedWeights-based implementation of AbstractGradSampleModule

Computes per-sample gradients using PyTorch built-in mechanism of ExpandedWeights.
See README.md for more details
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it relevant?


def __init__(
self,
m: nn.Module,
*,
batch_first=True,
loss_reduction="mean",
):
if not batch_first:
raise NotImplementedError

super().__init__(
m,
batch_first=batch_first,
loss_reduction=loss_reduction,
)

def forward(self, x: torch.Tensor, *args, **kwargs):
return self._module.forward(x, *args, **kwargs)
3 changes: 3 additions & 0 deletions opacus/grad_sample/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .grad_sample_module import GradSampleModule
from .gsm_base import AbstractGradSampleModule
from .gsm_exp_weights import GradSampleModuleExpandedWeights
from .gsm_no_op import GradSampleModuleNoOp


def register_grad_sampler(
Expand Down Expand Up @@ -69,6 +70,8 @@ def get_gsm_class(grad_sample_mode: str) -> Type[AbstractGradSampleModule]:
return GradSampleModule
elif grad_sample_mode == "ew":
return GradSampleModuleExpandedWeights
elif grad_sample_mode == "no_op":
return GradSampleModuleNoOp
else:
raise ValueError(
f"Unexpected grad_sample_mode: {grad_sample_mode}. "
Expand Down
2 changes: 1 addition & 1 deletion opacus/optimizers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def clip_and_accumulate(self):
"""

per_param_norms = [
g.norm(2, dim=tuple(range(1, g.ndim))) for g in self.grad_samples
g.view(len(g), -1).norm(2, dim=-1) for g in self.grad_samples
]
per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)
per_sample_clip_factor = (self.max_grad_norm / (per_sample_norms + 1e-6)).clamp(
Expand Down