Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Finish tests. Einsums work for embeddingbag. Functorch is disabled fo…
…r embeddingbag as it is difficult to make it work.
  • Loading branch information
Alex Sablayrolles committed Oct 31, 2022
commit aeecc27b606dd6bb2374d5af3c48fb70bb6e436d
6 changes: 4 additions & 2 deletions opacus/grad_sample/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,12 @@ def compute_embeddingbag_gradsampler(layer, inputs, backprops):
if i < batch_size - 1:
end = offset[i + 1]
else:
end = index.shape[0] - 1
end = index.shape[0]

if layer.mode == "sum":
gsm[i][index[begin:end]] = backprops[i]
gsm[i][index[begin:end]] += backprops[i]
elif layer.mode == "mean":
gsm[i][index[begin:end]] += backprops[i] / (end - begin)

ret = {}
ret[layer.weight] = gsm
Expand Down
6 changes: 5 additions & 1 deletion opacus/grad_sample/functorch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from opacus.layers.dp_rnn import RNNLinear

import torch.nn as nn

def prepare_layer(layer, batch_first=True):
"""
Expand All @@ -17,6 +17,10 @@ def prepare_layer(layer, batch_first=True):
raise NotImplementedError(
"This layer has buffers and is not supported by Opacus"
)
if type(layer) is nn.EmbeddingBag:
raise NotImplementedError(
"Functorch does not support EmbeddingBag yet"
)
flayer, _ = make_functional(layer)

def compute_loss_stateless_model(params, activations, backprops):
Expand Down
20 changes: 8 additions & 12 deletions opacus/grad_sample/grad_sample_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,16 +378,14 @@ def rearrange_grad_samples(

batch_dim = 0 if batch_first or type(module) is RNNLinear else 1

activations = module.activations.pop()

if not hasattr(module, "max_batch_len"):
# For packed sequences, max_batch_len is set in the forward of the model (e.g. the LSTM)
# Otherwise we infer it here
module.max_batch_len = _get_batch_size(
module=module,
grad_sample=activations[0],
batch_dim=batch_dim,
)
activations = module.activations.pop()

n = module.max_batch_len
if loss_reduction == "mean":
Expand Down Expand Up @@ -477,28 +475,26 @@ def validate(


def _get_batch_size(
*, module: nn.Module, grad_sample: torch.Tensor, batch_dim: int
*, module: nn.Module, batch_dim: int
) -> int:
"""
Computes and returns the maximum batch size which is the maximum of the dimension values
along 'batch_dim' axis over module.activations + [grad_sample], where module.activations is
along 'batch_dim' axis over module.activations, where module.activations is
a list.

If module.activations is a not a list, then return grad_sample.shape[batch_dim].

Args:
module: input module
grad_sample: per sample gradient tensor
batch_dim: batch dimension

Returns:
Maximum sequence length in a batch
"""

max_batch_len = 0
for out in module.activations:
if out[0].shape[batch_dim] > max_batch_len:
max_batch_len = out[0].shape[batch_dim]
# out is typically a tuple of one element (x)
# for embedding bag, it is a tuple of two elements (x, offsets)
# where len(offsets) = batch_size
if out[-1].shape[batch_dim] > max_batch_len:
max_batch_len = out[-1].shape[batch_dim]

max_batch_len = max(max_batch_len, grad_sample.shape[batch_dim])
return max_batch_len
27 changes: 21 additions & 6 deletions opacus/tests/grad_samples/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ def __init__(self, module: nn.Module, loss_reduction: str = "mean"):
self.criterion = nn.L1Loss(reduction=loss_reduction)

def forward(self, x):
x = self.wrapped_module(x)
if type(x) is tuple:
x = self.wrapped_module(*x)
else:
x = self.wrapped_module(x)
if type(x) is PackedSequence:
loss = _compute_loss_packedsequences(self.criterion, x)
else:
Expand Down Expand Up @@ -106,6 +109,7 @@ def compute_microbatch_grad_sample(
module: nn.Module,
batch_first=True,
loss_reduction="mean",
chunk_method=iter,
) -> Dict[str, torch.tensor]:
"""
Computes per-sample gradients with the microbatch method, i.e. by computing normal gradients
Expand All @@ -117,6 +121,8 @@ def compute_microbatch_grad_sample(
module: The ``ModelWithLoss`` that wraps the nn.Module you want to test.
batch_first: Whether batch size is the first dimension (as opposed to the second).
Defaults to True.
loss_reduction: What reduction to apply to the loss. Defaults to "mean".
chunk_method: The method to use to split the batch into microbatches. Defaults to ``iter``.

Returns:
Dictionary mapping parameter_name -> per-sample-gradient for that parameter
Expand All @@ -136,12 +142,14 @@ def compute_microbatch_grad_sample(

# Invariant: x is [B, T, ...]

for x_i in x:
for x_i in chunk_method(x):
# x_i is [T, ...]
x_i = x_i.unsqueeze(
0 if batch_first else 1
) # x_i of size [1, T, ...] if batch_first, else [T, 1, ...]
module.zero_grad()
if type(x_i) is not tuple:
# EmbeddingBag provides tuples
x_i = x_i.unsqueeze(
0 if batch_first else 1
) # x_i of size [1, T, ...] if batch_first, else [T, 1, ...]
loss_i = module(x_i)
loss_i.backward()
for p in module.parameters():
Expand Down Expand Up @@ -219,13 +227,17 @@ def run_test(
atol=10e-6,
rtol=10e-5,
ew_compatible=True,
chunk_method=iter,
):
grad_sample_modes = ["hooks", "functorch"]
try:
import functorch # noqa
except ImportError:
grad_sample_modes = ["hooks"]

if type(module) is nn.EmbeddingBag:
grad_sample_modes = ["hooks"]

for grad_sample_mode in grad_sample_modes:
for loss_reduction in ["sum", "mean"]:

Expand All @@ -240,6 +252,7 @@ def run_test(
atol=atol,
rtol=rtol,
grad_sample_mode=grad_sample_mode,
chunk_method=chunk_method,
)
if ew_compatible and batch_first and torch.__version__ >= (1, 13):
self.run_test_with_reduction(
Expand All @@ -250,6 +263,7 @@ def run_test(
atol=atol,
rtol=rtol,
grad_sample_mode="ew",
chunk_method=chunk_method,
)

def run_test_with_reduction(
Expand All @@ -261,6 +275,7 @@ def run_test_with_reduction(
atol=10e-6,
rtol=10e-5,
grad_sample_mode="hooks",
chunk_method=iter,
):
if type(x) is PackedSequence:
x_unpacked = _unpack_packedsequences(x)
Expand All @@ -272,7 +287,7 @@ def run_test_with_reduction(
)
else:
microbatch_grad_samples = self.compute_microbatch_grad_sample(
x, module, batch_first=batch_first, loss_reduction=loss_reduction
x, module, batch_first=batch_first, loss_reduction=loss_reduction, chunk_method=chunk_method
)

opacus_grad_samples = self.compute_opacus_grad_sample(
Expand Down
2 changes: 1 addition & 1 deletion opacus/tests/grad_samples/embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,5 @@ def test_input_across_dims(
size = [N, T, Q, R]

emb = nn.Embedding(V, D)
x = torch.randint(low=0, high=V - 1, size=size)
x = torch.randint(low=0, high=V, size=size)
self.run_test(x, emb, batch_first=batch_first)