Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
188bb42
add state_dict to privacy accountant
karthikprasad May 17, 2022
fe891c9
minor fixes and docstrings in accountant
karthikprasad May 17, 2022
6083f99
some more minor fixes in accountant
karthikprasad May 17, 2022
26348bc
add state dict support for GradSampleModule and save/load checkpoint …
karthikprasad May 17, 2022
01c5c4c
import typevar
karthikprasad May 17, 2022
13c272e
accountant unit test
karthikprasad May 18, 2022
f253a5e
lint fix in test
karthikprasad May 18, 2022
6d11754
fix typo
karthikprasad May 18, 2022
29e2c28
fix var name in test
karthikprasad May 18, 2022
66e8094
fix num steps in test
karthikprasad May 18, 2022
5887f78
fix lint again
karthikprasad May 18, 2022
64a1632
add-ons to GradSampleModule state_dict
karthikprasad May 22, 2022
345d1d7
fixes to GS and test
karthikprasad May 22, 2022
ec2fd92
test privacy engine checkpointing
karthikprasad May 23, 2022
eb3224b
remove debug comments
karthikprasad May 23, 2022
b6d5a86
fix lint
karthikprasad May 23, 2022
29ff3a8
fix lint again
karthikprasad May 23, 2022
f69819b
Minor fixex in FAQ (#430)
Kevin-Abd May 20, 2022
86a8e0d
disable poisson sampling in checkpoints test
karthikprasad May 23, 2022
d3591b0
rebase
karthikprasad May 23, 2022
8ad0545
fix sort order
karthikprasad May 23, 2022
f14e810
fix black
karthikprasad May 23, 2022
2416bb1
some more lints
karthikprasad May 23, 2022
051bc0c
address comments
karthikprasad May 25, 2022
e655df4
address comments
karthikprasad May 31, 2022
ba29092
Merge branch 'main' into master
karthikprasad May 31, 2022
33ee300
fix flake lint
karthikprasad May 31, 2022
1f9297a
Merge branch 'master' of https://github.com/karthikprasad/opacus
karthikprasad May 31, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
address comments
  • Loading branch information
karthikprasad committed May 25, 2022
commit 051bc0cefd2abd28aad52b70efa77e738152e0c7
9 changes: 1 addition & 8 deletions opacus/grad_sample/grad_sample_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,26 +433,19 @@ def state_dict(self, *args, **kwargs) -> Dict:
f"_module.{key}": value
for key, value in self._module.state_dict(*args, **kwargs).items()
}
ret_state_dict["batch_first"] = self.batch_first
ret_state_dict["loss_reduction"] = self.loss_reduction
return ret_state_dict

def load_state_dict(self, state_dict: Dict, **kwargs):
"""
Load the state_dict into the wrapped module
"""
state_dict = state_dict.copy()
self.batch_first = state_dict.pop("batch_first", self.batch_first)
self.loss_reduction = state_dict.pop("loss_reduction", self.loss_reduction)
# remove "_module." prefix before loading into wrapped module
for key in list(state_dict.keys()):
if key.startswith("_module."):
prefix_stripped_key = key[len("_module.") :]
state_dict[prefix_stripped_key] = state_dict.pop(key)
self._module.load_state_dict(state_dict, **kwargs)
# remove and add hooks with the newly loaded loss_reduction and batch_first
self.remove_hooks()
self.add_hooks(loss_reduction=self.loss_reduction, batch_first=self.batch_first)
return self._module.load_state_dict(state_dict, **kwargs)

