Skip to content

Commit dcc4acf

Browse files
sayanghoshfacebook-github-bot
authored andcommitted
Replace torch einsum with opt_einsum
Differential Revision: D37128344 fbshipit-source-id: ff1105b150dfa023b57a7eec96fb5fa8ae58f1c9
1 parent d079ffd commit dcc4acf

File tree

11 files changed

+28
-19
lines changed

11 files changed

+28
-19
lines changed

opacus/grad_sample/conv.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020
import torch.nn as nn
2121
from opacus.utils.tensor_utils import unfold2d, unfold3d
22+
from opt_einsum import contract
2223

2324
from .utils import register_grad_sampler
2425

opacus/grad_sample/dp_rnn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020
import torch.nn as nn
2121
from opacus.layers.dp_rnn import RNNLinear
22+
from opt_einsum import contract
2223

2324
from .utils import register_grad_sampler
2425

@@ -40,8 +41,8 @@ def compute_rnn_linear_grad_sample(
4041
"""
4142
ret = {}
4243
if layer.weight.requires_grad:
43-
gs = torch.einsum("n...i,n...j->nij", backprops, activations)
44+
gs = contract("n...i,n...j->nij", backprops, activations)
4445
ret[layer.weight] = gs
4546
if layer.bias is not None and layer.bias.requires_grad:
46-
ret[layer.bias] = torch.einsum("n...k->nk", backprops)
47+
ret[layer.bias] = contract("n...k->nk", backprops)
4748
return ret

opacus/grad_sample/group_norm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020
import torch.nn as nn
2121
import torch.nn.functional as F
22+
from opt_einsum import contract
2223

2324
from .utils import register_grad_sampler
2425

@@ -40,7 +41,7 @@ def compute_group_norm_grad_sample(
4041
ret = {}
4142
if layer.weight.requires_grad:
4243
gs = F.group_norm(activations, layer.num_groups, eps=layer.eps) * backprops
43-
ret[layer.weight] = torch.einsum("ni...->ni", gs)
44+
ret[layer.weight] = contract("ni...->ni", gs)
4445
if layer.bias is not None and layer.bias.requires_grad:
45-
ret[layer.bias] = torch.einsum("ni...->ni", backprops)
46+
ret[layer.bias] = contract("ni...->ni", backprops)
4647
return ret

opacus/grad_sample/instance_norm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919
import torch.nn as nn
2020
import torch.nn.functional as F
21+
from opt_einsum import contract
2122

2223
from .utils import register_grad_sampler
2324

@@ -49,7 +50,7 @@ def compute_instance_norm_grad_sample(
4950
ret = {}
5051
if layer.weight.requires_grad:
5152
gs = F.instance_norm(activations, eps=layer.eps) * backprops
52-
ret[layer.weight] = torch.einsum("ni...->ni", gs)
53+
ret[layer.weight] = contract("ni...->ni", gs)
5354
if layer.bias is not None and layer.bias.requires_grad:
54-
ret[layer.bias] = torch.einsum("ni...->ni", backprops)
55+
ret[layer.bias] = contract("ni...->ni", backprops)
5556
return ret

opacus/grad_sample/linear.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import torch
1919
import torch.nn as nn
20+
from opt_einsum import contract
2021

2122
from .utils import register_grad_sampler
2223

@@ -35,8 +36,8 @@ def compute_linear_grad_sample(
3536
"""
3637
ret = {}
3738
if layer.weight.requires_grad:
38-
gs = torch.einsum("n...i,n...j->nij", backprops, activations)
39+
gs = contract("n...i,n...j->nij", backprops, activations)
3940
ret[layer.weight] = gs
4041
if layer.bias is not None and layer.bias.requires_grad:
41-
ret[layer.bias] = torch.einsum("n...k->nk", backprops)
42+
ret[layer.bias] = contract("n...k->nk", backprops)
4243
return ret

opacus/optimizers/adaclipoptimizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,17 @@
1818
from typing import Callable, Optional
1919

2020
import torch
21+
from opt_einsum import contract
2122
from torch.optim import Optimizer
2223

2324
from .optimizer import (
24-
DPOptimizer,
2525
_check_processed_flag,
2626
_generate_noise,
2727
_get_flat_grad_sample,
2828
_mark_as_processed,
29+
DPOptimizer,
2930
)
3031

31-
3232
logger = logging.getLogger(__name__)
3333

3434

@@ -108,7 +108,7 @@ def clip_and_accumulate(self):
108108
_check_processed_flag(p.grad_sample)
109109

110110
grad_sample = _get_flat_grad_sample(p)
111-
grad = torch.einsum("i,i...", per_sample_clip_factor, grad_sample)
111+
grad = contract("i,i...", per_sample_clip_factor, grad_sample)
112112

113113
if p.summed_grad is not None:
114114
p.summed_grad += grad

opacus/optimizers/ddp_perlayeroptimizer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,18 @@
1818
from typing import Callable, List, Optional
1919

2020
import torch
21+
from opt_einsum import contract
2122
from torch import nn
2223
from torch.optim import Optimizer
2324

24-
from .optimizer import DPOptimizer, _generate_noise
25+
from .optimizer import _generate_noise, DPOptimizer
2526

2627

2728
def _clip_and_accumulate_parameter(p: nn.Parameter, max_grad_norm: float):
2829
per_sample_norms = p.grad_sample.view(len(p.grad_sample), -1).norm(2, dim=-1)
2930
per_sample_clip_factor = (max_grad_norm / (per_sample_norms + 1e-6)).clamp(max=1.0)
3031

31-
grad = torch.einsum("i,i...", per_sample_clip_factor, p.grad_sample)
32+
grad = contract("i,i...", per_sample_clip_factor, p.grad_sample)
3233
if p.summed_grad is not None:
3334
p.summed_grad += grad
3435
else:

opacus/optimizers/optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919

2020
import torch
2121
from opacus.optimizers.utils import params
22+
from opt_einsum import contract
2223
from torch import nn
2324
from torch.optim import Optimizer
2425

25-
2626
logger = logging.getLogger(__name__)
2727

2828

@@ -404,7 +404,7 @@ def clip_and_accumulate(self):
404404
_check_processed_flag(p.grad_sample)
405405

406406
grad_sample = _get_flat_grad_sample(p)
407-
grad = torch.einsum("i,i...", per_sample_clip_factor, grad_sample)
407+
grad = contract("i,i...", per_sample_clip_factor, grad_sample)
408408

409409
if p.summed_grad is not None:
410410
p.summed_grad += grad

opacus/optimizers/perlayeroptimizer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818

1919
import torch
2020
from opacus.optimizers.utils import params
21+
from opt_einsum import contract
2122
from torch.optim import Optimizer
2223

23-
from .optimizer import DPOptimizer, _check_processed_flag, _mark_as_processed
24+
from .optimizer import _check_processed_flag, _mark_as_processed, DPOptimizer
2425

2526

2627
class DPPerLayerOptimizer(DPOptimizer):
@@ -61,7 +62,7 @@ def clip_and_accumulate(self):
6162
per_sample_clip_factor = (max_grad_norm / (per_sample_norms + 1e-6)).clamp(
6263
max=1.0
6364
)
64-
grad = torch.einsum("i,i...", per_sample_clip_factor, p.grad_sample)
65+
grad = contract("i,i...", per_sample_clip_factor, p.grad_sample)
6566

6667
if p.summed_grad is not None:
6768
p.summed_grad += grad

opacus/tests/privacy_engine_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@
2929
from opacus import PrivacyEngine
3030
from opacus.layers.dp_multihead_attention import DPMultiheadAttention
3131
from opacus.optimizers.optimizer import _generate_noise
32-
from opacus.scheduler import StepNoise, _NoiseScheduler
32+
from opacus.scheduler import _NoiseScheduler, StepNoise
3333
from opacus.utils.module_utils import are_state_dict_equal
3434
from opacus.validators.errors import UnsupportedModuleError
35+
from opt_einsum import contract
3536
from torch.utils.data import DataLoader, Dataset, TensorDataset
3637
from torchvision import models, transforms
3738
from torchvision.datasets import FakeData
@@ -47,7 +48,7 @@ def get_grad_sample_aggregated(tensor: torch.Tensor, loss_type: str = "mean"):
4748
if loss_type not in ("sum", "mean"):
4849
raise ValueError(f"loss_type = {loss_type}. Only 'sum' and 'mean' supported")
4950

50-
grad_sample_aggregated = torch.einsum("i...->...", tensor.grad_sample)
51+
grad_sample_aggregated = contract("i...->...", tensor.grad_sample)
5152
if loss_type == "mean":
5253
b_sz = tensor.grad_sample.shape[0]
5354
grad_sample_aggregated /= b_sz

0 commit comments

Comments
 (0)