Skip to content

Commit 9dd6f0d

Browse files
author
Alex Sablayrolles
committed
Add functorch to cifar example
1 parent 4bfbbf9 commit 9dd6f0d

File tree

1 file changed

+40
-9
lines changed

1 file changed

+40
-9
lines changed

examples/cifar10.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import torch.utils.data
3232
import torch.utils.data.distributed
3333
import torchvision.transforms as transforms
34+
from functorch import grad_and_value, make_functional, vmap
3435
from opacus import PrivacyEngine
3536
from opacus.distributed import DifferentiallyPrivateDistributedDataParallel as DPDDP
3637
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -138,25 +139,53 @@ def train(args, model, train_loader, optimizer, privacy_engine, epoch, device):
138139
losses = []
139140
top1_acc = []
140141

142+
if args.grad_sample_mode == "no_op":
143+
# Functorch prepare
144+
fmodel, _fparams = make_functional(model)
145+
146+
def compute_loss_stateless_model(params, sample, target):
147+
batch = sample.unsqueeze(0)
148+
targets = target.unsqueeze(0)
149+
150+
predictions = fmodel(params, batch)
151+
loss = criterion(predictions, targets)
152+
return loss
153+
154+
ft_compute_grad = grad_and_value(compute_loss_stateless_model)
155+
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, 0, 0))
156+
# Using model.parameters() instead of fparams
157+
# as fparams seems to not point to the dynamically updated parameters
158+
params = list(model.parameters())
159+
141160
for i, (images, target) in enumerate(tqdm(train_loader)):
142161

143162
images = images.to(device)
144163
target = target.to(device)
145164

146165
# compute output
147166
output = model(images)
148-
loss = criterion(output, target)
149-
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
150-
labels = target.detach().cpu().numpy()
151167

152-
# measure accuracy and record loss
153-
acc1 = accuracy(preds, labels)
168+
if args.grad_sample_mode == "no_op":
169+
per_sample_grads, per_sample_losses = ft_compute_sample_grad(
170+
params, images, target
171+
)
172+
per_sample_grads = [g.detach() for g in per_sample_grads]
173+
loss = torch.mean(per_sample_losses)
174+
for (p, g) in zip(params, per_sample_grads):
175+
p.grad_sample = g
176+
else:
177+
loss = criterion(output, target)
178+
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
179+
labels = target.detach().cpu().numpy()
180+
181+
# measure accuracy and record loss
182+
acc1 = accuracy(preds, labels)
183+
top1_acc.append(acc1)
154184

155-
losses.append(loss.item())
156-
top1_acc.append(acc1)
185+
# compute gradient and do SGD step
186+
loss.backward()
157187

158-
# compute gradient and do SGD step
159-
loss.backward()
188+
losses.append(loss.item())
160189

161190
# make sure we take a step after processing the last mini-batch in the
162191
# epoch to ensure we start the next epoch with a clean state
@@ -331,6 +360,7 @@ def main():
331360
noise_multiplier=args.sigma,
332361
max_grad_norm=max_grad_norm,
333362
clipping=clipping,
363+
grad_sample_mode=args.grad_sample_mode,
334364
)
335365

336366
# Store some logs
@@ -388,6 +418,7 @@ def main():
388418

389419
def parse_args():
390420
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
421+
parser.add_argument("--grad_sample_mode", type=str, default="hooks")
391422
parser.add_argument(
392423
"-j",
393424
"--workers",

0 commit comments

Comments
 (0)