diff --git a/mmengine/hooks/__init__.py b/mmengine/hooks/__init__.py
index ff87dc6765c1bfe8038a373332c2a34c5fc12867..0f27b2378b60c49f9f5292a375e67c4bd901ef09 100644
--- a/mmengine/hooks/__init__.py
+++ b/mmengine/hooks/__init__.py
@@ -1,9 +1,11 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 from .hook import Hook
 from .iter_timer_hook import IterTimerHook
-from .sampler_seed_hook import DistSamplerSeedHook
+from .optimizer_hook import OptimizerHook
 from .param_scheduler_hook import ParamSchedulerHook
+from .sampler_seed_hook import DistSamplerSeedHook
 
 __all__ = [
-    'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook'
+    'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
+    'OptimizerHook'
 ]
diff --git a/mmengine/hooks/optimizer_hook.py b/mmengine/hooks/optimizer_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..8689b9fa3f43bb708538f33484bdf9ba5253b66a
--- /dev/null
+++ b/mmengine/hooks/optimizer_hook.py
@@ -0,0 +1,130 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+from typing import List, Optional, Sequence
+
+import torch
+from torch.nn.parameter import Parameter
+from torch.nn.utils import clip_grad
+
+from mmengine.data import BaseDataSample
+from mmengine.registry import HOOKS
+from .hook import Hook
+
+
+@HOOKS.register_module()
+class OptimizerHook(Hook):
+    """A hook contains custom operations for the optimizer.
+
+    Args:
+        grad_clip (dict, optional): A config dict to control the clip_grad.
+            Defaults to None.
+        detect_anomalous_params (bool): This option is only used for
+            debugging which will slow down the training speed.
+            Detect anomalous parameters that are not included in
+            the computational graph with ``loss`` as the root.
+            There are two cases
+                - Parameters were not used during
+                  forward pass.
+                - Parameters were not used to produce
+                  loss.
+            Defaults to False.
+    """
+
+    def __init__(self,
+                 grad_clip: Optional[dict] = None,
+                 detect_anomalous_params: bool = False) -> None:
+        self.grad_clip = grad_clip
+        self.detect_anomalous_params = detect_anomalous_params
+
+    def clip_grads(self, params: List[Parameter]) -> Optional[torch.Tensor]:
+        """Clip the gradients of parameters.
+
+        Args:
+            params (list[Parameter]): Model's parameters.
+
+        Returns:
+            Optional[torch.Tensor]: Total norm of the parameters if there is
+            at least one param requiring gradient, else None.
+        """
+        params = list(
+            filter(lambda p: p.requires_grad and p.grad is not None, params))
+        if len(params) > 0:
+            return clip_grad.clip_grad_norm_(params, **self.grad_clip)
+        return None
+
+    def after_train_iter(
+            self,
+            runner: object,
+            data_batch: Optional[Sequence[BaseDataSample]] = None,
+            outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
+        """All operations need to be finished after each training iteration.
+
+        This function will finish following 3 operations:
+
+        - Detect any anomalous parameters which are not included in the
+          training graph. (optional)
+
+        - Compute the gradient of model parameters.
+
+        - Clip the gradidents of each parameters. (optional)
+
+        - Update model parameters with gradients.
+
+        Args:
+            runner (object): The runner of the training process.
+            data_batch (Sequence[BaseDataSample], optional): Data from
+                dataloader. In order to keep this interface consistent with
+                other hooks, we keep ``data_batch`` here. Defaults to None.
+            outputs (Sequence[BaseDataSample], optional): Outputs from model.
+                In order to keep this interface consistent with other hooks,
+                we keep ``outputs`` here. Defaults to None.
+        """
+        runner.optimizer.zero_grad()  # type: ignore
+        if self.detect_anomalous_params:
+            self.detect_anomalous_parameters(
+                runner.outputs['loss'],  # type: ignore
+                runner)
+        runner.outputs['loss'].backward()  # type: ignore
+
+        if self.grad_clip is not None:
+            grad_norm = self.clip_grads(
+                runner.model.parameters())  # type: ignore
+            if grad_norm is not None:
+                # Add grad norm to the logger
+                runner.log_buffer.update(  # type: ignore
+                    {'grad_norm': float(grad_norm)},
+                    runner.outputs['num_samples'])  # type: ignore
+        runner.optimizer.step()  # type: ignore
+
+    def detect_anomalous_parameters(self, loss: torch.Tensor,
+                                    runner: object) -> None:
+        """Detect anomalous parameters that are not included in the graph.
+
+        Args:
+            loss (torch.Tensor): The loss of current iteration.
+            runner (object): The runner of the training process.
+        """
+        logger = runner.logger  # type: ignore
+        parameters_in_graph = set()
+        visited = set()
+
+        def traverse(grad_fn):
+            if grad_fn is None:
+                return
+            if grad_fn not in visited:
+                visited.add(grad_fn)
+                if hasattr(grad_fn, 'variable'):
+                    parameters_in_graph.add(grad_fn.variable)
+                parents = grad_fn.next_functions
+                if parents is not None:
+                    for parent in parents:
+                        grad_fn = parent[0]
+                        traverse(grad_fn)
+
+        traverse(loss.grad_fn)
+        for n, p in runner.model.named_parameters():  # type: ignore
+            if p not in parameters_in_graph and p.requires_grad:
+                logger.log(
+                    level=logging.ERROR,
+                    msg=f'{n} with shape {p.size()} is not '
+                    f'in the computational graph \n')
diff --git a/tests/test_hook/test_optimizer_hook.py b/tests/test_hook/test_optimizer_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ab1281492afab64590b99e0f72902cad5f8d7b8
--- /dev/null
+++ b/tests/test_hook/test_optimizer_hook.py
@@ -0,0 +1,115 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from unittest.mock import Mock
+
+import torch
+from torch import nn
+
+from mmengine.hooks import OptimizerHook
+
+
+class TestOptimizerHook:
+
+    def test_after_train_iter(self):
+
+        class Model(nn.Module):
+
+            def __init__(self):
+                super().__init__()
+                self.conv1 = nn.Conv2d(
+                    in_channels=1,
+                    out_channels=2,
+                    kernel_size=3,
+                    stride=1,
+                    padding=1,
+                    dilation=1)
+                self.conv2 = nn.Conv2d(
+                    in_channels=2,
+                    out_channels=2,
+                    kernel_size=3,
+                    stride=1,
+                    padding=1,
+                    dilation=1)
+                self.conv3 = nn.Conv2d(
+                    in_channels=1,
+                    out_channels=2,
+                    kernel_size=3,
+                    stride=1,
+                    padding=1,
+                    dilation=1)
+
+            def forward(self, x):
+                x1 = self.conv1(x)
+                x2 = self.conv2(x1)
+                return x1, x2
+
+        model = Model()
+        x = torch.rand(1, 1, 3, 3)
+
+        dummy_runner = Mock()
+        dummy_runner.optimizer.zero_grad = Mock(return_value=None)
+        dummy_runner.optimizer.step = Mock(return_value=None)
+        dummy_runner.model = model
+        dummy_runner.outputs = dict()
+
+        dummy_runner.outputs['num_samples'] = 0
+
+        class DummyLogger():
+
+            def __init__(self):
+                self.msg = ''
+
+            def log(self, msg=None, **kwargs):
+                self.msg += msg
+
+        dummy_runner.logger = DummyLogger()
+        optimizer_hook = OptimizerHook(
+            dict(max_norm=2), detect_anomalous_params=True)
+
+        dummy_runner.outputs['loss'] = model(x)[0].sum()
+
+        dummy_runner.outputs['loss'].backward = Mock(
+            wraps=dummy_runner.outputs['loss'].backward)
+        optimizer_hook.detect_anomalous_parameters = Mock(
+            wraps=optimizer_hook.detect_anomalous_parameters)
+        optimizer_hook.clip_grads = Mock(wraps=optimizer_hook.clip_grads)
+
+        optimizer_hook.after_train_iter(dummy_runner)
+        # assert the parameters of conv2 and conv3 are not in the
+        # computational graph which is with x1.sum() as root.
+        assert 'conv2.weight' in dummy_runner.logger.msg
+        assert 'conv2.bias' in dummy_runner.logger.msg
+        assert 'conv3.weight' in dummy_runner.logger.msg
+        assert 'conv3.bias' in dummy_runner.logger.msg
+        assert 'conv1.weight' not in dummy_runner.logger.msg
+        assert 'conv1.bias' not in dummy_runner.logger.msg
+        dummy_runner.optimizer.step.assert_called()
+        dummy_runner.outputs['loss'].backward.assert_called()
+        optimizer_hook.clip_grads.assert_called()
+        optimizer_hook.detect_anomalous_parameters.assert_called()
+
+        dummy_runner.outputs['loss'] = model(x)[1].sum()
+        dummy_runner.logger.msg = ''
+        optimizer_hook.after_train_iter(dummy_runner)
+        # assert the parameters of conv3 are not in the computational graph
+        assert 'conv3.weight' in dummy_runner.logger.msg
+        assert 'conv3.bias' in dummy_runner.logger.msg
+        assert 'conv2.weight' not in dummy_runner.logger.msg
+        assert 'conv2.bias' not in dummy_runner.logger.msg
+        assert 'conv1.weight' not in dummy_runner.logger.msg
+        assert 'conv1.bias' not in dummy_runner.logger.msg
+
+        # grad_clip is None and detect_anomalous_parameters is False
+        optimizer_hook = OptimizerHook(detect_anomalous_params=False)
+        optimizer_hook.detect_anomalous_parameters = Mock(
+            wraps=optimizer_hook.detect_anomalous_parameters)
+        optimizer_hook.clip_grads = Mock(wraps=optimizer_hook.clip_grads)
+        dummy_runner.outputs['loss'] = model(x)[0].sum()
+        dummy_runner.outputs['loss'].backward = Mock(
+            wraps=dummy_runner.outputs['loss'].backward)
+
+        optimizer_hook.after_train_iter(dummy_runner)
+
+        dummy_runner.optimizer.step.assert_called()
+        dummy_runner.outputs['loss'].backward.assert_called()
+        optimizer_hook.clip_grads.assert_not_called()
+        optimizer_hook.detect_anomalous_parameters.assert_not_called()