Skip to content

Commit 8ca0b41

Browse files
Deepak Agrawalfacebook-github-bot
authored andcommitted
Throws an error when params in optimizer are not the same as that of module's in make_private (#439)
Summary: Pull Request resolved: #439 Compare nn.Module.parameters() with list of parameters from all param_groups of optimizer. If they are all not equal then raise error "Module parameters are different than optimizer Parameters" Differential Revision: D37163873 fbshipit-source-id: b4e27d4b7879c5804424b2d1b18921409073d267
1 parent d079ffd commit 8ca0b41

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed

opacus/privacy_engine.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
import os
1616
import warnings
17-
from typing import IO, Any, BinaryIO, Dict, List, Optional, Tuple, Union
17+
from typing import Any, BinaryIO, Dict, IO, List, Optional, Tuple, Union
1818

1919
import torch
2020
from opacus.accountants import create_accountant
@@ -360,6 +360,21 @@ def make_private(
360360
if noise_generator and self.secure_mode:
361361
raise ValueError("Passing seed is prohibited in secure mode")
362362

363+
# compare module parameter with optimizer parameters
364+
if not all(
365+
torch.eq(i, j).all()
366+
for i, j in zip(
367+
list(module.parameters()),
368+
sum(
369+
[param_group["params"] for param_group in optimizer.param_groups],
370+
[],
371+
),
372+
)
373+
):
374+
raise ValueError(
375+
"Module parameters are different than optimizer Parameters"
376+
)
377+
363378
distributed = isinstance(module, (DPDDP, DDP))
364379

365380
module = self._prepare_model(

opacus/tests/privacy_engine_test.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@
2929
from opacus import PrivacyEngine
3030
from opacus.layers.dp_multihead_attention import DPMultiheadAttention
3131
from opacus.optimizers.optimizer import _generate_noise
32-
from opacus.scheduler import StepNoise, _NoiseScheduler
32+
from opacus.scheduler import _NoiseScheduler, StepNoise
3333
from opacus.utils.module_utils import are_state_dict_equal
3434
from opacus.validators.errors import UnsupportedModuleError
35+
from opacus.validators.module_validator import ModuleValidator
3536
from torch.utils.data import DataLoader, Dataset, TensorDataset
3637
from torchvision import models, transforms
3738
from torchvision.datasets import FakeData
@@ -464,6 +465,41 @@ def test_deterministic_run(self):
464465
"Model parameters after deterministic run must match",
465466
)
466467

468+
def test_param_equal_module_optimizer(self):
469+
"""Test that the privacy engine raises error if nn.Module parameters are not equal to optimizer parameters"""
470+
model = models.densenet121(pretrained=True)
471+
num_ftrs = model.classifier.in_features
472+
model.classifier = nn.Sequential(nn.Linear(num_ftrs, 10), nn.Sigmoid())
473+
optimizer = torch.optim.SGD(
474+
model.parameters(), lr=0.01, momentum=0, weight_decay=0
475+
)
476+
dl = self._init_data()
477+
model = ModuleValidator.fix(model)
478+
privacy_engine = PrivacyEngine()
479+
with self.assertRaisesRegex(
480+
ValueError, "Module parameters are different than optimizer Parameters"
481+
):
482+
_, _, _ = privacy_engine.make_private(
483+
module=model,
484+
optimizer=optimizer,
485+
data_loader=dl,
486+
noise_multiplier=1.1,
487+
max_grad_norm=1.0,
488+
)
489+
490+
# if optimizer is defined after ModuleValidator.fix() then raise no error
491+
optimizer = torch.optim.SGD(
492+
model.parameters(), lr=0.01, momentum=0, weight_decay=0
493+
)
494+
_, _, _ = privacy_engine.make_private(
495+
module=model,
496+
optimizer=optimizer,
497+
data_loader=dl,
498+
noise_multiplier=1.1,
499+
max_grad_norm=1.0,
500+
)
501+
self.assertTrue(1, 1)
502+
467503
@given(noise_scheduler=st.sampled_from([None, StepNoise]))
468504
@settings(deadline=None)
469505
def test_checkpoints(self, noise_scheduler: Optional[_NoiseScheduler]):

0 commit comments

Comments
 (0)