Skip to content

Commit a0a31ba

Browse files
psolikovfacebook-github-bot
authored andcommitted
Per sample grad correctness util (#532)
Summary: ## Types of changes - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [x] Docs change / refactoring / dependency upgrade ## Motivation and Context / Related issue Implementation of the utility described in #484. Refactored the code to avoid code duplicates. ## How Has This Been Tested (if it applies) Added the new utility as a test case for existing tests stored in `tests.grad_samples`. ## Checklist - [x] The documentation is up-to-date with the changes I made. - [x] I have read the **CONTRIBUTING** document and completed the CLA (see **CONTRIBUTING**). - [x] All tests passed, and additional code has been covered with new tests. Pull Request resolved: #532 Reviewed By: karthikprasad Differential Revision: D40797432 Pulled By: ffuuugor fbshipit-source-id: 923009d6f7f6d4c34bce9f4af39945fdf9ff9d57
1 parent 1e661e8 commit a0a31ba

11 files changed

+661
-335
lines changed

opacus/tests/grad_samples/common.py

Lines changed: 26 additions & 291 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,18 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import io
1716
import unittest
18-
from typing import Dict, Iterable, List, Tuple, Union
17+
from typing import Tuple, Union
1918

20-
import numpy as np
2119
import torch
2220
import torch.nn as nn
2321
import torch.nn.functional as F
24-
from opacus.grad_sample import wrap_model
25-
from opacus.utils.module_utils import trainable_parameters
26-
from opacus.utils.packed_sequences import compute_seq_lengths
27-
from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence
22+
from opacus.utils.per_sample_gradients_utils import (
23+
compute_grad_samples_microbatch_and_opacus,
24+
compute_opacus_grad_sample,
25+
is_batch_empty,
26+
)
27+
from torch.nn.utils.rnn import PackedSequence
2828
from torch.testing import assert_close
2929

3030

@@ -36,196 +36,12 @@ def shrinker(x, factor: int = 2):
3636
return max(1, x // factor) # if avoid returning 0 for x == 1
3737

3838

39-
def is_batch_empty(batch: Union[torch.Tensor, Iterable[torch.Tensor]]):
40-
if type(batch) is torch.Tensor:
41-
return batch.numel() == 0
42-
else:
43-
return batch[0].numel() == 0
44-
45-
46-
class ModelWithLoss(nn.Module):
47-
"""
48-
To test the gradients of a module, we need to have a loss.
49-
This module makes it easy to get a loss from any nn.Module, and automatically generates
50-
a target y vector for it in the forward (of all zeros of the correct size).
51-
This reduces boilerplate while testing.
52-
"""
53-
54-
supported_reductions = ["mean", "sum"]
55-
56-
def __init__(self, module: nn.Module, loss_reduction: str = "mean"):
57-
"""
58-
Instantiates this module.
59-
60-
Args:
61-
module: The nn.Module you want to test.
62-
loss_reduction: What reduction to apply to the loss. Defaults to "mean".
63-
64-
Raises:
65-
ValueError: If ``loss_reduction`` is not among those supported.
66-
"""
67-
super().__init__()
68-
self.wrapped_module = module
69-
70-
if loss_reduction not in self.supported_reductions:
71-
raise ValueError(
72-
f"Passed loss_reduction={loss_reduction}. Only {self.supported_reductions} supported."
73-
)
74-
self.criterion = nn.L1Loss(reduction=loss_reduction)
75-
76-
def forward(self, x):
77-
if type(x) is tuple:
78-
x = self.wrapped_module(*x)
79-
else:
80-
x = self.wrapped_module(x)
81-
if type(x) is PackedSequence:
82-
loss = _compute_loss_packedsequences(self.criterion, x)
83-
else:
84-
y = torch.zeros_like(x)
85-
loss = self.criterion(x, y)
86-
return loss
87-
88-
89-
def clone_module(module: nn.Module) -> nn.Module:
90-
"""
91-
Handy utility to clone an nn.Module. PyTorch doesn't always support copy.deepcopy(), so it is
92-
just easier to serialize the model to a BytesIO and read it from there.
93-
94-
Args:
95-
module: The module to clone
96-
97-
Returns:
98-
The clone of ``module``
99-
"""
100-
with io.BytesIO() as bytesio:
101-
torch.save(module, bytesio)
102-
bytesio.seek(0)
103-
module_copy = torch.load(bytesio)
104-
return module_copy
105-
106-
10739
class GradSampleHooks_test(unittest.TestCase):
10840
"""
10941
Set of common testing utils. It is meant to be subclassed by your test.
11042
See other tests as an example of how this is done.
11143
"""
11244

113-
def compute_microbatch_grad_sample(
114-
self,
115-
x: Union[torch.Tensor, List[torch.Tensor]],
116-
module: nn.Module,
117-
batch_first=True,
118-
loss_reduction="mean",
119-
chunk_method=iter,
120-
) -> Dict[str, torch.tensor]:
121-
"""
122-
Computes per-sample gradients with the microbatch method, i.e. by computing normal gradients
123-
with batch_size set to 1, and manually accumulating them. This is our reference for testing
124-
as this method is obviously correct, but slow.
125-
126-
Args:
127-
x: The tensor in input to the ``module``
128-
module: The ``ModelWithLoss`` that wraps the nn.Module you want to test.
129-
batch_first: Whether batch size is the first dimension (as opposed to the second).
130-
Defaults to True.
131-
loss_reduction: What reduction to apply to the loss. Defaults to "mean".
132-
chunk_method: The method to use to split the batch into microbatches. Defaults to ``iter``.
133-
134-
Returns:
135-
Dictionary mapping parameter_name -> per-sample-gradient for that parameter
136-
"""
137-
torch.use_deterministic_algorithms(True)
138-
torch.manual_seed(0)
139-
np.random.seed(0)
140-
141-
module = ModelWithLoss(clone_module(module), loss_reduction)
142-
143-
for _, p in trainable_parameters(module):
144-
p.microbatch_grad_sample = []
145-
146-
if not batch_first and type(x) is not list:
147-
# This allows us to iterate with x_i
148-
x = x.transpose(0, 1)
149-
150-
# Invariant: x is [B, T, ...]
151-
152-
for x_i in chunk_method(x):
153-
# x_i is [T, ...]
154-
module.zero_grad()
155-
if type(x_i) is not tuple:
156-
# EmbeddingBag provides tuples
157-
x_i = x_i.unsqueeze(
158-
0 if batch_first else 1
159-
) # x_i of size [1, T, ...] if batch_first, else [T, 1, ...]
160-
loss_i = module(x_i)
161-
loss_i.backward()
162-
for p in module.parameters():
163-
p.microbatch_grad_sample.append(p.grad.detach().clone())
164-
165-
for _, p in trainable_parameters(module):
166-
if batch_first:
167-
p.microbatch_grad_sample = torch.stack(
168-
p.microbatch_grad_sample, dim=0 # [B, T, ...]
169-
)
170-
else:
171-
p.microbatch_grad_sample = torch.stack(
172-
p.microbatch_grad_sample, dim=1 # [T, B, ...]
173-
).transpose(
174-
0, 1
175-
) # Opacus's semantics is that grad_samples are ALWAYS batch_first: [B, T, ...]
176-
177-
microbatch_grad_samples = {
178-
name: p.microbatch_grad_sample
179-
for name, p in trainable_parameters(module.wrapped_module)
180-
}
181-
return microbatch_grad_samples
182-
183-
def compute_opacus_grad_sample(
184-
self,
185-
x: Union[torch.Tensor, PackedSequence],
186-
module: nn.Module,
187-
batch_first=True,
188-
loss_reduction="mean",
189-
grad_sample_mode="hooks",
190-
) -> Dict[str, torch.tensor]:
191-
"""
192-
Runs Opacus to compute per-sample gradients and return them for testing purposes.
193-
194-
Args:
195-
x: The tensor in input to the ``module``
196-
module: The ``ModelWithLoss`` that wraps the nn.Module you want to test.
197-
batch_first: Whether batch size is the first dimension (as opposed to the second).
198-
Defaults to True.
199-
loss_reduction: What reduction to apply to the loss. Defaults to "mean".
200-
201-
Returns:
202-
Dictionary mapping parameter_name -> per-sample-gradient for that parameter
203-
"""
204-
torch.use_deterministic_algorithms(True)
205-
torch.manual_seed(0)
206-
np.random.seed(0)
207-
208-
gs_module = wrap_model(
209-
model=clone_module(module),
210-
grad_sample_mode=grad_sample_mode,
211-
batch_first=batch_first,
212-
loss_reduction=loss_reduction,
213-
)
214-
grad_sample_module = ModelWithLoss(gs_module, loss_reduction)
215-
216-
grad_sample_module.zero_grad()
217-
loss = grad_sample_module(x)
218-
loss.backward()
219-
220-
opacus_grad_samples = {
221-
name: p.grad_sample
222-
for name, p in trainable_parameters(
223-
grad_sample_module.wrapped_module._module
224-
)
225-
}
226-
227-
return opacus_grad_samples
228-
22945
def run_test(
23046
self,
23147
x: Union[torch.Tensor, PackedSequence, Tuple],
@@ -237,18 +53,17 @@ def run_test(
23753
chunk_method=iter,
23854
):
23955
grad_sample_modes = ["hooks", "functorch"]
240-
try:
241-
import functorch # noqa
242-
except ImportError:
243-
grad_sample_modes = ["hooks"]
24456

24557
if type(module) is nn.EmbeddingBag or (
24658
type(x) is not PackedSequence and is_batch_empty(x)
24759
):
24860
grad_sample_modes = ["hooks"]
24961

250-
for grad_sample_mode in grad_sample_modes:
251-
for loss_reduction in ["sum", "mean"]:
62+
if ew_compatible and batch_first and torch.__version__ >= (1, 13):
63+
grad_sample_modes += ["ew"]
64+
65+
for loss_reduction in ["sum", "mean"]:
66+
for grad_sample_mode in grad_sample_modes:
25267
with self.subTest(
25368
grad_sample_mode=grad_sample_mode, loss_reduction=loss_reduction
25469
):
@@ -262,17 +77,6 @@ def run_test(
26277
grad_sample_mode=grad_sample_mode,
26378
chunk_method=chunk_method,
26479
)
265-
if ew_compatible and batch_first and torch.__version__ >= (1, 13):
266-
self.run_test_with_reduction(
267-
x,
268-
module,
269-
batch_first=batch_first,
270-
loss_reduction="sum",
271-
atol=atol,
272-
rtol=rtol,
273-
grad_sample_mode="ew",
274-
chunk_method=chunk_method,
275-
)
27680

27781
def run_test_with_reduction(
27882
self,
@@ -285,40 +89,27 @@ def run_test_with_reduction(
28589
grad_sample_mode="hooks",
28690
chunk_method=iter,
28791
):
288-
opacus_grad_samples = self.compute_opacus_grad_sample(
289-
x,
290-
module,
291-
batch_first=batch_first,
292-
loss_reduction=loss_reduction,
293-
grad_sample_mode=grad_sample_mode,
294-
)
295-
296-
if type(x) is PackedSequence:
297-
x_unpacked = _unpack_packedsequences(x)
298-
microbatch_grad_samples = self.compute_microbatch_grad_sample(
299-
x_unpacked,
300-
module,
301-
batch_first=batch_first,
302-
loss_reduction=loss_reduction,
303-
)
304-
elif not is_batch_empty(x):
305-
microbatch_grad_samples = self.compute_microbatch_grad_sample(
92+
if not type(x) is PackedSequence and is_batch_empty(x):
93+
_ = compute_opacus_grad_sample(
30694
x,
30795
module,
30896
batch_first=batch_first,
30997
loss_reduction=loss_reduction,
310-
chunk_method=chunk_method,
98+
grad_sample_mode=grad_sample_mode,
31199
)
312-
else:
313100
# We've checked opacus can handle 0-sized batch. Microbatch doesn't make sense
314101
return
315-
316-
if microbatch_grad_samples.keys() != opacus_grad_samples.keys():
317-
raise ValueError(
318-
"Keys not matching! "
319-
f"Keys only in microbatch: {microbatch_grad_samples.keys() - opacus_grad_samples.keys()}; "
320-
f"Keys only in Opacus: {opacus_grad_samples.keys() - microbatch_grad_samples.keys()}"
321-
)
102+
(
103+
microbatch_grad_samples,
104+
opacus_grad_samples,
105+
) = compute_grad_samples_microbatch_and_opacus(
106+
x,
107+
module,
108+
batch_first=batch_first,
109+
loss_reduction=loss_reduction,
110+
grad_sample_mode=grad_sample_mode,
111+
chunk_method=chunk_method,
112+
)
322113

323114
self.check_shapes(microbatch_grad_samples, opacus_grad_samples, loss_reduction)
324115
self.check_values(
@@ -388,59 +179,3 @@ def check_values(
388179
f"A total of {len(failed)} values do not match "
389180
f"for loss_reduction={loss_reduction}: \n\t{failed_str}"
390181
)
391-
392-
393-
def _unpack_packedsequences(X: PackedSequence) -> List[torch.Tensor]:
394-
r"""
395-
Produces a list of tensors from X (PackedSequence) such that this list was used to create X with batch_first=True
396-
397-
Args:
398-
X: A PackedSequence from which the output list of tensors will be produced.
399-
400-
Returns:
401-
unpacked_data: The list of tensors produced from X.
402-
"""
403-
404-
X_padded = pad_packed_sequence(X)
405-
X_padded = X_padded[0].permute((1, 0, 2))
406-
407-
if X.sorted_indices is not None:
408-
X_padded = X_padded[X.sorted_indices]
409-
410-
seq_lens = compute_seq_lengths(X.batch_sizes)
411-
unpacked_data = [0] * len(seq_lens)
412-
for idx, length in enumerate(seq_lens):
413-
unpacked_data[idx] = X_padded[idx][:length, :]
414-
415-
return unpacked_data
416-
417-
418-
def _compute_loss_packedsequences(
419-
criterion: nn.L1Loss, x: PackedSequence
420-
) -> torch.Tensor:
421-
r"""
422-
This function computes the loss in a different way for 'mean' reduced L1 loss while for 'sum' reduced L1 loss,
423-
it computes the same way as with non-packed data. For 'mean' reduced L1 loss, it transforms x (PackedSequence)
424-
into a list of tensors such that this list of tensors was used to create this PackedSequence in the first
425-
place using batch_first=True and then takes the mean of the loss values produced from applying criterion on
426-
each sequence sample.
427-
428-
Args:
429-
criterion: An L1 loss function with reduction either set to 'sum' or 'mean'.
430-
x: Data in the form of a PackedSequence.
431-
432-
Returns:
433-
A loss variable, reduced either using summation or averaging from L1 errors.
434-
"""
435-
436-
if criterion.reduction == "sum":
437-
y = torch.zeros_like(x[0])
438-
return criterion(x[0], y)
439-
elif criterion.reduction == "mean":
440-
x = _unpack_packedsequences(x)
441-
loss_sum = 0
442-
for x_i in x:
443-
y_i = torch.zeros_like(x_i)
444-
loss_sum += criterion(x_i, y_i)
445-
loss_mean = loss_sum / len(x)
446-
return loss_mean

opacus/tests/grad_samples/dp_multihead_attention_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class MultiHeadAttention_test(GradSampleHooks_test):
5353
add_bias_kv=st.booleans(),
5454
add_zero_attn=st.booleans(),
5555
kv_dim=st.booleans(),
56+
test_or_check=st.integers(1, 2),
5657
)
5758
@settings(deadline=10000)
5859
def test_multihead_attention(
@@ -65,6 +66,7 @@ def test_multihead_attention(
6566
add_bias_kv: bool,
6667
add_zero_attn: bool,
6768
kv_dim: bool,
69+
test_or_check: int,
6870
):
6971
if kv_dim:
7072
kdim, vdim = D, D

0 commit comments

Comments
 (0)