Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,11 @@ def get_clipping_coef(self) -> torch.Tensor:

def get_norm_sample(self) -> torch.Tensor:
"""Get per-example gradient norms."""
norm_sample = torch.stack(
[param._norm_sample for param in self.trainable_parameters], dim=0
).norm(2, dim=0)
norm_samples = [param._norm_sample for param in self.trainable_parameters]
if norm_samples:
target_device = norm_samples[0].device
norm_samples = [norm.to(target_device) for norm in norm_samples]
norm_sample = torch.stack(norm_samples, dim=0).norm(2, dim=0)
self.per_sample_gradient_norms = norm_sample
return norm_sample

Expand Down
9 changes: 8 additions & 1 deletion opacus/optimizers/adaclipoptimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ def clip_and_accumulate(self):
per_param_norms = [
g.view(len(g), -1).norm(2, dim=-1) for g in self.grad_samples
]

if per_param_norms:
target_device = per_param_norms[0].device
per_param_norms = [norm.to(target_device) for norm in per_param_norms]

per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)
per_sample_clip_factor = (self.max_grad_norm / (per_sample_norms + 1e-6)).clamp(
max=1.0
Expand All @@ -112,7 +117,9 @@ def clip_and_accumulate(self):
for p in self.params:
_check_processed_flag(p.grad_sample)
grad_sample = self._get_flat_grad_sample(p)
grad = torch.einsum("i,i...", per_sample_clip_factor, grad_sample)

clip_factor_on_device = per_sample_clip_factor.to(grad_sample.device)
grad = torch.einsum("i,i...", clip_factor_on_device, grad_sample)

if p.summed_grad is not None:
p.summed_grad += grad
Expand Down
11 changes: 9 additions & 2 deletions opacus/optimizers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,11 @@ def clip_and_accumulate(self):
per_param_norms = [
g.reshape(len(g), -1).norm(2, dim=-1) for g in self.grad_samples
]

if per_param_norms:
target_device = per_param_norms[0].device
per_param_norms = [norm.to(target_device) for norm in per_param_norms]

per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)
per_sample_clip_factor = (
self.max_grad_norm / (per_sample_norms + 1e-6)
Expand All @@ -457,8 +462,10 @@ def clip_and_accumulate(self):
# for mixed precision, optimizer parameters are usually in FP32
# lower precision grads will be cast up to FP32
grad_sample = grad_sample.to(p.dtype)
per_sample_clip_factor = per_sample_clip_factor.to(p.dtype)
grad = torch.einsum("i,i...", per_sample_clip_factor, grad_sample)
clip_factor_on_device = per_sample_clip_factor.to(grad_sample.device).to(
p.dtype
)
grad = torch.einsum("i,i...", clip_factor_on_device, grad_sample)

if p.summed_grad is not None:
p.summed_grad += grad
Expand Down
94 changes: 94 additions & 0 deletions opacus/tests/grad_sample_module_fast_gradient_clipping_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,3 +422,97 @@ def test_gradient_calculation(self):
logging.info(f"Max difference between (vanilla) Opacus and FGC = {max(diff)}")
msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
assert torch.allclose(flat_grads_normal, flat_grads_gc, atol=1e-3), msg

@unittest.skipIf(torch.cuda.device_count() < 2, "Need at least 2 GPUs")
def test_multidevice_get_norm_sample(self):
"""Test that get_norm_sample handles parameters on different devices."""
device1 = torch.device("cuda:0")
device2 = torch.device("cuda:1")

# Create a simple model with parameters on different devices
class MultiDeviceModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 20).to(device1)
self.fc2 = nn.Linear(20, 5).to(device2)

def forward(self, x):
x = x.to(device1)
x = torch.relu(self.fc1(x))
x = x.to(device2)
return self.fc2(x)

model = MultiDeviceModel()
grad_sample_module = GradSampleModuleFastGradientClipping(
model, max_grad_norm=1.0, use_ghost_clipping=False
)

# Simulate _norm_sample on different devices
batch_size = 4
for param in grad_sample_module.trainable_parameters:
param._norm_sample = torch.randn(batch_size, device=param.device)

# This should not raise any device mismatch errors
try:
norm_sample = grad_sample_module.get_norm_sample()
success = True
except RuntimeError as e:
if "Expected all tensors to be on the same device" in str(e):
success = False
self.fail(f"Device mismatch error in get_norm_sample: {e}")
else:
raise

self.assertTrue(
success, "get_norm_sample should handle multi-device parameters"
)
self.assertEqual(norm_sample.shape[0], batch_size)

@unittest.skipIf(torch.cuda.device_count() < 2, "Need at least 2 GPUs")
def test_multidevice_get_clipping_coef(self):
"""Test that get_clipping_coef handles parameters on different devices."""
device1 = torch.device("cuda:0")
device2 = torch.device("cuda:1")

# Create a simple model with parameters on different devices
class MultiDeviceModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 20).to(device1)
self.fc2 = nn.Linear(20, 5).to(device2)

def forward(self, x):
x = x.to(device1)
x = torch.relu(self.fc1(x))
x = x.to(device2)
return self.fc2(x)

model = MultiDeviceModel()
max_grad_norm = 1.0
grad_sample_module = GradSampleModuleFastGradientClipping(
model, max_grad_norm=max_grad_norm, use_ghost_clipping=False
)

# Simulate _norm_sample on different devices
batch_size = 4
for param in grad_sample_module.trainable_parameters:
# Create norms with values that will require clipping
param._norm_sample = torch.ones(batch_size, device=param.device) * 2.0

# This should not raise any device mismatch errors
try:
clipping_coef = grad_sample_module.get_clipping_coef()
success = True
except RuntimeError as e:
if "Expected all tensors to be on the same device" in str(e):
success = False
self.fail(f"Device mismatch error in get_clipping_coef: {e}")
else:
raise

self.assertTrue(
success, "get_clipping_coef should handle multi-device parameters"
)
self.assertEqual(clipping_coef.shape[0], batch_size)
# Verify clipping coefficients are correct
self.assertTrue(torch.all(clipping_coef <= 1.0))
Loading
Loading