From bc759e55502e1613c9378a71ae1ac809bf05aa16 Mon Sep 17 00:00:00 2001
From: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com>
Date: Sat, 26 Feb 2022 14:51:09 +0800
Subject: [PATCH] [Feature]: Add base Hook (#47)

* [Feature]: Add Part3 of Hooks

* [Feature]: Add Hook

* [Fix]: Add docstring and type hint for base hook

* [Fix]: Add test case to not the last iter, inner_iter, epoch

* [Fix]: Add missing type hint

* [Feature]: Add Args and Returns in docstring

* [Fix]: Add missing colon

* [Fix]: Add optional to docstring

* [Fix]: Fix docstring problem

* [Fix]:typo

* [Fix]: Fix lint

Co-authored-by: Your <you@example.com>
---
 mmengine/__init__.py         |   1 +
 mmengine/hooks/__init__.py   |   4 +
 mmengine/hooks/hook.py       | 320 +++++++++++++++++++++++++++++++++++
 tests/test_hook/test_hook.py | 201 ++++++++++++++++++++++
 4 files changed, 526 insertions(+)
 create mode 100644 mmengine/hooks/__init__.py
 create mode 100644 mmengine/hooks/hook.py
 create mode 100644 tests/test_hook/test_hook.py

diff --git a/mmengine/__init__.py b/mmengine/__init__.py
index 55e7b929..7bc5108b 100644
--- a/mmengine/__init__.py
+++ b/mmengine/__init__.py
@@ -4,5 +4,6 @@ from .config import *
 from .data import *
 from .dataset import *
 from .fileio import *
+from .hooks import *
 from .registry import *
 from .utils import *
diff --git a/mmengine/hooks/__init__.py b/mmengine/hooks/__init__.py
new file mode 100644
index 00000000..99cd738e
--- /dev/null
+++ b/mmengine/hooks/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .hook import Hook
+
+__all__ = ['Hook']
diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py
new file mode 100644
index 00000000..8321af83
--- /dev/null
+++ b/mmengine/hooks/hook.py
@@ -0,0 +1,320 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional, Sequence
+
+from mmengine.data import BaseDataSample
+
+
+class Hook:
+    """Base hook class.
+
+    All hooks should inherit from this class.
+    """
+
+    def before_run(self, runner: object) -> None:
+        """All subclasses should override this method, if they need any
+        operations before the training process.
+
+        Args:
+            runner (object): The runner of the training process.
+        """
+        pass
+
+    def after_run(self, runner: object) -> None:
+        """All subclasses should override this method, if they need any
+        operations after the training process.
+
+        Args:
+            runner (object): The runner of the training process.
+        """
+        pass
+
+    def before_epoch(self, runner: object) -> None:
+        """All subclasses should override this method, if they need any
+        operations before each epoch.
+
+        Args:
+            runner (object): The runner of the training process.
+        """
+        pass
+
+    def after_epoch(self, runner: object) -> None:
+        """All subclasses should override this method, if they need any
+        operations after each epoch.
+
+        Args:
+            runner (object): The runner of the training process.
+        """
+        pass
+
+    def before_iter(
+            self,
+            runner: object,
+            data_batch: Optional[Sequence[BaseDataSample]] = None) -> None:
+        """All subclasses should override this method, if they need any
+        operations before each iter.
+
+        Args:
+            runner (object): The runner of the training process.
+            data_batch (Sequence[BaseDataSample]): Data from dataloader.
+                Defaults to None.
+        """
+        pass
+
+    def after_iter(self,
+                   runner: object,
+                   data_batch: Optional[Sequence[BaseDataSample]] = None,
+                   outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
+        """All subclasses should override this method, if they need any
+        operations after each epoch.
+
+        Args:
+            runner (object): The runner of the training process.
+            data_batch (Sequence[BaseDataSample]): Data from dataloader.
+                Defaults to None.
+            outputs (Sequence[BaseDataSample]): Outputs from model.
+                Defaults to None.
+        """
+        pass
+
+    def before_save_checkpoint(self, runner: object, checkpoint: dict) -> None:
+        """All subclasses should override this method, if they need any
+        operations before saving the checkpoint.
+
+        Args:
+            runner (object): The runner of the training process.
+            checkpoints (dict): Model's checkpoint.
+        """
+        pass
+
+    def after_load_checkpoint(self, runner: object, checkpoint: dict) -> None:
+        """All subclasses should override this method, if they need any
+        operations after loading the checkpoint.
+
+        Args:
+            runner (object): The runner of the training process.
+            checkpoints (dict): Model's checkpoint.
+        """
+        pass
+
+    def before_train_epoch(self, runner: object) -> None:
+        """All subclasses should override this method, if they need any
+        operations before each training epoch.
+
+        Args:
+            runner (object): The runner of the training process.
+        """
+        self.before_epoch(runner)
+
+    def before_val_epoch(self, runner: object) -> None:
+        """All subclasses should override this method, if they need any
+        operations before each validation epoch.
+
+        Args:
+            runner (object): The runner of the training process.
+        """
+        self.before_epoch(runner)
+
+    def before_test_epoch(self, runner: object) -> None:
+        """All subclasses should override this method, if they need any
+        operations before each test epoch.
+
+        Args:
+            runner (object): The runner of the training process.
+        """
+        self.before_epoch(runner)
+
+    def after_train_epoch(self, runner: object) -> None:
+        """All subclasses should override this method, if they need any
+        operations after each training epoch.
+
+        Args:
+            runner (object): The runner of the training process.
+        """
+        self.after_epoch(runner)
+
+    def after_val_epoch(self, runner: object) -> None:
+        """All subclasses should override this method, if they need any
+        operations after each validation epoch.
+
+        Args:
+            runner (object): The runner of the training process.
+        """
+        self.after_epoch(runner)
+
+    def after_test_epoch(self, runner: object) -> None:
+        """All subclasses should override this method, if they need any
+        operations after each test epoch.
+
+        Args:
+            runner (object): The runner of the training process.
+        """
+        self.after_epoch(runner)
+
+    def before_train_iter(
+            self,
+            runner: object,
+            data_batch: Optional[Sequence[BaseDataSample]] = None) -> None:
+        """All subclasses should override this method, if they need any
+        operations before each training iteration.
+
+        Args:
+            runner (object): The runner of the training process.
+            data_batch (Sequence[BaseDataSample], optional): Data from
+                dataloader. Defaults to None.
+        """
+        self.before_iter(runner, data_batch=None)
+
+    def before_val_iter(
+            self,
+            runner: object,
+            data_batch: Optional[Sequence[BaseDataSample]] = None) -> None:
+        """All subclasses should override this method, if they need any
+        operations before each validation iteration.
+
+        Args:
+            runner (object): The runner of the training process.
+            data_batch (Sequence[BaseDataSample], optional): Data from
+                dataloader. Defaults to None.
+        """
+        self.before_iter(runner, data_batch=None)
+
+    def before_test_iter(
+            self,
+            runner: object,
+            data_batch: Optional[Sequence[BaseDataSample]] = None) -> None:
+        """All subclasses should override this method, if they need any
+        operations before each test iteration.
+
+        Args:
+            runner (object): The runner of the training process.
+            data_batch (Sequence[BaseDataSample], optional): Data from
+                dataloader. Defaults to None.
+        """
+        self.before_iter(runner, data_batch=None)
+
+    def after_train_iter(
+            self,
+            runner: object,
+            data_batch: Optional[Sequence[BaseDataSample]] = None,
+            outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
+        """All subclasses should override this method, if they need any
+        operations after each training iteration.
+
+        Args:
+            runner (object): The runner of the training process.
+            data_batch (Sequence[BaseDataSample], optional): Data from
+                dataloader. Defaults to None.
+            outputs (Sequence[BaseDataSample], optional): Outputs from model.
+                Defaults to None.
+        """
+        self.after_iter(runner, data_batch=None, outputs=None)
+
+    def after_val_iter(
+            self,
+            runner: object,
+            data_batch: Optional[Sequence[BaseDataSample]] = None,
+            outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
+        """All subclasses should override this method, if they need any
+        operations after each validation iteration.
+
+        Args:
+            runner (object): The runner of the training process.
+            data_batch (Sequence[BaseDataSample], optional): Data from
+                dataloader. Defaults to None.
+            outputs (Sequence[BaseDataSample], optional): Outputs from
+                model. Defaults to None.
+        """
+        self.after_iter(runner, data_batch=None, outputs=None)
+
+    def after_test_iter(
+            self,
+            runner: object,
+            data_batch: Optional[Sequence[BaseDataSample]] = None,
+            outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
+        """All subclasses should override this method, if they need any
+        operations after each test iteration.
+
+        Args:
+            runner (object): The runner of the training process.
+            data_batch (Sequence[BaseDataSample], optional): Data from
+                dataloader. Defaults to None.
+            outputs (Sequence[BaseDataSample], optional): Outputs from model.
+                Defaults to None.
+        """
+        self.after_iter(runner, data_batch=None, outputs=None)
+
+    def every_n_epochs(self, runner: object, n: int) -> bool:
+        """Test whether or not current epoch can be evenly divided by n.
+
+        Args:
+            runner (object): The runner of the training process.
+            n (int): Whether or not current epoch can be evenly divided by n.
+
+        Returns:
+            bool: whether or not current epoch can be evenly divided by n.
+        """
+        return (runner.epoch + 1) % n == 0 if n > 0 else False  # type: ignore
+
+    def every_n_inner_iters(self, runner: object, n: int) -> bool:
+        """Test whether or not current inner iteration can be evenly divided by
+        n.
+
+        Args:
+            runner (object): The runner of the training process.
+            n (int): Whether or not current inner iteration can be evenly
+                divided by n.
+
+        Returns:
+            bool: whether or not current inner iteration can be evenly
+            divided by n.
+        """
+        return (runner.inner_iter +  # type: ignore
+                1) % n == 0 if n > 0 else False
+
+    def every_n_iters(self, runner: object, n: int) -> bool:
+        """Test whether or not current iteration can be evenly divided by n.
+
+        Args:
+            runner (object): The runner of the training process.
+            n (int): Whether or not current iteration can be
+                evenly divided by n.
+
+        Returns:
+            bool: Return True if the current iteration can be evenly divided
+            by n, otherwise False.
+        """
+        return (runner.iter + 1) % n == 0 if n > 0 else False  # type: ignore
+
+    def end_of_epoch(self, runner: object) -> bool:
+        """Check whether the current epoch reaches the `max_epochs` or not.
+
+        Args:
+            runner (object): The runner of the training process.
+
+        Returns:
+            bool: whether the end of current epoch or not.
+        """
+        return runner.inner_iter + 1 == len(runner.data_loader)  # type: ignore
+
+    def is_last_epoch(self, runner: object) -> bool:
+        """Test whether or not current epoch is the last epoch.
+
+        Args:
+            runner (object): The runner of the training process.
+
+        Returns:
+            bool: bool: Return True if the current epoch reaches the
+            `max_epochs`, otherwise False.
+        """
+        return runner.epoch + 1 == runner._max_epochs  # type: ignore
+
+    def is_last_iter(self, runner: object) -> bool:
+        """Test whether or not current epoch is the last iteration.
+
+        Args:
+            runner (object): The runner of the training process.
+
+        Returns:
+            bool: whether or not current iteration is the last iteration.
+        """
+        return runner.iter + 1 == runner._max_iters  # type: ignore
diff --git a/tests/test_hook/test_hook.py b/tests/test_hook/test_hook.py
new file mode 100644
index 00000000..5884a161
--- /dev/null
+++ b/tests/test_hook/test_hook.py
@@ -0,0 +1,201 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from unittest.mock import Mock
+
+from mmengine.hooks import Hook
+
+
+class TestHook:
+
+    def test_before_run(self):
+        hook = Hook()
+        runner = Mock()
+        hook.before_run(runner)
+
+    def test_after_run(self):
+        hook = Hook()
+        runner = Mock()
+        hook.after_run(runner)
+
+    def test_before_epoch(self):
+        hook = Hook()
+        runner = Mock()
+        hook.before_epoch(runner)
+
+    def test_after_epoch(self):
+        hook = Hook()
+        runner = Mock()
+        hook.after_epoch(runner)
+
+    def test_before_iter(self):
+        hook = Hook()
+        runner = Mock()
+        data_batch = {}
+        hook.before_iter(runner, data_batch)
+
+    def test_after_iter(self):
+        hook = Hook()
+        runner = Mock()
+        data_batch = {}
+        outputs = {}
+        hook.after_iter(runner, data_batch, outputs)
+
+    def test_before_save_checkpoint(self):
+        hook = Hook()
+        runner = Mock()
+        checkpoint = {}
+        hook.before_save_checkpoint(runner, checkpoint)
+
+    def test_after_load_checkpoint(self):
+        hook = Hook()
+        runner = Mock()
+        checkpoint = {}
+        hook.after_load_checkpoint(runner, checkpoint)
+
+    def test_before_train_epoch(self):
+        hook = Hook()
+        runner = Mock()
+        hook.before_train_epoch(runner)
+
+    def test_before_val_epoch(self):
+        hook = Hook()
+        runner = Mock()
+        hook.before_val_epoch(runner)
+
+    def test_before_test_epoch(self):
+        hook = Hook()
+        runner = Mock()
+        hook.before_test_epoch(runner)
+
+    def test_after_train_epoch(self):
+        hook = Hook()
+        runner = Mock()
+        hook.after_train_epoch(runner)
+
+    def test_after_val_epoch(self):
+        hook = Hook()
+        runner = Mock()
+        hook.after_val_epoch(runner)
+
+    def test_after_test_epoch(self):
+        hook = Hook()
+        runner = Mock()
+        hook.after_test_epoch(runner)
+
+    def test_before_train_iter(self):
+        hook = Hook()
+        runner = Mock()
+        data_batch = {}
+        hook.before_train_iter(runner, data_batch)
+
+    def test_before_val_iter(self):
+        hook = Hook()
+        runner = Mock()
+        data_batch = {}
+        hook.before_val_iter(runner, data_batch)
+
+    def test_before_test_iter(self):
+        hook = Hook()
+        runner = Mock()
+        data_batch = {}
+        hook.before_test_iter(runner, data_batch)
+
+    def test_after_train_iter(self):
+        hook = Hook()
+        runner = Mock()
+        data_batch = {}
+        outputs = {}
+        hook.after_train_iter(runner, data_batch, outputs)
+
+    def test_after_val_iter(self):
+        hook = Hook()
+        runner = Mock()
+        data_batch = {}
+        outputs = {}
+        hook.after_val_iter(runner, data_batch, outputs)
+
+    def test_after_test_iter(self):
+        hook = Hook()
+        runner = Mock()
+        data_batch = {}
+        outputs = {}
+        hook.after_test_iter(runner, data_batch, outputs)
+
+    def test_every_n_epochs(self):
+        hook = Hook()
+        runner = Mock()
+
+        for i in range(100):
+            runner.epoch = i
+            return_val = hook.every_n_epochs(runner, 3)
+            if (i + 1) % 3 == 0:
+                assert return_val
+            else:
+                assert not return_val
+
+    def test_every_n_inner_iters(self):
+        hook = Hook()
+        runner = Mock()
+
+        for i in range(100):
+            runner.inner_iter = i
+            return_val = hook.every_n_inner_iters(runner, 3)
+            if (i + 1) % 3 == 0:
+                assert return_val
+            else:
+                assert not return_val
+
+    def test_every_n_iters(self):
+        hook = Hook()
+        runner = Mock()
+        for i in range(100):
+            runner.iter = i
+            return_val = hook.every_n_iters(runner, 3)
+            if (i + 1) % 3 == 0:
+                assert return_val
+            else:
+                assert not return_val
+
+    def test_end_of_epoch(self):
+        hook = Hook()
+        runner = Mock()
+
+        # last inner iter
+        runner.inner_iter = 1
+        runner.data_loader.__len__ = Mock(return_value=2)
+        return_val = hook.end_of_epoch(runner)
+        assert return_val
+
+        # not the last inner iter
+        runner.inner_iter = 0
+        return_val = hook.end_of_epoch(runner)
+        assert not return_val
+
+    def test_is_last_epoch(self):
+        hook = Hook()
+        runner = Mock()
+
+        # last epoch
+        runner.epoch = 1
+        runner._max_epochs = 2
+        return_val = hook.is_last_epoch(runner)
+        assert return_val
+
+        # not the last epoch
+        runner.epoch = 0
+        return_val = hook.is_last_epoch(runner)
+        assert not return_val
+
+    def test_is_last_iter(self):
+        hook = Hook()
+        runner = Mock()
+
+        # last iter
+        runner.iter = 1
+        runner._max_iters = 2
+        return_val = hook.is_last_iter(runner)
+        assert return_val
+
+        # not the last iter
+        runner.iter = 0
+        return_val = hook.is_last_iter(runner)
+        assert not return_val
-- 
GitLab