Skip to content

Commit e194ad7

Browse files
HuanyuZhangfacebook-github-bot
authored andcommitted
Add gradient sample mode to the logging system
Summary: We add gradient sample mode of each submodule to the logging system, which is especially useful information when people want to check the compatibility of complex model architecture. Differential Revision: D70255075
1 parent 0a70a1d commit e194ad7

File tree

1 file changed

+40
-1
lines changed

1 file changed

+40
-1
lines changed

opacus/grad_sample/grad_sample_module_fast_gradient_clipping.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
create_or_accumulate_grad_sample,
2727
promote_current_grad_sample,
2828
)
29-
from opacus.utils.module_utils import requires_grad, trainable_parameters
29+
from opacus.layers.dp_rnn import DPGRU, DPLSTM, DPRNN
30+
from opacus.utils.module_utils import requires_grad, trainable_modules, trainable_parameters
3031

3132

3233
logger = logging.getLogger(__name__)
@@ -109,6 +110,8 @@ def __init__(
109110
If ``strict`` is set to ``True`` and module ``m`` (or any of its
110111
submodules) includes a buffer.
111112
"""
113+
if logger.isEnabledFor(logging.INFO):
114+
self.module_gradient_sample_mode(module=m, force_functorch=force_functorch, use_ghost_clipping=use_ghost_clipping)
112115

113116
super().__init__(
114117
m,
@@ -233,7 +236,43 @@ def capture_backprops_hook(
233236
if len(module.activations) == 0:
234237
if hasattr(module, "max_batch_len"):
235238
del module.max_batch_len
239+
240+
def module_gradient_sample_mode(
241+
self, module: nn.Module, *, force_functorch=False, use_ghost_clipping=True
242+
):
243+
"""
244+
Add logs to track gradient sample mode for each part of the module, including 1) Ghost Clipping, 2) Fast Gradient Clipping (hook mode), and 3) Fast Gradient Clipping (functorch mode).
236245
246+
Args:
247+
module: nn.Module to be checked
248+
force_functorch: If set to ``True``, will use functorch to compute
249+
all per sample gradients. Otherwise, functorch will be used only
250+
for layers without registered grad sampler methods.
251+
use_ghost_clipping: If set to ``True``, Ghost Clipping
252+
will be used for clipping gradients of supported layers. If ``False``, Fast
253+
Gradient Clipping will be used for all layers.
254+
"""
255+
for m_name, m in trainable_modules(module):
256+
if type(m) in [DPRNN, DPLSTM, DPGRU]:
257+
logger.info(
258+
f"Module name: {m_name}, module type: {type(m)}. No hook or functorch is added."
259+
)
260+
261+
elif use_ghost_clipping and type(m) in self.NORM_SAMPLERS:
262+
logger.info(
263+
f"Module name: {m_name}, module type: {type(m)}, under Ghost Clipping."
264+
)
265+
266+
else:
267+
if not force_functorch and type(m) in self.GRAD_SAMPLERS:
268+
logger.info(
269+
f"Module name: {m_name}, module type: {type(m)}, under Fast Gradient Clipping (hook mode)."
270+
)
271+
else:
272+
logger.info(
273+
f"Module name: {m_name}, module type: {type(m)}, under Fast Gradient Clipping (functorch mode)."
274+
)
275+
237276
@property
238277
def per_sample_gradient_norms(self) -> torch.Tensor:
239278
"""Returns per sample gradient norms. Note that these are not privatized and should only be used for debugging purposes or in non-private settings"""

0 commit comments

Comments
 (0)