@classmethod
def is_supported(cls, module: nn.Module) -> bool:
Expand Down
23 changes: 16 additions & 7 deletions opacus/privacy_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ def save_checkpoint(
module: GradSampleModule,
optimizer: Optional[DPOptimizer] = None,
noise_scheduler: Optional[_NoiseScheduler] = None,
checkpoint_dict: Optional[Dict[str, Any]] = None,
module_state_dict_kwargs: Optional[Dict[str, Any]] = None,
torch_save_kwargs: Optional[Dict[str, Any]] = None,
):
Expand All @@ -516,17 +517,17 @@ def save_checkpoint(
torch_save_kwargs: dict of kwargs to pass to ``torch.save()``

"""
dict_to_save = {}
dict_to_save["module_state_dict"] = module.state_dict(
checkpoint_dict = checkpoint_dict or {}
checkpoint_dict["module_state_dict"] = module.state_dict(
**(module_state_dict_kwargs or {})
)
dict_to_save["privacy_accountant_state_dict"] = self.accountant.state_dict()
checkpoint_dict["privacy_accountant_state_dict"] = self.accountant.state_dict()
if optimizer is not None:
dict_to_save["optimizer_state_dict"] = optimizer.state_dict()
checkpoint_dict["optimizer_state_dict"] = optimizer.state_dict()
if noise_scheduler is not None:
dict_to_save["noise_scheduler_state_dict"] = noise_scheduler.state_dict()
checkpoint_dict["noise_scheduler_state_dict"] = noise_scheduler.state_dict()

torch.save(dict_to_save, path, **(torch_save_kwargs or {}))
torch.save(checkpoint_dict, path, **(torch_save_kwargs or {}))

def load_checkpoint(
self,
Expand All @@ -537,7 +538,7 @@ def load_checkpoint(
noise_scheduler: Optional[_NoiseScheduler] = None,
module_load_dict_kwargs: Optional[Dict[str, Any]] = None,
torch_load_kwargs: Optional[Dict[str, Any]] = None,
):
) -> Dict:
checkpoint = torch.load(path, **(torch_load_kwargs or {}))
module.load_state_dict(
checkpoint["module_state_dict"], **(module_load_dict_kwargs or {})
Expand All @@ -547,7 +548,15 @@ def load_checkpoint(
optimizer_state_dict = checkpoint.pop("optimizer_state_dict", {})
if optimizer is not None and len(optimizer_state_dict) > 0:
optimizer.load_state_dict(optimizer_state_dict)
elif (optimizer is not None) ^ (len(optimizer_state_dict) > 0):
# warn if only one of them is available
warnings.warn(
f"optimizer_state_dict has {len(optimizer_state_dict)} items"
f" but optimizer is {'' if optimizer else 'not'} provided."
)

noise_scheduler_state_dict = checkpoint.pop("noise_scheduler_state_dict", {})
if noise_scheduler is not None and len(noise_scheduler_state_dict) > 0:
noise_scheduler.load_state_dict(noise_scheduler_state_dict)

return checkpoint
15 changes: 0 additions & 15 deletions opacus/tests/grad_sample_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,32 +232,17 @@ def test_submodule_access(self):
def test_state_dict(self):
gs_state_dict = self.grad_sample_module.state_dict()
og_state_dict = self.original_model.state_dict()
# check gs specific attributes
self.assertEqual(gs_state_dict["batch_first"], True)
self.assertEqual(gs_state_dict["loss_reduction"], "mean")
# check wrapped module state dict
for key in og_state_dict.keys():
self.assertTrue(f"_module.{key}" in gs_state_dict)
assert_allclose(og_state_dict[key], gs_state_dict[f"_module.{key}"])

def test_load_state_dict(self):
gs_state_dict = self.grad_sample_module.state_dict()
gs_state_dict["loss_reduction"] = "sum"
_ = gs_state_dict.pop("batch_first")

new_gs = GradSampleModule(
SampleConvNet(), batch_first=False, loss_reduction="mean"
)

new_gs_hook_before_load = new_gs.autograd_grad_sample_hooks[0]
new_gs.load_state_dict(gs_state_dict)
new_gs_hook_after_load = new_gs.autograd_grad_sample_hooks[0]

self.assertEqual(new_gs.loss_reduction, "sum") # value should have changed
self.assertEqual(new_gs.batch_first, False) # old value to be retained
self.assertTrue(
new_gs_hook_before_load != new_gs_hook_after_load
) # hook is reset
# wrapped module is the same
for key in self.original_model.state_dict().keys():
self.assertTrue(key in new_gs._module.state_dict())
Expand Down
29 changes: 15 additions & 14 deletions opacus/tests/privacy_engine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,7 @@ def test_checkpoints(self, noise_scheduler: Optional[_NoiseScheduler]):
)

# check that two sets of components are different
self.assertFalse(
are_state_dict_equal(m1._module.state_dict(), m2._module.state_dict())
)
self.assertFalse(are_state_dict_equal(m1.state_dict(), m2.state_dict()))
if noise_scheduler:
self.assertNotEqual(s1.state_dict(), s2.state_dict())
self.assertNotEqual(opt1.noise_multiplier, opt2.noise_multiplier)
Expand All @@ -504,20 +502,27 @@ def test_checkpoints(self, noise_scheduler: Optional[_NoiseScheduler]):
s1.step()

# load into set 2
checkpoint_to_save = {"foo": "bar"}
with io.BytesIO() as bytesio:
pe1.save_checkpoint(
path=bytesio, module=m1, optimizer=opt1, noise_scheduler=s1
path=bytesio,
module=m1,
optimizer=opt1,
noise_scheduler=s1,
checkpoint_dict=checkpoint_to_save,
)
bytesio.seek(0)
pe2.load_checkpoint(
loaded_checkpoint = pe2.load_checkpoint(
path=bytesio, module=m2, optimizer=opt2, noise_scheduler=s2
)

# check the two sets of components are now the same
self.assertEqual(pe1.accountant.state_dict(), pe2.accountant.state_dict())
# check if loaded checkpoint has dummy dict
self.assertTrue(
are_state_dict_equal(m1._module.state_dict(), m2._module.state_dict())
"foo" in loaded_checkpoint and loaded_checkpoint["foo"] == "bar"
)
# check the two sets of components are now the same
self.assertEqual(pe1.accountant.state_dict(), pe2.accountant.state_dict())
self.assertTrue(are_state_dict_equal(m1.state_dict(), m2.state_dict()))
if noise_scheduler:
self.assertEqual(s1.state_dict(), s2.state_dict())
# check that non-state params are still different
Expand All @@ -539,9 +544,7 @@ def test_checkpoints(self, noise_scheduler: Optional[_NoiseScheduler]):
if noise_scheduler is not None
else None
)
self.assertFalse(
are_state_dict_equal(m2._module.state_dict(), m11._module.state_dict())
)
self.assertFalse(are_state_dict_equal(m2.state_dict(), m11.state_dict()))
if noise_scheduler:
self.assertNotEqual(s2.state_dict(), s11.state_dict())
# train the recreated set for the same number of steps
Expand All @@ -552,9 +555,7 @@ def test_checkpoints(self, noise_scheduler: Optional[_NoiseScheduler]):
if noise_scheduler:
s11.step()
# check that recreated set is now same as the original set 1 after training
self.assertTrue(
are_state_dict_equal(m2._module.state_dict(), m11._module.state_dict())
)
self.assertTrue(are_state_dict_equal(m2.state_dict(), m11.state_dict()))
if noise_scheduler:
self.assertEqual(s2.state_dict(), s11.state_dict())

Expand Down