diff --git a/opacus/grad_sample/functorch.py b/opacus/grad_sample/functorch.py index 97779506..ade37a1c 100644 --- a/opacus/grad_sample/functorch.py +++ b/opacus/grad_sample/functorch.py @@ -48,7 +48,7 @@ def ft_compute_per_sample_gradient(layer, activations, backprops): activations: the input to the layer backprops: the gradient of the loss w.r.t. outputs of the layer """ - parameters = list(layer.parameters()) + parameters = list(layer.parameters(recurse=True)) if not hasattr(layer, "ft_compute_sample_grad"): prepare_layer(layer) diff --git a/opacus/grad_sample/grad_sample_module.py b/opacus/grad_sample/grad_sample_module.py index d2fb0987..3b2a226e 100644 --- a/opacus/grad_sample/grad_sample_module.py +++ b/opacus/grad_sample/grad_sample_module.py @@ -18,7 +18,7 @@ import logging import warnings from functools import partial -from typing import List, Tuple +from typing import Iterable, List, Tuple import torch import torch.nn as nn @@ -26,6 +26,7 @@ from opacus.grad_sample.gsm_base import AbstractGradSampleModule from opacus.layers.dp_rnn import DPGRU, DPLSTM, DPRNN, RNNLinear from opacus.utils.module_utils import ( + has_trainable_params, requires_grad, trainable_modules, trainable_parameters, @@ -146,6 +147,21 @@ def __init__( def forward(self, *args, **kwargs): return self._module(*args, **kwargs) + def iterate_submodules(self, module: nn.Module) -> Iterable[nn.Module]: + if has_trainable_params(module): + yield module + + # Don't recurse if module is handled by functorch + if ( + has_trainable_params(module) + and type(module) not in self.GRAD_SAMPLERS + and type(module) not in [DPRNN, DPLSTM, DPGRU] + ): + return + + for m in module.children(): + yield from self.iterate_submodules(m) + def add_hooks( self, *, @@ -177,7 +193,7 @@ def add_hooks( self._module.autograd_grad_sample_hooks = [] self.autograd_grad_sample_hooks = self._module.autograd_grad_sample_hooks - for _module_name, module in trainable_modules(self._module): + for module in self.iterate_submodules(self._module): # Do not add hooks to DPRNN, DPLSTM or DPGRU as the hooks are handled by the `RNNLinear` if type(module) in [DPRNN, DPLSTM, DPGRU]: continue diff --git a/opacus/tests/privacy_engine_test.py b/opacus/tests/privacy_engine_test.py index 90af717a..ad2eb6f3 100644 --- a/opacus/tests/privacy_engine_test.py +++ b/opacus/tests/privacy_engine_test.py @@ -40,6 +40,18 @@ from torchvision import models, transforms from torchvision.datasets import FakeData +from .utils import CustomLinearModule, LinearWithExtraParam + + +def _is_functorch_available(): + try: + # flake8: noqa F401 + import functorch + + return True + except ImportError: + return False + def get_grad_sample_aggregated(tensor: torch.Tensor, loss_type: str = "mean"): if tensor.grad_sample is None: @@ -246,7 +258,7 @@ def _compare_to_vanilla( # vanilla gradient is nearly zero: will match even with clipping continue - atol = 1e-7 if max_steps == 1 else 1e-5 + atol = 1e-7 if max_steps == 1 else 1e-4 self.assertEqual( torch.allclose(vp, pp, atol=atol, rtol=1e-3), expected_match, @@ -265,10 +277,6 @@ def _compare_to_vanilla( do_noise=st.booleans(), use_closure=st.booleans(), max_steps=st.sampled_from([1, 4]), - # do_clip=st.just(False), - # do_noise=st.just(False), - # use_closure=st.just(False), - # max_steps=st.sampled_from([4]), ) @settings(deadline=None) def test_compare_to_vanilla( @@ -799,9 +807,7 @@ def _init_data(self): ) return DataLoader(ds, batch_size=self.BATCH_SIZE, drop_last=False) - def _init_model( - self, private=False, state_dict=None, model=None, **privacy_engine_kwargs - ): + def _init_model(self): return SampleConvNet() @@ -817,9 +823,7 @@ def _init_data(self): ) return DataLoader(ds, batch_size=self.BATCH_SIZE, drop_last=False) - def _init_model( - self, private=False, state_dict=None, model=None, **privacy_engine_kwargs - ): + def _init_model(self): m = SampleConvNet() for p in itertools.chain(m.conv1.parameters(), m.gnorm1.parameters()): p.requires_grad = False @@ -827,6 +831,13 @@ def _init_model( return m +@unittest.skipIf(not _is_functorch_available(), "not supported in this torch version") +class PrivacyEngineConvNetFrozenTestFunctorch(PrivacyEngineConvNetFrozenTest): + def setUp(self): + super().setUp() + self.GRAD_SAMPLE_MODE = "functorch" + + @unittest.skipIf( torch.__version__ < API_CUTOFF_VERSION, "not supported in this torch version" ) @@ -840,6 +851,13 @@ def test_sample_grad_aggregation(self): pass +@unittest.skipIf(not _is_functorch_available(), "not supported in this torch version") +class PrivacyEngineConvNetTestFunctorch(PrivacyEngineConvNetTest): + def setUp(self): + super().setUp() + self.GRAD_SAMPLE_MODE = "functorch" + + class SampleAttnNet(nn.Module): def __init__(self): super().__init__() @@ -919,6 +937,13 @@ def _init_model( return SampleAttnNet() +@unittest.skipIf(not _is_functorch_available(), "not supported in this torch version") +class PrivacyEngineTextTestFunctorch(PrivacyEngineTextTest): + def setUp(self): + super().setUp() + self.GRAD_SAMPLE_MODE = "functorch" + + class SampleTiedWeights(nn.Module): def __init__(self, tie=True): super().__init__() @@ -958,7 +983,39 @@ def _init_data(self): ) return DataLoader(ds, batch_size=self.BATCH_SIZE, drop_last=False) - def _init_model( - self, private=False, state_dict=None, model=None, **privacy_engine_kwargs - ): + def _init_model(self): return SampleTiedWeights(tie=True) + + +@unittest.skipIf(not _is_functorch_available(), "not supported in this torch version") +class PrivacyEngineTiedWeightsTestFunctorch(PrivacyEngineTiedWeightsTest): + def setUp(self): + super().setUp() + self.GRAD_SAMPLE_MODE = "functorch" + + +class ModelWithCustomLinear(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = CustomLinearModule(4, 8) + self.fc2 = LinearWithExtraParam(8, 4) + self.extra_param = nn.Parameter(torch.randn(4, 4)) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + x = x.matmul(self.extra_param) + return x + + +@unittest.skipIf(not _is_functorch_available(), "not supported in this torch version") +class PrivacyEngineCustomLayerTest(BasePrivacyEngineTest, unittest.TestCase): + def _init_data(self): + ds = TensorDataset( + torch.randn(self.DATA_SIZE, 4), + torch.randint(low=0, high=3, size=(self.DATA_SIZE,)), + ) + return DataLoader(ds, batch_size=self.BATCH_SIZE, drop_last=False) + + def _init_model(self): + return ModelWithCustomLinear() diff --git a/opacus/tests/privacy_engine_validation_test.py b/opacus/tests/privacy_engine_validation_test.py index 8548f73f..0ba061d8 100644 --- a/opacus/tests/privacy_engine_validation_test.py +++ b/opacus/tests/privacy_engine_validation_test.py @@ -1,58 +1,16 @@ import unittest import torch -import torch.nn as nn -import torch.nn.functional as F from opacus import PrivacyEngine from opacus.grad_sample.gsm_exp_weights import API_CUTOFF_VERSION from torch.utils.data import DataLoader - -class BasicSupportedModule(nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv1d(in_channels=16, out_channels=8, kernel_size=2) - self.gn = nn.GroupNorm(num_groups=2, num_channels=8) - self.fc = nn.Linear(in_features=4, out_features=8) - self.ln = nn.LayerNorm([8, 8]) - - def forward(self, x): - x = self.conv(x) - x = self.gn(x) - x = self.fc(x) - x = self.ln(x) - return x - - -class CustomLinearModule(nn.Module): - def __init__(self, in_features, out_features): - super().__init__() - self._weight = nn.Parameter(torch.randn(out_features, in_features)) - self._bias = nn.Parameter(torch.randn(out_features)) - - def forward(self, x): - return F.linear(x, self._weight, self._bias) - - -class MatmulModule(nn.Module): - def __init__(self, input_features, output_features): - super().__init__() - self.weight = nn.Parameter(torch.randn(input_features, output_features)) - - def forward(self, x): - return torch.matmul(x, self.weight) - - -class LinearWithExtraParam(nn.Module): - def __init__(self, in_features, out_features): - super().__init__() - self.fc = nn.Linear(in_features, out_features) - self.extra_param = nn.Parameter(torch.randn(out_features, 2)) - - def forward(self, x): - x = self.fc(x) - x = x.matmul(self.extra_param) - return x +from .utils import ( + BasicSupportedModule, + CustomLinearModule, + LinearWithExtraParam, + MatmulModule, +) class PrivacyEngineValidationTest(unittest.TestCase): diff --git a/opacus/tests/utils.py b/opacus/tests/utils.py new file mode 100644 index 00000000..36833977 --- /dev/null +++ b/opacus/tests/utils.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BasicSupportedModule(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv1d(in_channels=16, out_channels=8, kernel_size=2) + self.gn = nn.GroupNorm(num_groups=2, num_channels=8) + self.fc = nn.Linear(in_features=4, out_features=8) + self.ln = nn.LayerNorm([8, 8]) + + def forward(self, x): + x = self.conv(x) + x = self.gn(x) + x = self.fc(x) + x = self.ln(x) + return x + + +class CustomLinearModule(nn.Module): + def __init__(self, in_features: int, out_features: int): + super().__init__() + self._weight = nn.Parameter(torch.randn(out_features, in_features)) + self._bias = nn.Parameter(torch.randn(out_features)) + + def forward(self, x): + return F.linear(x, self._weight, self._bias) + + +class MatmulModule(nn.Module): + def __init__(self, input_features: int, output_features: int): + super().__init__() + self.weight = nn.Parameter(torch.randn(input_features, output_features)) + + def forward(self, x): + return torch.matmul(x, self.weight) + + +class LinearWithExtraParam(nn.Module): + def __init__(self, in_features: int, out_features: int, hidden_dim: int = 8): + super().__init__() + self.fc = nn.Linear(in_features, hidden_dim) + self.extra_param = nn.Parameter(torch.randn(hidden_dim, out_features)) + + def forward(self, x): + x = self.fc(x) + x = x.matmul(self.extra_param) + return x diff --git a/opacus/utils/module_utils.py b/opacus/utils/module_utils.py index da2f6c9a..28146cef 100644 --- a/opacus/utils/module_utils.py +++ b/opacus/utils/module_utils.py @@ -31,7 +31,11 @@ logger.setLevel(level=logging.INFO) -def parametrized_modules(module: nn.Module) -> Iterable[nn.Module]: +def has_trainable_params(module: nn.Module) -> bool: + return any(p.requires_grad for p in module.parameters(recurse=False)) + + +def parametrized_modules(module: nn.Module) -> Iterable[Tuple[str, nn.Module]]: """ Recursively iterates over all submodules, returning those that have parameters (as opposed to "wrapper modules" that just organize modules).