|
26 | 26 | create_or_accumulate_grad_sample, |
27 | 27 | promote_current_grad_sample, |
28 | 28 | ) |
29 | | -from opacus.utils.module_utils import requires_grad, trainable_parameters |
| 29 | +from opacus.layers.dp_rnn import DPGRU, DPLSTM, DPRNN |
| 30 | +from opacus.utils.module_utils import requires_grad, trainable_modules, trainable_parameters |
30 | 31 |
|
31 | 32 |
|
32 | 33 | logger = logging.getLogger(__name__) |
@@ -109,6 +110,8 @@ def __init__( |
109 | 110 | If ``strict`` is set to ``True`` and module ``m`` (or any of its |
110 | 111 | submodules) includes a buffer. |
111 | 112 | """ |
| 113 | + if logger.isEnabledFor(logging.INFO): |
| 114 | + self.module_gradient_sample_mode(module=m, force_functorch=force_functorch, use_ghost_clipping=use_ghost_clipping) |
112 | 115 |
|
113 | 116 | super().__init__( |
114 | 117 | m, |
@@ -233,7 +236,43 @@ def capture_backprops_hook( |
233 | 236 | if len(module.activations) == 0: |
234 | 237 | if hasattr(module, "max_batch_len"): |
235 | 238 | del module.max_batch_len |
| 239 | + |
| 240 | + def module_gradient_sample_mode( |
| 241 | + self, module: nn.Module, *, force_functorch=False, use_ghost_clipping=True |
| 242 | + ): |
| 243 | + """ |
| 244 | + Add logs to track gradient sample mode for each part of the module, including 1) Ghost Clipping, 2) Fast Gradient Clipping (hook mode), and 3) Fast Gradient Clipping (functorch mode). |
236 | 245 |
|
| 246 | + Args: |
| 247 | + module: nn.Module to be checked |
| 248 | + force_functorch: If set to ``True``, will use functorch to compute |
| 249 | + all per sample gradients. Otherwise, functorch will be used only |
| 250 | + for layers without registered grad sampler methods. |
| 251 | + use_ghost_clipping: If set to ``True``, Ghost Clipping |
| 252 | + will be used for clipping gradients of supported layers. If ``False``, Fast |
| 253 | + Gradient Clipping will be used for all layers. |
| 254 | + """ |
| 255 | + for m_name, m in trainable_modules(module): |
| 256 | + if type(m) in [DPRNN, DPLSTM, DPGRU]: |
| 257 | + logger.info( |
| 258 | + f"Module name: {m_name}, module type: {type(m)}. No hook or functorch is added." |
| 259 | + ) |
| 260 | + |
| 261 | + elif use_ghost_clipping and type(m) in self.NORM_SAMPLERS: |
| 262 | + logger.info( |
| 263 | + f"Module name: {m_name}, module type: {type(m)}, under Ghost Clipping." |
| 264 | + ) |
| 265 | + |
| 266 | + else: |
| 267 | + if not force_functorch and type(m) in self.GRAD_SAMPLERS: |
| 268 | + logger.info( |
| 269 | + f"Module name: {m_name}, module type: {type(m)}, under Fast Gradient Clipping (hook mode)." |
| 270 | + ) |
| 271 | + else: |
| 272 | + logger.info( |
| 273 | + f"Module name: {m_name}, module type: {type(m)}, under Fast Gradient Clipping (functorch mode)." |
| 274 | + ) |
| 275 | + |
237 | 276 | @property |
238 | 277 | def per_sample_gradient_norms(self) -> torch.Tensor: |
239 | 278 | """Returns per sample gradient norms. Note that these are not privatized and should only be used for debugging purposes or in non-private settings""" |
|
0 commit comments