Skip to content

Commit f43033d

Browse files
sayanghoshfacebook-github-bot
authored andcommitted
Replace torch einsum with opt_einsum (#440)
Summary: Pull Request resolved: #440 We are using optimized einsums in place of Pytorch einsums. As per https://optimized-einsum.readthedocs.io/en/stable/ opt einsums are faster and our results on Opacus benchmarking also corroborate it. Differential Revision: D37128344 fbshipit-source-id: 891c1cc3e1348a4965a068d6fd1375eb584805b9
1 parent 7689ff5 commit f43033d

File tree

11 files changed

+30
-17
lines changed

11 files changed

+30
-17
lines changed

opacus/grad_sample/conv.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020
import torch.nn as nn
2121
from opacus.utils.tensor_utils import unfold2d, unfold3d
22+
from opt_einsum import contract
2223

2324
from .utils import register_grad_sampler
2425

@@ -70,7 +71,7 @@ def compute_conv_grad_sample(
7071
ret = {}
7172
if layer.weight.requires_grad:
7273
# n=batch_sz; o=num_out_channels; p=(num_in_channels/groups)*kernel_sz
73-
grad_sample = torch.einsum("noq,npq->nop", backprops, activations)
74+
grad_sample = contract("noq,npq->nop", backprops, activations)
7475
# rearrange the above tensor and extract diagonals.
7576
grad_sample = grad_sample.view(
7677
n,
@@ -80,7 +81,7 @@ def compute_conv_grad_sample(
8081
int(layer.in_channels / layer.groups),
8182
np.prod(layer.kernel_size),
8283
)
83-
grad_sample = torch.einsum("ngrg...->ngr...", grad_sample).contiguous()
84+
grad_sample = contract("ngrg...->ngr...", grad_sample).contiguous()
8485
shape = [n] + list(layer.weight.shape)
8586
ret[layer.weight] = grad_sample.view(shape)
8687

opacus/grad_sample/dp_rnn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020
import torch.nn as nn
2121
from opacus.layers.dp_rnn import RNNLinear
22+
from opt_einsum import contract
2223

2324
from .utils import register_grad_sampler
2425

@@ -40,8 +41,8 @@ def compute_rnn_linear_grad_sample(
4041
"""
4142
ret = {}
4243
if layer.weight.requires_grad:
43-
gs = torch.einsum("n...i,n...j->nij", backprops, activations)
44+
gs = contract("n...i,n...j->nij", backprops, activations)
4445
ret[layer.weight] = gs
4546
if layer.bias is not None and layer.bias.requires_grad:
46-
ret[layer.bias] = torch.einsum("n...k->nk", backprops)
47+
ret[layer.bias] = contract("n...k->nk", backprops)
4748
return ret

opacus/grad_sample/group_norm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020
import torch.nn as nn
2121
import torch.nn.functional as F
22+
from opt_einsum import contract
2223

2324
from .utils import register_grad_sampler
2425

@@ -40,7 +41,7 @@ def compute_group_norm_grad_sample(
4041
ret = {}
4142
if layer.weight.requires_grad:
4243
gs = F.group_norm(activations, layer.num_groups, eps=layer.eps) * backprops
43-
ret[layer.weight] = torch.einsum("ni...->ni", gs)
44+
ret[layer.weight] = contract("ni...->ni", gs)
4445
if layer.bias is not None and layer.bias.requires_grad:
45-
ret[layer.bias] = torch.einsum("ni...->ni", backprops)
46+
ret[layer.bias] = contract("ni...->ni", backprops)
4647
return ret

opacus/grad_sample/instance_norm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919
import torch.nn as nn
2020
import torch.nn.functional as F
21+
from opt_einsum import contract
2122

2223
from .utils import register_grad_sampler
2324

@@ -49,7 +50,7 @@ def compute_instance_norm_grad_sample(
4950
ret = {}
5051
if layer.weight.requires_grad:
5152
gs = F.instance_norm(activations, eps=layer.eps) * backprops
52-
ret[layer.weight] = torch.einsum("ni...->ni", gs)
53+
ret[layer.weight] = contract("ni...->ni", gs)
5354
if layer.bias is not None and layer.bias.requires_grad:
54-
ret[layer.bias] = torch.einsum("ni...->ni", backprops)
55+
ret[layer.bias] = contract("ni...->ni", backprops)
5556
return ret

opacus/grad_sample/linear.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import torch
1919
import torch.nn as nn
20+
from opt_einsum import contract
2021

2122
from .utils import register_grad_sampler
2223

@@ -35,8 +36,8 @@ def compute_linear_grad_sample(
3536
"""
3637
ret = {}
3738
if layer.weight.requires_grad:
38-
gs = torch.einsum("n...i,n...j->nij", backprops, activations)
39+
gs = contract("n...i,n...j->nij", backprops, activations)
3940
ret[layer.weight] = gs
4041
if layer.bias is not None and layer.bias.requires_grad:
41-
ret[layer.bias] = torch.einsum("n...k->nk", backprops)
42+
ret[layer.bias] = contract("n...k->nk", backprops)
4243
return ret

opacus/optimizers/adaclipoptimizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Callable, Optional
1919

2020
import torch
21+
from opt_einsum import contract
2122
from torch.optim import Optimizer
2223

2324
from .optimizer import (
@@ -108,7 +109,7 @@ def clip_and_accumulate(self):
108109
_check_processed_flag(p.grad_sample)
109110

110111
grad_sample = _get_flat_grad_sample(p)
111-
grad = torch.einsum("i,i...", per_sample_clip_factor, grad_sample)
112+
grad = contract("i,i...", per_sample_clip_factor, grad_sample)
112113

113114
if p.summed_grad is not None:
114115
p.summed_grad += grad

opacus/optimizers/ddp_perlayeroptimizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Callable, List, Optional
1919

2020
import torch
21+
from opt_einsum import contract
2122
from torch import nn
2223
from torch.optim import Optimizer
2324

@@ -28,7 +29,7 @@ def _clip_and_accumulate_parameter(p: nn.Parameter, max_grad_norm: float):
2829
per_sample_norms = p.grad_sample.view(len(p.grad_sample), -1).norm(2, dim=-1)
2930
per_sample_clip_factor = (max_grad_norm / (per_sample_norms + 1e-6)).clamp(max=1.0)
3031

31-
grad = torch.einsum("i,i...", per_sample_clip_factor, p.grad_sample)
32+
grad = contract("i,i...", per_sample_clip_factor, p.grad_sample)
3233
if p.summed_grad is not None:
3334
p.summed_grad += grad
3435
else:

opacus/optimizers/optimizer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
import logging
1818
from typing import Callable, List, Optional, Union
1919

20-
import torch
2120
from opacus.optimizers.utils import params
21+
22+
import torch
23+
from opt_einsum import contract
2224
from torch import nn
2325
from torch.optim import Optimizer
2426

@@ -404,7 +406,7 @@ def clip_and_accumulate(self):
404406
_check_processed_flag(p.grad_sample)
405407

406408
grad_sample = _get_flat_grad_sample(p)
407-
grad = torch.einsum("i,i...", per_sample_clip_factor, grad_sample)
409+
grad = contract("i,i...", per_sample_clip_factor, grad_sample)
408410

409411
if p.summed_grad is not None:
410412
p.summed_grad += grad

opacus/optimizers/perlayeroptimizer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616

1717
from typing import List, Optional
1818

19-
import torch
2019
from opacus.optimizers.utils import params
20+
21+
import torch
22+
from opt_einsum import contract
2123
from torch.optim import Optimizer
2224

2325
from .optimizer import DPOptimizer, _check_processed_flag, _mark_as_processed
@@ -61,7 +63,7 @@ def clip_and_accumulate(self):
6163
per_sample_clip_factor = (max_grad_norm / (per_sample_norms + 1e-6)).clamp(
6264
max=1.0
6365
)
64-
grad = torch.einsum("i,i...", per_sample_clip_factor, p.grad_sample)
66+
grad = contract("i,i...", per_sample_clip_factor, p.grad_sample)
6567

6668
if p.summed_grad is not None:
6769
p.summed_grad += grad

opacus/tests/privacy_engine_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from opacus.utils.module_utils import are_state_dict_equal
3434
from opacus.validators.errors import UnsupportedModuleError
3535
from opacus.validators.module_validator import ModuleValidator
36+
from opt_einsum import contract
3637
from torch.utils.data import DataLoader, Dataset, TensorDataset
3738
from torchvision import models, transforms
3839
from torchvision.datasets import FakeData
@@ -48,7 +49,7 @@ def get_grad_sample_aggregated(tensor: torch.Tensor, loss_type: str = "mean"):
4849
if loss_type not in ("sum", "mean"):
4950
raise ValueError(f"loss_type = {loss_type}. Only 'sum' and 'mean' supported")
5051

51-
grad_sample_aggregated = torch.einsum("i...->...", tensor.grad_sample)
52+
grad_sample_aggregated = contract("i...->...", tensor.grad_sample)
5253
if loss_type == "mean":
5354
b_sz = tensor.grad_sample.shape[0]
5455
grad_sample_aggregated /= b_sz

0 commit comments

Comments
 (0)