Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
188bb42
add state_dict to privacy accountant
karthikprasad May 17, 2022
fe891c9
minor fixes and docstrings in accountant
karthikprasad May 17, 2022
6083f99
some more minor fixes in accountant
karthikprasad May 17, 2022
26348bc
add state dict support for GradSampleModule and save/load checkpoint …
karthikprasad May 17, 2022
01c5c4c
import typevar
karthikprasad May 17, 2022
13c272e
accountant unit test
karthikprasad May 18, 2022
f253a5e
lint fix in test
karthikprasad May 18, 2022
6d11754
fix typo
karthikprasad May 18, 2022
29e2c28
fix var name in test
karthikprasad May 18, 2022
66e8094
fix num steps in test
karthikprasad May 18, 2022
5887f78
fix lint again
karthikprasad May 18, 2022
64a1632
add-ons to GradSampleModule state_dict
karthikprasad May 22, 2022
345d1d7
fixes to GS and test
karthikprasad May 22, 2022
ec2fd92
test privacy engine checkpointing
karthikprasad May 23, 2022
eb3224b
remove debug comments
karthikprasad May 23, 2022
b6d5a86
fix lint
karthikprasad May 23, 2022
29ff3a8
fix lint again
karthikprasad May 23, 2022
f69819b
Minor fixex in FAQ (#430)
Kevin-Abd May 20, 2022
86a8e0d
disable poisson sampling in checkpoints test
karthikprasad May 23, 2022
d3591b0
rebase
karthikprasad May 23, 2022
8ad0545
fix sort order
karthikprasad May 23, 2022
f14e810
fix black
karthikprasad May 23, 2022
2416bb1
some more lints
karthikprasad May 23, 2022
051bc0c
address comments
karthikprasad May 25, 2022
e655df4
address comments
karthikprasad May 31, 2022
ba29092
Merge branch 'main' into master
karthikprasad May 31, 2022
33ee300
fix flake lint
karthikprasad May 31, 2022
1f9297a
Merge branch 'master' of https://github.com/karthikprasad/opacus
karthikprasad May 31, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
test privacy engine checkpointing
  • Loading branch information
karthikprasad committed May 23, 2022
commit ec2fd92f7a4cf566f0eb733815150cdfbcea0f26
40 changes: 25 additions & 15 deletions opacus/privacy_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from opacus.distributed import DifferentiallyPrivateDistributedDataParallel as DPDDP
from opacus.grad_sample.grad_sample_module import GradSampleModule
from opacus.optimizers import DPOptimizer, get_optimizer_class
from opacus.scheduler import _NoiseScheduler
from opacus.validators.module_validator import ModuleValidator
from torch import nn, optim
from torch.nn.parallel import DistributedDataParallel as DDP
Expand Down Expand Up @@ -500,9 +501,10 @@ def save_checkpoint(
*,
path: Union[str, os.PathLike, BinaryIO, IO[bytes]],
module: GradSampleModule,
optimizer: DPOptimizer,
optimizer: Optional[DPOptimizer] = None,
noise_scheduler: Optional[_NoiseScheduler] = None,
module_state_dict_kwargs: Optional[Dict[str, Any]] = None,
save_kwargs: Optional[Dict[str, Any]] = None,
torch_save_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Saves the state_dict of module, optimzer, and accountant at path.
Expand All @@ -511,33 +513,41 @@ def save_checkpoint(
module: GradSampleModule to save; wrapped module's state_dict is saved.
optimizer: DPOptimizer to save; wrapped optimizer's state_dict is saved.
module_state_dict_kwargs: dict of kwargs to pass to ``module.state_dict()``
save_kwargs: dict of kwargs to pass to ``torch.save()``
torch_save_kwargs: dict of kwargs to pass to ``torch.save()``

"""
torch.save(
{
"module_state_dict": module.state_dict(
**(module_state_dict_kwargs or {})
),
"optimizer_state_dict": optimizer.state_dict(),
"privacy_accountant_state_dict": self.accountant.state_dict(),
},
path,
**(save_kwargs or {}),
dict_to_save = {}
dict_to_save["module_state_dict"] = module.state_dict(
**(module_state_dict_kwargs or {})
)
dict_to_save["privacy_accountant_state_dict"] = self.accountant.state_dict()
if optimizer is not None:
dict_to_save["optimizer_state_dict"] = optimizer.state_dict()
if noise_scheduler is not None:
dict_to_save["noise_scheduler_state_dict"] = noise_scheduler.state_dict()

torch.save(dict_to_save, path, **(torch_save_kwargs or {}))

def load_checkpoint(
self,
*,
path: Union[str, os.PathLike, BinaryIO, IO[bytes]],
module: GradSampleModule,
optimizer: DPOptimizer,
optimizer: Optional[DPOptimizer] = None,
noise_scheduler: Optional[_NoiseScheduler] = None,
module_load_dict_kwargs: Optional[Dict[str, Any]] = None,
torch_load_kwargs: Optional[Dict[str, Any]] = None,
):
checkpoint = torch.load(path, **(torch_load_kwargs or {}))
module.load_state_dict(
checkpoint["module_state_dict"], **(module_load_dict_kwargs or {})
)
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.accountant.load_state_dict(checkpoint["privacy_accountant_state_dict"])

optimizer_state_dict = checkpoint.pop("optimizer_state_dict", {})
if optimizer is not None and len(optimizer_state_dict) > 0:
optimizer.load_state_dict(optimizer_state_dict)

noise_scheduler_state_dict = checkpoint.pop("noise_scheduler_state_dict", {})
if noise_scheduler is not None and len(noise_scheduler_state_dict) > 0:
noise_scheduler.load_state_dict(noise_scheduler_state_dict)
123 changes: 94 additions & 29 deletions opacus/tests/privacy_engine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import abc
import io
import math
import unittest
from abc import ABC
Expand All @@ -28,6 +29,7 @@
from opacus import PrivacyEngine
from opacus.layers.dp_multihead_attention import DPMultiheadAttention
from opacus.optimizers.optimizer import _generate_noise
from opacus.scheduler import _NoiseScheduler, StepNoise
from opacus.utils.module_utils import are_state_dict_equal
from opacus.validators.errors import UnsupportedModuleError
from torch.utils.data import DataLoader, Dataset, TensorDataset
Expand Down Expand Up @@ -462,6 +464,69 @@ def test_deterministic_run(self):
"Model parameters after deterministic run must match",
)

@given(
noise_scheduler=st.sampled_from([None, StepNoise])
)
@settings(deadline=None)
def test_checkpoints(self, noise_scheduler: Optional[_NoiseScheduler]):
# create a set of components: set 1
torch.manual_seed(1)
m1, opt1, dl1, pe1 = self._init_private_training(noise_multiplier=1.0)
s1 = noise_scheduler(optimizer=opt1, step_size=1, gamma=1.0) if noise_scheduler is not None else None
# create a different set of components: set 2
torch.manual_seed(2)
m2, opt2, _, pe2 = self._init_private_training(noise_multiplier=2.0)
s2 = noise_scheduler(optimizer=opt2, step_size=1, gamma=2.0) if noise_scheduler is not None else None

# check that two sets of components are different
self.assertFalse(are_state_dict_equal(m1._module.state_dict(), m2._module.state_dict()))
if noise_scheduler:
self.assertNotEqual(s1.state_dict(), s2.state_dict())
self.assertNotEqual(opt1.noise_multiplier, opt2.noise_multiplier)

# train set 1 for a few steps
self._train_steps(m1, opt1, dl1)
if noise_scheduler:
s1.step()

# load into set 2
with io.BytesIO() as bytesio:
pe1.save_checkpoint(path=bytesio, module=m1, optimizer=opt1, noise_scheduler=s1)
bytesio.seek(0)
pe2.load_checkpoint(path=bytesio, module=m2, optimizer=opt2, noise_scheduler=s2)

# check the two sets of components are now the same
self.assertTrue(are_state_dict_equal(m1._module.state_dict(), m2._module.state_dict()))
if noise_scheduler:
self.assertEqual(s1.state_dict(), s2.state_dict())
# check that non-state params are still different
self.assertNotEqual(opt1.noise_multiplier, opt2.noise_multiplier)

# train the now loaded set 2 some more (change noise multiplier before doing so)
opt2.noise_multiplier = 1.0
self._train_steps(m2, opt2, dl1)
if noise_scheduler:
s2.step()

# recreate set 1 from scratch (set11) and check it is different from the trained set 2
torch.manual_seed(1)
m11, opt11, dl11, _ = self._init_private_training(noise_multiplier=1.0)
s11 = noise_scheduler(optimizer=opt11, step_size=1, gamma=1.0) if noise_scheduler is not None else None
self.assertFalse(are_state_dict_equal(m2._module.state_dict(), m11._module.state_dict()))
if noise_scheduler:
self.assertNotEqual(s2.state_dict(), s11.state_dict())
# train the recreated set for the same number of steps
self._train_steps(m11, opt11, dl11)
if noise_scheduler:
s11.step()
self._train_steps(m11, opt11, dl11)
if noise_scheduler:
s11.step()
# check that recreated set is now same as the original set 1 after training
self.assertTrue(are_state_dict_equal(m2._module.state_dict(), m11._module.state_dict()))
if noise_scheduler:
self.assertEqual(s2.state_dict(), s11.state_dict())

@given(
noise_multiplier=st.floats(0.5, 5.0),
max_steps=st.integers(8, 10),
Expand Down Expand Up @@ -664,26 +729,26 @@ def batch_second_collate(batch):
return data, labels


class PrivacyEngineTextTest(BasePrivacyEngineTest, unittest.TestCase):
def setUp(self):
super().setUp()
self.BATCH_FIRST = False
# class PrivacyEngineTextTest(BasePrivacyEngineTest, unittest.TestCase):
# def setUp(self):
# super().setUp()
# self.BATCH_FIRST = False

def _init_data(self):
x = torch.randint(0, 100, (12, self.DATA_SIZE))
y = torch.randint(0, 12, (self.DATA_SIZE,))
ds = MockTextDataset(x, y)
return DataLoader(
ds,
batch_size=self.BATCH_SIZE,
collate_fn=batch_second_collate,
drop_last=True,
)
# def _init_data(self):
# x = torch.randint(0, 100, (12, self.DATA_SIZE))
# y = torch.randint(0, 12, (self.DATA_SIZE,))
# ds = MockTextDataset(x, y)
# return DataLoader(
# ds,
# batch_size=self.BATCH_SIZE,
# collate_fn=batch_second_collate,
# drop_last=True,
# )

def _init_model(
self, private=False, state_dict=None, model=None, **privacy_engine_kwargs
):
return SampleAttnNet()
# def _init_model(
# self, private=False, state_dict=None, model=None, **privacy_engine_kwargs
# ):
# return SampleAttnNet()


class SampleTiedWeights(nn.Module):
Expand Down Expand Up @@ -712,15 +777,15 @@ def forward(self, x):
return x


class PrivacyEngineTiedWeightsTest(BasePrivacyEngineTest, unittest.TestCase):
def _init_data(self):
ds = TensorDataset(
torch.randint(low=0, high=100, size=(self.DATA_SIZE,)),
torch.randint(low=0, high=100, size=(self.DATA_SIZE,)),
)
return DataLoader(ds, batch_size=self.BATCH_SIZE, drop_last=True)
# class PrivacyEngineTiedWeightsTest(BasePrivacyEngineTest, unittest.TestCase):
# def _init_data(self):
# ds = TensorDataset(
# torch.randint(low=0, high=100, size=(self.DATA_SIZE,)),
# torch.randint(low=0, high=100, size=(self.DATA_SIZE,)),
# )
# return DataLoader(ds, batch_size=self.BATCH_SIZE, drop_last=True)

def _init_model(
self, private=False, state_dict=None, model=None, **privacy_engine_kwargs
):
return SampleTiedWeights(tie=False)
# def _init_model(
# self, private=False, state_dict=None, model=None, **privacy_engine_kwargs
# ):
# return SampleTiedWeights(tie=False)