|
31 | 31 | import torch.utils.data |
32 | 32 | import torch.utils.data.distributed |
33 | 33 | import torchvision.transforms as transforms |
| 34 | +from functorch import grad_and_value, make_functional, vmap |
34 | 35 | from opacus import PrivacyEngine |
35 | 36 | from opacus.distributed import DifferentiallyPrivateDistributedDataParallel as DPDDP |
36 | 37 | from torch.nn.parallel import DistributedDataParallel as DDP |
@@ -138,25 +139,53 @@ def train(args, model, train_loader, optimizer, privacy_engine, epoch, device): |
138 | 139 | losses = [] |
139 | 140 | top1_acc = [] |
140 | 141 |
|
| 142 | + if args.grad_sample_mode == "no_op": |
| 143 | + # Functorch prepare |
| 144 | + fmodel, _fparams = make_functional(model) |
| 145 | + |
| 146 | + def compute_loss_stateless_model(params, sample, target): |
| 147 | + batch = sample.unsqueeze(0) |
| 148 | + targets = target.unsqueeze(0) |
| 149 | + |
| 150 | + predictions = fmodel(params, batch) |
| 151 | + loss = criterion(predictions, targets) |
| 152 | + return loss |
| 153 | + |
| 154 | + ft_compute_grad = grad_and_value(compute_loss_stateless_model) |
| 155 | + ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, 0, 0)) |
| 156 | + # Using model.parameters() instead of fparams |
| 157 | + # as fparams seems to not point to the dynamically updated parameters |
| 158 | + params = list(model.parameters()) |
| 159 | + |
141 | 160 | for i, (images, target) in enumerate(tqdm(train_loader)): |
142 | 161 |
|
143 | 162 | images = images.to(device) |
144 | 163 | target = target.to(device) |
145 | 164 |
|
146 | 165 | # compute output |
147 | 166 | output = model(images) |
148 | | - loss = criterion(output, target) |
149 | | - preds = np.argmax(output.detach().cpu().numpy(), axis=1) |
150 | | - labels = target.detach().cpu().numpy() |
151 | 167 |
|
152 | | - # measure accuracy and record loss |
153 | | - acc1 = accuracy(preds, labels) |
| 168 | + if args.grad_sample_mode == "no_op": |
| 169 | + per_sample_grads, per_sample_losses = ft_compute_sample_grad( |
| 170 | + params, images, target |
| 171 | + ) |
| 172 | + per_sample_grads = [g.detach() for g in per_sample_grads] |
| 173 | + loss = torch.mean(per_sample_losses) |
| 174 | + for (p, g) in zip(params, per_sample_grads): |
| 175 | + p.grad_sample = g |
| 176 | + else: |
| 177 | + loss = criterion(output, target) |
| 178 | + preds = np.argmax(output.detach().cpu().numpy(), axis=1) |
| 179 | + labels = target.detach().cpu().numpy() |
| 180 | + |
| 181 | + # measure accuracy and record loss |
| 182 | + acc1 = accuracy(preds, labels) |
| 183 | + top1_acc.append(acc1) |
154 | 184 |
|
155 | | - losses.append(loss.item()) |
156 | | - top1_acc.append(acc1) |
| 185 | + # compute gradient and do SGD step |
| 186 | + loss.backward() |
157 | 187 |
|
158 | | - # compute gradient and do SGD step |
159 | | - loss.backward() |
| 188 | + losses.append(loss.item()) |
160 | 189 |
|
161 | 190 | # make sure we take a step after processing the last mini-batch in the |
162 | 191 | # epoch to ensure we start the next epoch with a clean state |
@@ -331,6 +360,7 @@ def main(): |
331 | 360 | noise_multiplier=args.sigma, |
332 | 361 | max_grad_norm=max_grad_norm, |
333 | 362 | clipping=clipping, |
| 363 | + grad_sample_mode=args.grad_sample_mode, |
334 | 364 | ) |
335 | 365 |
|
336 | 366 | # Store some logs |
@@ -388,6 +418,7 @@ def main(): |
388 | 418 |
|
389 | 419 | def parse_args(): |
390 | 420 | parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training") |
| 421 | + parser.add_argument("--grad_sample_mode", type=str, default="hooks") |
391 | 422 | parser.add_argument( |
392 | 423 | "-j", |
393 | 424 | "--workers", |
|
0 commit comments