1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- import io
1716import unittest
18- from typing import Dict , Iterable , List , Tuple , Union
17+ from typing import Tuple , Union
1918
20- import numpy as np
2119import torch
2220import torch .nn as nn
2321import torch .nn .functional as F
24- from opacus .grad_sample import wrap_model
25- from opacus .utils .module_utils import trainable_parameters
26- from opacus .utils .packed_sequences import compute_seq_lengths
27- from torch .nn .utils .rnn import PackedSequence , pad_packed_sequence
22+ from opacus .utils .per_sample_gradients_utils import (
23+ compute_grad_samples_microbatch_and_opacus ,
24+ compute_opacus_grad_sample ,
25+ is_batch_empty ,
26+ )
27+ from torch .nn .utils .rnn import PackedSequence
2828from torch .testing import assert_close
2929
3030
@@ -36,196 +36,12 @@ def shrinker(x, factor: int = 2):
3636 return max (1 , x // factor ) # if avoid returning 0 for x == 1
3737
3838
39- def is_batch_empty (batch : Union [torch .Tensor , Iterable [torch .Tensor ]]):
40- if type (batch ) is torch .Tensor :
41- return batch .numel () == 0
42- else :
43- return batch [0 ].numel () == 0
44-
45-
46- class ModelWithLoss (nn .Module ):
47- """
48- To test the gradients of a module, we need to have a loss.
49- This module makes it easy to get a loss from any nn.Module, and automatically generates
50- a target y vector for it in the forward (of all zeros of the correct size).
51- This reduces boilerplate while testing.
52- """
53-
54- supported_reductions = ["mean" , "sum" ]
55-
56- def __init__ (self , module : nn .Module , loss_reduction : str = "mean" ):
57- """
58- Instantiates this module.
59-
60- Args:
61- module: The nn.Module you want to test.
62- loss_reduction: What reduction to apply to the loss. Defaults to "mean".
63-
64- Raises:
65- ValueError: If ``loss_reduction`` is not among those supported.
66- """
67- super ().__init__ ()
68- self .wrapped_module = module
69-
70- if loss_reduction not in self .supported_reductions :
71- raise ValueError (
72- f"Passed loss_reduction={ loss_reduction } . Only { self .supported_reductions } supported."
73- )
74- self .criterion = nn .L1Loss (reduction = loss_reduction )
75-
76- def forward (self , x ):
77- if type (x ) is tuple :
78- x = self .wrapped_module (* x )
79- else :
80- x = self .wrapped_module (x )
81- if type (x ) is PackedSequence :
82- loss = _compute_loss_packedsequences (self .criterion , x )
83- else :
84- y = torch .zeros_like (x )
85- loss = self .criterion (x , y )
86- return loss
87-
88-
89- def clone_module (module : nn .Module ) -> nn .Module :
90- """
91- Handy utility to clone an nn.Module. PyTorch doesn't always support copy.deepcopy(), so it is
92- just easier to serialize the model to a BytesIO and read it from there.
93-
94- Args:
95- module: The module to clone
96-
97- Returns:
98- The clone of ``module``
99- """
100- with io .BytesIO () as bytesio :
101- torch .save (module , bytesio )
102- bytesio .seek (0 )
103- module_copy = torch .load (bytesio )
104- return module_copy
105-
106-
10739class GradSampleHooks_test (unittest .TestCase ):
10840 """
10941 Set of common testing utils. It is meant to be subclassed by your test.
11042 See other tests as an example of how this is done.
11143 """
11244
113- def compute_microbatch_grad_sample (
114- self ,
115- x : Union [torch .Tensor , List [torch .Tensor ]],
116- module : nn .Module ,
117- batch_first = True ,
118- loss_reduction = "mean" ,
119- chunk_method = iter ,
120- ) -> Dict [str , torch .tensor ]:
121- """
122- Computes per-sample gradients with the microbatch method, i.e. by computing normal gradients
123- with batch_size set to 1, and manually accumulating them. This is our reference for testing
124- as this method is obviously correct, but slow.
125-
126- Args:
127- x: The tensor in input to the ``module``
128- module: The ``ModelWithLoss`` that wraps the nn.Module you want to test.
129- batch_first: Whether batch size is the first dimension (as opposed to the second).
130- Defaults to True.
131- loss_reduction: What reduction to apply to the loss. Defaults to "mean".
132- chunk_method: The method to use to split the batch into microbatches. Defaults to ``iter``.
133-
134- Returns:
135- Dictionary mapping parameter_name -> per-sample-gradient for that parameter
136- """
137- torch .use_deterministic_algorithms (True )
138- torch .manual_seed (0 )
139- np .random .seed (0 )
140-
141- module = ModelWithLoss (clone_module (module ), loss_reduction )
142-
143- for _ , p in trainable_parameters (module ):
144- p .microbatch_grad_sample = []
145-
146- if not batch_first and type (x ) is not list :
147- # This allows us to iterate with x_i
148- x = x .transpose (0 , 1 )
149-
150- # Invariant: x is [B, T, ...]
151-
152- for x_i in chunk_method (x ):
153- # x_i is [T, ...]
154- module .zero_grad ()
155- if type (x_i ) is not tuple :
156- # EmbeddingBag provides tuples
157- x_i = x_i .unsqueeze (
158- 0 if batch_first else 1
159- ) # x_i of size [1, T, ...] if batch_first, else [T, 1, ...]
160- loss_i = module (x_i )
161- loss_i .backward ()
162- for p in module .parameters ():
163- p .microbatch_grad_sample .append (p .grad .detach ().clone ())
164-
165- for _ , p in trainable_parameters (module ):
166- if batch_first :
167- p .microbatch_grad_sample = torch .stack (
168- p .microbatch_grad_sample , dim = 0 # [B, T, ...]
169- )
170- else :
171- p .microbatch_grad_sample = torch .stack (
172- p .microbatch_grad_sample , dim = 1 # [T, B, ...]
173- ).transpose (
174- 0 , 1
175- ) # Opacus's semantics is that grad_samples are ALWAYS batch_first: [B, T, ...]
176-
177- microbatch_grad_samples = {
178- name : p .microbatch_grad_sample
179- for name , p in trainable_parameters (module .wrapped_module )
180- }
181- return microbatch_grad_samples
182-
183- def compute_opacus_grad_sample (
184- self ,
185- x : Union [torch .Tensor , PackedSequence ],
186- module : nn .Module ,
187- batch_first = True ,
188- loss_reduction = "mean" ,
189- grad_sample_mode = "hooks" ,
190- ) -> Dict [str , torch .tensor ]:
191- """
192- Runs Opacus to compute per-sample gradients and return them for testing purposes.
193-
194- Args:
195- x: The tensor in input to the ``module``
196- module: The ``ModelWithLoss`` that wraps the nn.Module you want to test.
197- batch_first: Whether batch size is the first dimension (as opposed to the second).
198- Defaults to True.
199- loss_reduction: What reduction to apply to the loss. Defaults to "mean".
200-
201- Returns:
202- Dictionary mapping parameter_name -> per-sample-gradient for that parameter
203- """
204- torch .use_deterministic_algorithms (True )
205- torch .manual_seed (0 )
206- np .random .seed (0 )
207-
208- gs_module = wrap_model (
209- model = clone_module (module ),
210- grad_sample_mode = grad_sample_mode ,
211- batch_first = batch_first ,
212- loss_reduction = loss_reduction ,
213- )
214- grad_sample_module = ModelWithLoss (gs_module , loss_reduction )
215-
216- grad_sample_module .zero_grad ()
217- loss = grad_sample_module (x )
218- loss .backward ()
219-
220- opacus_grad_samples = {
221- name : p .grad_sample
222- for name , p in trainable_parameters (
223- grad_sample_module .wrapped_module ._module
224- )
225- }
226-
227- return opacus_grad_samples
228-
22945 def run_test (
23046 self ,
23147 x : Union [torch .Tensor , PackedSequence , Tuple ],
@@ -237,18 +53,17 @@ def run_test(
23753 chunk_method = iter ,
23854 ):
23955 grad_sample_modes = ["hooks" , "functorch" ]
240- try :
241- import functorch # noqa
242- except ImportError :
243- grad_sample_modes = ["hooks" ]
24456
24557 if type (module ) is nn .EmbeddingBag or (
24658 type (x ) is not PackedSequence and is_batch_empty (x )
24759 ):
24860 grad_sample_modes = ["hooks" ]
24961
250- for grad_sample_mode in grad_sample_modes :
251- for loss_reduction in ["sum" , "mean" ]:
62+ if ew_compatible and batch_first and torch .__version__ >= (1 , 13 ):
63+ grad_sample_modes += ["ew" ]
64+
65+ for loss_reduction in ["sum" , "mean" ]:
66+ for grad_sample_mode in grad_sample_modes :
25267 with self .subTest (
25368 grad_sample_mode = grad_sample_mode , loss_reduction = loss_reduction
25469 ):
@@ -262,17 +77,6 @@ def run_test(
26277 grad_sample_mode = grad_sample_mode ,
26378 chunk_method = chunk_method ,
26479 )
265- if ew_compatible and batch_first and torch .__version__ >= (1 , 13 ):
266- self .run_test_with_reduction (
267- x ,
268- module ,
269- batch_first = batch_first ,
270- loss_reduction = "sum" ,
271- atol = atol ,
272- rtol = rtol ,
273- grad_sample_mode = "ew" ,
274- chunk_method = chunk_method ,
275- )
27680
27781 def run_test_with_reduction (
27882 self ,
@@ -285,40 +89,27 @@ def run_test_with_reduction(
28589 grad_sample_mode = "hooks" ,
28690 chunk_method = iter ,
28791 ):
288- opacus_grad_samples = self .compute_opacus_grad_sample (
289- x ,
290- module ,
291- batch_first = batch_first ,
292- loss_reduction = loss_reduction ,
293- grad_sample_mode = grad_sample_mode ,
294- )
295-
296- if type (x ) is PackedSequence :
297- x_unpacked = _unpack_packedsequences (x )
298- microbatch_grad_samples = self .compute_microbatch_grad_sample (
299- x_unpacked ,
300- module ,
301- batch_first = batch_first ,
302- loss_reduction = loss_reduction ,
303- )
304- elif not is_batch_empty (x ):
305- microbatch_grad_samples = self .compute_microbatch_grad_sample (
92+ if not type (x ) is PackedSequence and is_batch_empty (x ):
93+ _ = compute_opacus_grad_sample (
30694 x ,
30795 module ,
30896 batch_first = batch_first ,
30997 loss_reduction = loss_reduction ,
310- chunk_method = chunk_method ,
98+ grad_sample_mode = grad_sample_mode ,
31199 )
312- else :
313100 # We've checked opacus can handle 0-sized batch. Microbatch doesn't make sense
314101 return
315-
316- if microbatch_grad_samples .keys () != opacus_grad_samples .keys ():
317- raise ValueError (
318- "Keys not matching! "
319- f"Keys only in microbatch: { microbatch_grad_samples .keys () - opacus_grad_samples .keys ()} ; "
320- f"Keys only in Opacus: { opacus_grad_samples .keys () - microbatch_grad_samples .keys ()} "
321- )
102+ (
103+ microbatch_grad_samples ,
104+ opacus_grad_samples ,
105+ ) = compute_grad_samples_microbatch_and_opacus (
106+ x ,
107+ module ,
108+ batch_first = batch_first ,
109+ loss_reduction = loss_reduction ,
110+ grad_sample_mode = grad_sample_mode ,
111+ chunk_method = chunk_method ,
112+ )
322113
323114 self .check_shapes (microbatch_grad_samples , opacus_grad_samples , loss_reduction )
324115 self .check_values (
@@ -388,59 +179,3 @@ def check_values(
388179 f"A total of { len (failed )} values do not match "
389180 f"for loss_reduction={ loss_reduction } : \n \t { failed_str } "
390181 )
391-
392-
393- def _unpack_packedsequences (X : PackedSequence ) -> List [torch .Tensor ]:
394- r"""
395- Produces a list of tensors from X (PackedSequence) such that this list was used to create X with batch_first=True
396-
397- Args:
398- X: A PackedSequence from which the output list of tensors will be produced.
399-
400- Returns:
401- unpacked_data: The list of tensors produced from X.
402- """
403-
404- X_padded = pad_packed_sequence (X )
405- X_padded = X_padded [0 ].permute ((1 , 0 , 2 ))
406-
407- if X .sorted_indices is not None :
408- X_padded = X_padded [X .sorted_indices ]
409-
410- seq_lens = compute_seq_lengths (X .batch_sizes )
411- unpacked_data = [0 ] * len (seq_lens )
412- for idx , length in enumerate (seq_lens ):
413- unpacked_data [idx ] = X_padded [idx ][:length , :]
414-
415- return unpacked_data
416-
417-
418- def _compute_loss_packedsequences (
419- criterion : nn .L1Loss , x : PackedSequence
420- ) -> torch .Tensor :
421- r"""
422- This function computes the loss in a different way for 'mean' reduced L1 loss while for 'sum' reduced L1 loss,
423- it computes the same way as with non-packed data. For 'mean' reduced L1 loss, it transforms x (PackedSequence)
424- into a list of tensors such that this list of tensors was used to create this PackedSequence in the first
425- place using batch_first=True and then takes the mean of the loss values produced from applying criterion on
426- each sequence sample.
427-
428- Args:
429- criterion: An L1 loss function with reduction either set to 'sum' or 'mean'.
430- x: Data in the form of a PackedSequence.
431-
432- Returns:
433- A loss variable, reduced either using summation or averaging from L1 errors.
434- """
435-
436- if criterion .reduction == "sum" :
437- y = torch .zeros_like (x [0 ])
438- return criterion (x [0 ], y )
439- elif criterion .reduction == "mean" :
440- x = _unpack_packedsequences (x )
441- loss_sum = 0
442- for x_i in x :
443- y_i = torch .zeros_like (x_i )
444- loss_sum += criterion (x_i , y_i )
445- loss_mean = loss_sum / len (x )
446- return loss_mean
0 commit comments