Skip to content
Closed
Changes from 1 commit
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
b8a2ff7
support empty batches in memory manager and optimizer
Oct 25, 2022
2e1b9d7
restore warning
Oct 25, 2022
df9d1ab
disable functorch test for 1.13+
Oct 25, 2022
0268fa1
Merge branch 'main' of github.com:pytorch/opacus into ffuuugor_522
Oct 26, 2022
b952c2a
0-batch tests
Oct 27, 2022
5c7fc6f
lint
Oct 27, 2022
64f08ad
EW test fix
Oct 27, 2022
df7c355
docstring up
Oct 27, 2022
b64b06a
Implement per sample grads util and refactor code
Oct 27, 2022
9b4d5ee
Merge branch 'pytorch:main' into per-sample-grad-correctness-util
psolikov Oct 27, 2022
3f9f9cd
Add docs and refactor
Oct 28, 2022
765b84e
Apply code style fixes
Oct 28, 2022
16477ac
Merge branch 'main' into per-sample-grad-correctness-util
psolikov Oct 28, 2022
585be68
Fix flake8 errors
Oct 28, 2022
82c8f52
Implement per sample grads util and refactor code
Oct 27, 2022
1d957fa
Fixed issue with missing argument in MNIST example (#520)
Oct 27, 2022
4e3a979
Add docs and refactor
Oct 28, 2022
3d0a5db
Apply code style fixes
Oct 28, 2022
36dd386
Functorch gradients: investigation and fix (#510)
Oct 28, 2022
c06ebec
Fix flake8 errors
Oct 28, 2022
5168e20
Add type hints
Oct 31, 2022
b71fb30
Refactor
Oct 31, 2022
cdcae86
Update docstrings
Oct 31, 2022
ab1d6a7
Fix reduction modes for EW
Oct 31, 2022
206a042
Rebase on #530, separate utils tests, refactor
Nov 1, 2022
f9a35de
Optimize imports
Nov 1, 2022
a8aac48
Merge remote-tracking branch 'origin/per-sample-grad-correctness-util…
Nov 1, 2022
f7880d8
Fix test
Nov 1, 2022
1ae50cb
Add utility description to tutorial
Nov 1, 2022
8c67f40
Fix grad samples test
Nov 7, 2022
0faf661
Fixed isort warnings
Nov 7, 2022
a95c95a
Fix grad samples zero batch test
Nov 7, 2022
786a093
Skip functorch test when unavailable
Nov 7, 2022
42866b1
Merge branch 'main' into per-sample-grad-correctness-util
Nov 8, 2022
6402c18
Fix merge
Nov 8, 2022
0b9a1ec
Isort fix
Nov 8, 2022
05fdb4f
Fix docstring
Nov 8, 2022
fd8fbde
Merge branch 'main' into per-sample-grad-correctness-util
psolikov Nov 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update docstrings
  • Loading branch information
Pavel Solikov committed Oct 31, 2022
commit cdcae861a37f5250d776996aa7a0903bb9547a1e
26 changes: 19 additions & 7 deletions opacus/utils/per_sample_gradients_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def compute_microbatch_grad_sample(
as this method is obviously correct, but slow.

Args:
x: The tensor in input to the ``module``
module: The ``ModelWithLoss`` that wraps the nn.Module you want to test.
x: Sample input batch
module: The nn.Module you want to test.
batch_first: Whether batch size is the first dimension (as opposed to the second).
Defaults to True.
loss_reduction: Indicates if the loss reduction (for aggregating the gradients)
Expand Down Expand Up @@ -162,8 +162,8 @@ def compute_opacus_grad_sample(
Runs Opacus to compute per-sample gradients and return them for testing purposes.

Args:
x: The tensor in input to the ``module``
module: The ``ModelWithLoss`` that wraps the nn.Module you want to test.
x: Sample input batch
module: The nn.Module you want to test.
batch_first: Whether batch size is the first dimension (as opposed to the second).
Defaults to True.
loss_reduction: What reduction to apply to the loss. Defaults to "mean".
Expand Down Expand Up @@ -218,16 +218,28 @@ def check_per_sample_gradients_are_correct(
) -> bool:
"""
A utility to check whether per sample gradients are computed correctly with a particular model.
The check is performed by comparing the result of the slow but reliable micro-batch method `compute_microbatch_grad_sample`
with the result of optimized opacus method.

Args:
x: The tensor in input to the ``module``
x: Sample input batch
module: The ``ModelWithLoss`` that wraps the nn.Module you want to check.
batch_first: Whether batch size is the first dimension (as opposed to the second).
Defaults to True.
atol: The relative tolerance parameter (numpy).
rtol: The absolute tolerance parameter (numpy).
atol: The relative tolerance parameter (torch.allclose).
rtol: The absolute tolerance parameter (torch.allclose).
grad_sample_mode: What sampling method to use to get gradients.

Returns: True if per sample gradients were computed correctly. False otherwise.

Example:
>>> x_shape = [N, Z, W]
>>> x = torch.randn(x_shape)
>>> model = nn.Linear(W, W + 2)
>>> assert check_per_sample_gradients_are_correct(
... x,
... model
... ) # This will fail only if the opacus per sample gradients do not match the micro-batch gradients.
"""
if grad_sample_mode == "functorch":
import functorch # noqa
Expand Down