From 63a3af4f8cf4de44d3bcab62360a5dcd0e33c219 Mon Sep 17 00:00:00 2001 From: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com> Date: Tue, 1 Mar 2022 17:42:15 +0800 Subject: [PATCH] [Feature]: Add optimizer hook (#70) * [Feature]: Add optimizer hook * [Fix]: Update docstring * [Fix]: Add call with in UT --- mmengine/hooks/__init__.py | 6 +- mmengine/hooks/optimizer_hook.py | 130 +++++++++++++++++++++++++ tests/test_hook/test_optimizer_hook.py | 115 ++++++++++++++++++++++ 3 files changed, 249 insertions(+), 2 deletions(-) create mode 100644 mmengine/hooks/optimizer_hook.py create mode 100644 tests/test_hook/test_optimizer_hook.py diff --git a/mmengine/hooks/__init__.py b/mmengine/hooks/__init__.py index ff87dc67..0f27b237 100644 --- a/mmengine/hooks/__init__.py +++ b/mmengine/hooks/__init__.py @@ -1,9 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from .hook import Hook from .iter_timer_hook import IterTimerHook -from .sampler_seed_hook import DistSamplerSeedHook +from .optimizer_hook import OptimizerHook from .param_scheduler_hook import ParamSchedulerHook +from .sampler_seed_hook import DistSamplerSeedHook __all__ = [ - 'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook' + 'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook', + 'OptimizerHook' ] diff --git a/mmengine/hooks/optimizer_hook.py b/mmengine/hooks/optimizer_hook.py new file mode 100644 index 00000000..8689b9fa --- /dev/null +++ b/mmengine/hooks/optimizer_hook.py @@ -0,0 +1,130 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +from typing import List, Optional, Sequence + +import torch +from torch.nn.parameter import Parameter +from torch.nn.utils import clip_grad + +from mmengine.data import BaseDataSample +from mmengine.registry import HOOKS +from .hook import Hook + + +@HOOKS.register_module() +class OptimizerHook(Hook): + """A hook contains custom operations for the optimizer. + + Args: + grad_clip (dict, optional): A config dict to control the clip_grad. + Defaults to None. + detect_anomalous_params (bool): This option is only used for + debugging which will slow down the training speed. + Detect anomalous parameters that are not included in + the computational graph with ``loss`` as the root. + There are two cases + - Parameters were not used during + forward pass. + - Parameters were not used to produce + loss. + Defaults to False. + """ + + def __init__(self, + grad_clip: Optional[dict] = None, + detect_anomalous_params: bool = False) -> None: + self.grad_clip = grad_clip + self.detect_anomalous_params = detect_anomalous_params + + def clip_grads(self, params: List[Parameter]) -> Optional[torch.Tensor]: + """Clip the gradients of parameters. + + Args: + params (list[Parameter]): Model's parameters. + + Returns: + Optional[torch.Tensor]: Total norm of the parameters if there is + at least one param requiring gradient, else None. + """ + params = list( + filter(lambda p: p.requires_grad and p.grad is not None, params)) + if len(params) > 0: + return clip_grad.clip_grad_norm_(params, **self.grad_clip) + return None + + def after_train_iter( + self, + runner: object, + data_batch: Optional[Sequence[BaseDataSample]] = None, + outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + """All operations need to be finished after each training iteration. + + This function will finish following 3 operations: + + - Detect any anomalous parameters which are not included in the + training graph. (optional) + + - Compute the gradient of model parameters. + + - Clip the gradidents of each parameters. (optional) + + - Update model parameters with gradients. + + Args: + runner (object): The runner of the training process. + data_batch (Sequence[BaseDataSample], optional): Data from + dataloader. In order to keep this interface consistent with + other hooks, we keep ``data_batch`` here. Defaults to None. + outputs (Sequence[BaseDataSample], optional): Outputs from model. + In order to keep this interface consistent with other hooks, + we keep ``outputs`` here. Defaults to None. + """ + runner.optimizer.zero_grad() # type: ignore + if self.detect_anomalous_params: + self.detect_anomalous_parameters( + runner.outputs['loss'], # type: ignore + runner) + runner.outputs['loss'].backward() # type: ignore + + if self.grad_clip is not None: + grad_norm = self.clip_grads( + runner.model.parameters()) # type: ignore + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update( # type: ignore + {'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) # type: ignore + runner.optimizer.step() # type: ignore + + def detect_anomalous_parameters(self, loss: torch.Tensor, + runner: object) -> None: + """Detect anomalous parameters that are not included in the graph. + + Args: + loss (torch.Tensor): The loss of current iteration. + runner (object): The runner of the training process. + """ + logger = runner.logger # type: ignore + parameters_in_graph = set() + visited = set() + + def traverse(grad_fn): + if grad_fn is None: + return + if grad_fn not in visited: + visited.add(grad_fn) + if hasattr(grad_fn, 'variable'): + parameters_in_graph.add(grad_fn.variable) + parents = grad_fn.next_functions + if parents is not None: + for parent in parents: + grad_fn = parent[0] + traverse(grad_fn) + + traverse(loss.grad_fn) + for n, p in runner.model.named_parameters(): # type: ignore + if p not in parameters_in_graph and p.requires_grad: + logger.log( + level=logging.ERROR, + msg=f'{n} with shape {p.size()} is not ' + f'in the computational graph \n') diff --git a/tests/test_hook/test_optimizer_hook.py b/tests/test_hook/test_optimizer_hook.py new file mode 100644 index 00000000..8ab12814 --- /dev/null +++ b/tests/test_hook/test_optimizer_hook.py @@ -0,0 +1,115 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest.mock import Mock + +import torch +from torch import nn + +from mmengine.hooks import OptimizerHook + + +class TestOptimizerHook: + + def test_after_train_iter(self): + + class Model(nn.Module): + + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d( + in_channels=1, + out_channels=2, + kernel_size=3, + stride=1, + padding=1, + dilation=1) + self.conv2 = nn.Conv2d( + in_channels=2, + out_channels=2, + kernel_size=3, + stride=1, + padding=1, + dilation=1) + self.conv3 = nn.Conv2d( + in_channels=1, + out_channels=2, + kernel_size=3, + stride=1, + padding=1, + dilation=1) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(x1) + return x1, x2 + + model = Model() + x = torch.rand(1, 1, 3, 3) + + dummy_runner = Mock() + dummy_runner.optimizer.zero_grad = Mock(return_value=None) + dummy_runner.optimizer.step = Mock(return_value=None) + dummy_runner.model = model + dummy_runner.outputs = dict() + + dummy_runner.outputs['num_samples'] = 0 + + class DummyLogger(): + + def __init__(self): + self.msg = '' + + def log(self, msg=None, **kwargs): + self.msg += msg + + dummy_runner.logger = DummyLogger() + optimizer_hook = OptimizerHook( + dict(max_norm=2), detect_anomalous_params=True) + + dummy_runner.outputs['loss'] = model(x)[0].sum() + + dummy_runner.outputs['loss'].backward = Mock( + wraps=dummy_runner.outputs['loss'].backward) + optimizer_hook.detect_anomalous_parameters = Mock( + wraps=optimizer_hook.detect_anomalous_parameters) + optimizer_hook.clip_grads = Mock(wraps=optimizer_hook.clip_grads) + + optimizer_hook.after_train_iter(dummy_runner) + # assert the parameters of conv2 and conv3 are not in the + # computational graph which is with x1.sum() as root. + assert 'conv2.weight' in dummy_runner.logger.msg + assert 'conv2.bias' in dummy_runner.logger.msg + assert 'conv3.weight' in dummy_runner.logger.msg + assert 'conv3.bias' in dummy_runner.logger.msg + assert 'conv1.weight' not in dummy_runner.logger.msg + assert 'conv1.bias' not in dummy_runner.logger.msg + dummy_runner.optimizer.step.assert_called() + dummy_runner.outputs['loss'].backward.assert_called() + optimizer_hook.clip_grads.assert_called() + optimizer_hook.detect_anomalous_parameters.assert_called() + + dummy_runner.outputs['loss'] = model(x)[1].sum() + dummy_runner.logger.msg = '' + optimizer_hook.after_train_iter(dummy_runner) + # assert the parameters of conv3 are not in the computational graph + assert 'conv3.weight' in dummy_runner.logger.msg + assert 'conv3.bias' in dummy_runner.logger.msg + assert 'conv2.weight' not in dummy_runner.logger.msg + assert 'conv2.bias' not in dummy_runner.logger.msg + assert 'conv1.weight' not in dummy_runner.logger.msg + assert 'conv1.bias' not in dummy_runner.logger.msg + + # grad_clip is None and detect_anomalous_parameters is False + optimizer_hook = OptimizerHook(detect_anomalous_params=False) + optimizer_hook.detect_anomalous_parameters = Mock( + wraps=optimizer_hook.detect_anomalous_parameters) + optimizer_hook.clip_grads = Mock(wraps=optimizer_hook.clip_grads) + dummy_runner.outputs['loss'] = model(x)[0].sum() + dummy_runner.outputs['loss'].backward = Mock( + wraps=dummy_runner.outputs['loss'].backward) + + optimizer_hook.after_train_iter(dummy_runner) + + dummy_runner.optimizer.step.assert_called() + dummy_runner.outputs['loss'].backward.assert_called() + optimizer_hook.clip_grads.assert_not_called() + optimizer_hook.detect_anomalous_parameters.assert_not_called() -- GitLab