Skip to content

Commit 8fb03aa

Browse files
authored
Merge branch 'main' into benchmarks-ci
2 parents 82474f1 + 12cf9ed commit 8fb03aa

26 files changed

+798
-672
lines changed

.circleci/config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ commands:
204204
echo "Using $(python -V) ($(which python))"
205205
echo "Using $(pip -V) ($(which pip))"
206206
pip install --user datasets transformers
207-
python examples/imdb.py --lr 0.02 --sigma 0.56 -c 1.0 --batch-size 32 --max-sequence-length 256 --epochs 1 --data-root runs/imdb/data --device <<parameters.device>>
207+
python examples/imdb.py --lr 0.02 --sigma 1.0 -c 1.0 --batch-size 64 --max-sequence-length 256 --epochs 2 --data-root runs/imdb/data --device <<parameters.device>>
208208
python -c "import torch; accuracy = torch.load('run_results_imdb_classification.pt'); exit(0) if (accuracy>0.54 and accuracy<0.66) else exit(1)"
209209
when: always
210210
- store_test_results:

CONTRIBUTING.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,32 @@ Run following command from `website` folder. It will build the docs and serve th
9595
./scripts/build_website.sh
9696
```
9797

98+
You can also perform spell checks on documentation automatically (besides IDEs) using [```sphinxcontrib-spelling```](https://sphinxcontrib-spelling.readthedocs.io/en/latest/install.html)
99+
Note that you will also need [```PyEnchant```](https://pyenchant.github.io/pyenchant/) to run ```sphinxcontrib-spelling```, and thus the Enchant C library. Use this guide for ```PyEnchant```.
100+
101+
Steps:
102+
1. Install the extension with pip: ```pip install sphinxcontrib-spelling```
103+
2. Add ```sphinxcontrib.spelling``` to the extensions list in ```conf.py```.
104+
3. Install ```PyEnchant```. Please follow the [installation guide](https://pyenchant.github.io/pyenchant/install.html). Noticed that Apple Silicons may require a way around under section "Apple Silicon related errors".
105+
4. Make sure you have a ```source``` and ```build``` folder. Pass "spelling" as the builder argument to ```sphinx-build```.
106+
```
107+
cd website/sphnix
108+
mkdir build # if you do not already have one
109+
sphinx-build -b spelling source build
110+
```
111+
5. Find files with spelling errors in ```build``` (remember to check each folder). A file will be generated for each source file that contains spelling error. Example:
112+
* File name: ```batch_memory_manager.spelling```
113+
* File content:
114+
```
115+
../../opacus/utils/batch_memory_manager.py:docstring of opacus.utils.batch_memory_manager.BatchMemoryManager:5: (occasinal) safeguarding against occasinal large batches produced by
116+
../../opacus/utils/batch_memory_manager.py:docstring of opacus.utils.batch_memory_manager.BatchMemoryManager:13: (optimzer) On every step optimzer will check if the batch was the last physical batch comprising
117+
../../opacus/utils/batch_memory_manager.py:docstring of opacus.utils.batch_memory_manager.BatchMemoryManager:14: (behaviour) a logical one, and will change behaviour accordignly.
118+
../../opacus/utils/batch_memory_manager.py:docstring of opacus.utils.batch_memory_manager.BatchMemoryManager:14: (accordignly) a logical one, and will change behaviour accordignly.
119+
../../opacus/utils/batch_memory_manager.py:docstring of opacus.utils.batch_memory_manager.BatchSplittingSampler:4: (physocal) Used to split large logical batches into physocal batches of a smaller size,
120+
```
121+
6. Manually review the spelling files and make changes in source files accordingly. Some detections are not perfect. For example, "nn" (from torch.nn) can be detected as a spelling error.
122+
123+
98124
## Pull Requests
99125
We actively welcome your pull requests.
100126

dev_requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
torch==1.8.1
1+
torch
22
torchvision>=0.9.1
33
tqdm>=4.40
44
requests>=2.25.1

opacus/accountants/accountant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def get_optimizer_hook_fn(
7272
"""
7373
Returns a callback function which can be used to attach to DPOptimizer
7474
Args:
75-
sample_rate: Expected samping rate used for accounting
75+
sample_rate: Expected sampling rate used for accounting
7676
"""
7777

7878
def hook_fn(optim: DPOptimizer):
@@ -88,7 +88,7 @@ def hook_fn(optim: DPOptimizer):
8888

8989
def state_dict(self, destination: T_state_dict = None) -> T_state_dict:
9090
"""
91-
Retruns a dictionary containing the state of the accountant.
91+
Returns a dictionary containing the state of the accountant.
9292
Args:
9393
destination: a mappable object to populate the current state_dict into.
9494
If this arg is None, an OrderedDict is created and populated.

opacus/grad_sample/functorch.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from opacus.layers.dp_rnn import RNNLinear
2+
3+
4+
def prepare_layer(layer, batch_first=True):
5+
"""
6+
Prepare a layer to compute grad samples using functorch.
7+
The grad samples are computed by redoing the forward and
8+
backward passes on the functional version of the module.
9+
10+
Args:
11+
layer: the layer to prepare
12+
batch_first: whether the input is batch_first or not
13+
"""
14+
from functorch import grad, make_functional, vmap
15+
16+
if len(list(layer.buffers())) > 0:
17+
raise NotImplementedError(
18+
"This layer has buffers and is not supported by Opacus"
19+
)
20+
flayer, _ = make_functional(layer)
21+
22+
def compute_loss_stateless_model(params, activations, backprops):
23+
if batch_first or type(layer) is RNNLinear:
24+
batched_activations = activations.unsqueeze(0)
25+
batched_backprops = backprops.unsqueeze(0)
26+
else:
27+
# If batch_first is False, the batch dimension is the second dimension
28+
batched_activations = activations.unsqueeze(1)
29+
batched_backprops = backprops.unsqueeze(1)
30+
31+
output = flayer(params, batched_activations)
32+
loss = (output * batched_backprops).sum()
33+
34+
return loss
35+
36+
ft_compute_grad = grad(compute_loss_stateless_model)
37+
# Note that the vmap is done on the first dimension, regardless of batch_first
38+
# This is because the activations and backprops given by the GradSampleModule
39+
# are always batch_first=True
40+
layer.ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, 0, 0))
41+
42+
43+
def ft_compute_per_sample_gradient(layer, activations, backprops):
44+
"""
45+
Compute the per-sample gradient of the layer.
46+
Args:
47+
layer: the layer on which to compute the gradient
48+
activations: the input to the layer
49+
backprops: the gradient of the loss w.r.t. outputs of the layer
50+
"""
51+
parameters = list(layer.parameters())
52+
if not hasattr(layer, "ft_compute_sample_grad"):
53+
prepare_layer(layer)
54+
55+
per_sample_grads = layer.ft_compute_sample_grad(parameters, activations, backprops)
56+
57+
ret = {}
58+
for i_p, p in enumerate(parameters):
59+
ret[p] = per_sample_grads[i_p]
60+
61+
return ret

opacus/grad_sample/grad_sample_module.py

Lines changed: 59 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
from __future__ import annotations
1717

1818
import logging
19+
import warnings
1920
from functools import partial
2021
from typing import List, Tuple
2122

2223
import torch
2324
import torch.nn as nn
25+
from opacus.grad_sample.functorch import ft_compute_per_sample_gradient, prepare_layer
2426
from opacus.grad_sample.gsm_base import AbstractGradSampleModule
25-
from opacus.layers.dp_rnn import DPRNNBase, DPRNNCellBase, RNNLinear
27+
from opacus.layers.dp_rnn import DPGRU, DPLSTM, DPRNN, RNNLinear
2628
from opacus.utils.module_utils import (
2729
requires_grad,
2830
trainable_modules,
@@ -89,6 +91,7 @@ def __init__(
8991
batch_first=True,
9092
loss_reduction="mean",
9193
strict: bool = True,
94+
force_functorch=False,
9295
):
9396
"""
9497
@@ -108,6 +111,9 @@ def __init__(
108111
possible and set to None otherwise. This is not recommended, because
109112
some unsupported modules (e.g. BatchNorm) affect other parameters and
110113
invalidate the concept of per sample gradients for the entire model.
114+
force_functorch: If set to ``True``, will use functorch to compute
115+
all per sample gradients. Otherwise, functorch will be used only
116+
for layers without registered grad sampler methods.
111117
112118
Raises:
113119
NotImplementedError
@@ -128,13 +134,24 @@ def __init__(
128134
)
129135

130136
self.hooks_enabled = False
131-
self.add_hooks(loss_reduction=loss_reduction, batch_first=batch_first)
137+
self.batch_first = batch_first
138+
self.loss_reduction = loss_reduction
139+
self.force_functorch = force_functorch
140+
self.add_hooks(
141+
loss_reduction=loss_reduction,
142+
batch_first=batch_first,
143+
force_functorch=force_functorch,
144+
)
132145

133146
def forward(self, *args, **kwargs):
134147
return self._module(*args, **kwargs)
135148

136149
def add_hooks(
137-
self, *, loss_reduction: str = "mean", batch_first: bool = True
150+
self,
151+
*,
152+
loss_reduction: str = "mean",
153+
batch_first: bool = True,
154+
force_functorch: bool = False,
138155
) -> None:
139156
"""
140157
Adds hooks to model to save activations and backprop values.
@@ -151,6 +168,8 @@ def add_hooks(
151168
``[K, batch_size, ...]``
152169
loss_reduction: Indicates if the loss reduction (for aggregating the gradients)
153170
is a sum or a mean operation. Can take values "sum" or "mean"
171+
force_functorch: If set to ``True``, will use functorch to compute all per sample gradients.
172+
Otherwise, functorch will be used only for layers without registered grad sampler methods.
154173
"""
155174
if hasattr(self._module, "autograd_grad_sample_hooks"):
156175
raise ValueError("Trying to add hooks twice to the same model")
@@ -159,20 +178,27 @@ def add_hooks(
159178
self.autograd_grad_sample_hooks = self._module.autograd_grad_sample_hooks
160179

161180
for _module_name, module in trainable_modules(self._module):
162-
if type(module) in self.GRAD_SAMPLERS:
163-
self.autograd_grad_sample_hooks.append(
164-
module.register_forward_hook(self.capture_activations_hook)
165-
)
181+
# Do not add hooks to DPRNN, DPLSTM or DPGRU as the hooks are handled by the `RNNLinear`
182+
if type(module) in [DPRNN, DPLSTM, DPGRU]:
183+
continue
184+
185+
if force_functorch or not type(module) in self.GRAD_SAMPLERS:
186+
prepare_layer(module, batch_first=batch_first)
187+
188+
self.autograd_grad_sample_hooks.append(
189+
module.register_forward_hook(self.capture_activations_hook)
190+
)
166191

167-
self.autograd_grad_sample_hooks.append(
168-
module.register_backward_hook(
169-
partial(
170-
self.capture_backprops_hook,
171-
loss_reduction=loss_reduction,
172-
batch_first=batch_first,
173-
)
192+
self.autograd_grad_sample_hooks.append(
193+
module.register_backward_hook(
194+
partial(
195+
self.capture_backprops_hook,
196+
loss_reduction=loss_reduction,
197+
batch_first=batch_first,
174198
)
175199
)
200+
)
201+
176202
self.enable_hooks()
177203

178204
def remove_hooks(self) -> None:
@@ -197,6 +223,11 @@ def remove_hooks(self) -> None:
197223
delattr(self, "autograd_grad_sample_hooks")
198224
delattr(self._module, "autograd_grad_sample_hooks")
199225

226+
# Remove functorch hooks
227+
for _module_name, module in trainable_modules(self._module):
228+
if hasattr(module, "ft_compute_sample_grad"):
229+
delattr(module, "ft_compute_sample_grad")
230+
200231
def disable_hooks(self) -> None:
201232
r"""
202233
Globally disable all hooks installed by this library.
@@ -282,7 +313,11 @@ def capture_backprops_hook(
282313
loss_reduction=loss_reduction,
283314
batch_first=batch_first,
284315
)
285-
grad_sampler_fn = self.GRAD_SAMPLERS[type(module)]
316+
if not self.force_functorch and type(module) in self.GRAD_SAMPLERS:
317+
grad_sampler_fn = self.GRAD_SAMPLERS[type(module)]
318+
else:
319+
grad_sampler_fn = ft_compute_per_sample_gradient
320+
286321
grad_samples = grad_sampler_fn(module, activations, backprops)
287322
for param, gs in grad_samples.items():
288323
create_or_accumulate_grad_sample(
@@ -374,10 +409,13 @@ def is_supported(cls, module: nn.Module) -> bool:
374409
Returns:
375410
``True`` if grad sampler is found, ``False`` otherwise
376411
"""
377-
return type(module) in cls.GRAD_SAMPLERS or isinstance(
378-
module, (DPRNNBase, DPRNNCellBase)
412+
warnings.warn(
413+
"GradSampleModule.is_supported is deprecated, as all layers can now be used with functorch.",
414+
DeprecationWarning,
379415
)
380416

417+
return True
418+
381419
@classmethod
382420
def validate(
383421
cls, module: nn.Module, *, strict: bool = False
@@ -409,7 +447,10 @@ def validate(
409447
f"(See opacus.grad_sample.utils.register_grad_sampler)"
410448
)
411449
for m_name, m in trainable_modules(module)
412-
if not cls.is_supported(m)
450+
# With functorch, all modules are trainable
451+
# We still want to avoid module that have buffers (e.g. BatchNorm)
452+
# as the buffers are not private
453+
if len(list(m.buffers())) > 0
413454
]
414455
)
415456
# raise or return errors as needed

opacus/grad_sample/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ def decorator(f):
5252

5353
def wrap_model(model: nn.Module, grad_sample_mode: str, *args, **kwargs):
5454
cls = get_gsm_class(grad_sample_mode)
55+
if grad_sample_mode == "functorch":
56+
kwargs["force_functorch"] = True
5557
return cls(model, *args, **kwargs)
5658

5759

@@ -63,7 +65,7 @@ def get_gsm_class(grad_sample_mode: str) -> Type[AbstractGradSampleModule]:
6365
:param grad_sample_mode:
6466
:return:
6567
"""
66-
if grad_sample_mode == "hooks":
68+
if grad_sample_mode in ["hooks", "functorch"]:
6769
return GradSampleModule
6870
elif grad_sample_mode == "ew":
6971
return GradSampleModuleExpandedWeights

opacus/optimizers/ddp_perlayeroptimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(
6767
class DistributedPerLayerOptimizer(DPOptimizer):
6868
"""
6969
:class:`~opacus.optimizers.optimizer.DPOptimizer` that implements
70-
per layer clipping strategy and is compatible with distibured data parallel
70+
per layer clipping strategy and is compatible with distributed data parallel
7171
"""
7272

7373
def __init__(

opacus/optimizers/optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def _generate_noise(
113113
reference: The reference Tensor to get the appropriate shape and device
114114
for generating the noise
115115
generator: The PyTorch noise generator
116-
secure_mode: boolean showing if "secure" noise need to be generate
116+
secure_mode: boolean showing if "secure" noise need to be generated
117117
(see the notes)
118118
119119
Notes:
@@ -186,7 +186,7 @@ class DPOptimizer(Optimizer):
186186
Examples:
187187
>>> module = MyCustomModel()
188188
>>> optimizer = torch.optim.SGD(module.parameters(), lr=0.1)
189-
>>> dp_optimzer = DPOptimizer(
189+
>>> dp_optimizer = DPOptimizer(
190190
... optimizer=optimizer,
191191
... noise_multiplier=1.0,
192192
... max_grad_norm=1.0,

opacus/tests/grad_sample_module_test.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,18 @@ def __init__(self, in_f, out_f):
212212
def forward(self, x: torch.Tensor):
213213
return F.linear(x, self.p)
214214

215-
with self.assertRaises(NotImplementedError):
216-
GradSampleModule(SimpleLinear(4, 2))
215+
# Should be handled by functorch
216+
try:
217+
gsm = GradSampleModule(SimpleLinear(4, 2))
218+
self.assertTrue(hasattr(gsm._module, "ft_compute_sample_grad"))
219+
except ImportError:
220+
print("Test could not be ran because functorch not available")
217221

218222
# Should not raise exception if strict=False
219-
GradSampleModule(SimpleLinear(4, 2), strict=False)
223+
try:
224+
GradSampleModule(SimpleLinear(4, 2), strict=False)
225+
except ImportError:
226+
print("Test could not be ran because functorch not available")
220227

221228
# Should not fail after relevant grad sampler has been registered
222229
register_grad_sampler(SimpleLinear)(compute_linear_grad_sample)
@@ -226,9 +233,6 @@ def test_custom_module_validation(self):
226233
with self.assertRaises(NotImplementedError):
227234
GradSampleModule(mobilenet_v3_small())
228235

229-
# Should not raise exception if strict=False
230-
GradSampleModule(mobilenet_v3_small(), strict=False)
231-
232236
def test_submodule_access(self):
233237
_ = self.grad_sample_module.fc1
234238
_ = self.grad_sample_module.fc2

0 commit comments

Comments
 (0)