Skip to content

Commit 91dd689

Browse files
author
Alex Sablayrolles
committed
Fix tests
1 parent dba12e1 commit 91dd689

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

.circleci/config.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ commands:
162162
pip install tensorboard
163163
python examples/cifar10.py --lr 0.1 --sigma 1.5 -c 10 --sample-rate 0.04 --epochs 10 --data-root runs/cifar10/data --log-dir runs/cifar10/logs --device <<parameters.device>>
164164
python -c "import torch; model = torch.load('model_best.pth.tar'); exit(0) if (model['best_acc1']>0.4 and model['best_acc1']<0.49) else exit(1)"
165+
python examples/cifar10.py --lr 0.1 --sigma 1.5 -c 10 --sample-rate 0.04 --epochs 10 --data-root runs/cifar10/data --log-dir runs/cifar10/logs --device <<parameters.device>> --grad_sample_mode no_op
166+
python -c "import torch; model = torch.load('model_best.pth.tar'); exit(0) if (model['best_acc1']>0.4 and model['best_acc1']<0.49) else exit(1)"
165167
when: always
166168
- store_test_results:
167169
path: runs/cifar10/test-reports

examples/cifar10.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
import torch.utils.data
3232
import torch.utils.data.distributed
3333
import torchvision.transforms as transforms
34-
from functorch import grad_and_value, make_functional, vmap
3534
from opacus import PrivacyEngine
3635
from opacus.distributed import DifferentiallyPrivateDistributedDataParallel as DPDDP
3736
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -140,6 +139,8 @@ def train(args, model, train_loader, optimizer, privacy_engine, epoch, device):
140139
top1_acc = []
141140

142141
if args.grad_sample_mode == "no_op":
142+
from functorch import grad_and_value, make_functional, vmap
143+
143144
# Functorch prepare
144145
fmodel, _fparams = make_functional(model)
145146

0 commit comments

Comments
 (0)