From bc759e55502e1613c9378a71ae1ac809bf05aa16 Mon Sep 17 00:00:00 2001 From: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com> Date: Sat, 26 Feb 2022 14:51:09 +0800 Subject: [PATCH] [Feature]: Add base Hook (#47) * [Feature]: Add Part3 of Hooks * [Feature]: Add Hook * [Fix]: Add docstring and type hint for base hook * [Fix]: Add test case to not the last iter, inner_iter, epoch * [Fix]: Add missing type hint * [Feature]: Add Args and Returns in docstring * [Fix]: Add missing colon * [Fix]: Add optional to docstring * [Fix]: Fix docstring problem * [Fix]:typo * [Fix]: Fix lint Co-authored-by: Your <you@example.com> --- mmengine/__init__.py | 1 + mmengine/hooks/__init__.py | 4 + mmengine/hooks/hook.py | 320 +++++++++++++++++++++++++++++++++++ tests/test_hook/test_hook.py | 201 ++++++++++++++++++++++ 4 files changed, 526 insertions(+) create mode 100644 mmengine/hooks/__init__.py create mode 100644 mmengine/hooks/hook.py create mode 100644 tests/test_hook/test_hook.py diff --git a/mmengine/__init__.py b/mmengine/__init__.py index 55e7b929..7bc5108b 100644 --- a/mmengine/__init__.py +++ b/mmengine/__init__.py @@ -4,5 +4,6 @@ from .config import * from .data import * from .dataset import * from .fileio import * +from .hooks import * from .registry import * from .utils import * diff --git a/mmengine/hooks/__init__.py b/mmengine/hooks/__init__.py new file mode 100644 index 00000000..99cd738e --- /dev/null +++ b/mmengine/hooks/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .hook import Hook + +__all__ = ['Hook'] diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py new file mode 100644 index 00000000..8321af83 --- /dev/null +++ b/mmengine/hooks/hook.py @@ -0,0 +1,320 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence + +from mmengine.data import BaseDataSample + + +class Hook: + """Base hook class. + + All hooks should inherit from this class. + """ + + def before_run(self, runner: object) -> None: + """All subclasses should override this method, if they need any + operations before the training process. + + Args: + runner (object): The runner of the training process. + """ + pass + + def after_run(self, runner: object) -> None: + """All subclasses should override this method, if they need any + operations after the training process. + + Args: + runner (object): The runner of the training process. + """ + pass + + def before_epoch(self, runner: object) -> None: + """All subclasses should override this method, if they need any + operations before each epoch. + + Args: + runner (object): The runner of the training process. + """ + pass + + def after_epoch(self, runner: object) -> None: + """All subclasses should override this method, if they need any + operations after each epoch. + + Args: + runner (object): The runner of the training process. + """ + pass + + def before_iter( + self, + runner: object, + data_batch: Optional[Sequence[BaseDataSample]] = None) -> None: + """All subclasses should override this method, if they need any + operations before each iter. + + Args: + runner (object): The runner of the training process. + data_batch (Sequence[BaseDataSample]): Data from dataloader. + Defaults to None. + """ + pass + + def after_iter(self, + runner: object, + data_batch: Optional[Sequence[BaseDataSample]] = None, + outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + """All subclasses should override this method, if they need any + operations after each epoch. + + Args: + runner (object): The runner of the training process. + data_batch (Sequence[BaseDataSample]): Data from dataloader. + Defaults to None. + outputs (Sequence[BaseDataSample]): Outputs from model. + Defaults to None. + """ + pass + + def before_save_checkpoint(self, runner: object, checkpoint: dict) -> None: + """All subclasses should override this method, if they need any + operations before saving the checkpoint. + + Args: + runner (object): The runner of the training process. + checkpoints (dict): Model's checkpoint. + """ + pass + + def after_load_checkpoint(self, runner: object, checkpoint: dict) -> None: + """All subclasses should override this method, if they need any + operations after loading the checkpoint. + + Args: + runner (object): The runner of the training process. + checkpoints (dict): Model's checkpoint. + """ + pass + + def before_train_epoch(self, runner: object) -> None: + """All subclasses should override this method, if they need any + operations before each training epoch. + + Args: + runner (object): The runner of the training process. + """ + self.before_epoch(runner) + + def before_val_epoch(self, runner: object) -> None: + """All subclasses should override this method, if they need any + operations before each validation epoch. + + Args: + runner (object): The runner of the training process. + """ + self.before_epoch(runner) + + def before_test_epoch(self, runner: object) -> None: + """All subclasses should override this method, if they need any + operations before each test epoch. + + Args: + runner (object): The runner of the training process. + """ + self.before_epoch(runner) + + def after_train_epoch(self, runner: object) -> None: + """All subclasses should override this method, if they need any + operations after each training epoch. + + Args: + runner (object): The runner of the training process. + """ + self.after_epoch(runner) + + def after_val_epoch(self, runner: object) -> None: + """All subclasses should override this method, if they need any + operations after each validation epoch. + + Args: + runner (object): The runner of the training process. + """ + self.after_epoch(runner) + + def after_test_epoch(self, runner: object) -> None: + """All subclasses should override this method, if they need any + operations after each test epoch. + + Args: + runner (object): The runner of the training process. + """ + self.after_epoch(runner) + + def before_train_iter( + self, + runner: object, + data_batch: Optional[Sequence[BaseDataSample]] = None) -> None: + """All subclasses should override this method, if they need any + operations before each training iteration. + + Args: + runner (object): The runner of the training process. + data_batch (Sequence[BaseDataSample], optional): Data from + dataloader. Defaults to None. + """ + self.before_iter(runner, data_batch=None) + + def before_val_iter( + self, + runner: object, + data_batch: Optional[Sequence[BaseDataSample]] = None) -> None: + """All subclasses should override this method, if they need any + operations before each validation iteration. + + Args: + runner (object): The runner of the training process. + data_batch (Sequence[BaseDataSample], optional): Data from + dataloader. Defaults to None. + """ + self.before_iter(runner, data_batch=None) + + def before_test_iter( + self, + runner: object, + data_batch: Optional[Sequence[BaseDataSample]] = None) -> None: + """All subclasses should override this method, if they need any + operations before each test iteration. + + Args: + runner (object): The runner of the training process. + data_batch (Sequence[BaseDataSample], optional): Data from + dataloader. Defaults to None. + """ + self.before_iter(runner, data_batch=None) + + def after_train_iter( + self, + runner: object, + data_batch: Optional[Sequence[BaseDataSample]] = None, + outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + """All subclasses should override this method, if they need any + operations after each training iteration. + + Args: + runner (object): The runner of the training process. + data_batch (Sequence[BaseDataSample], optional): Data from + dataloader. Defaults to None. + outputs (Sequence[BaseDataSample], optional): Outputs from model. + Defaults to None. + """ + self.after_iter(runner, data_batch=None, outputs=None) + + def after_val_iter( + self, + runner: object, + data_batch: Optional[Sequence[BaseDataSample]] = None, + outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + """All subclasses should override this method, if they need any + operations after each validation iteration. + + Args: + runner (object): The runner of the training process. + data_batch (Sequence[BaseDataSample], optional): Data from + dataloader. Defaults to None. + outputs (Sequence[BaseDataSample], optional): Outputs from + model. Defaults to None. + """ + self.after_iter(runner, data_batch=None, outputs=None) + + def after_test_iter( + self, + runner: object, + data_batch: Optional[Sequence[BaseDataSample]] = None, + outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + """All subclasses should override this method, if they need any + operations after each test iteration. + + Args: + runner (object): The runner of the training process. + data_batch (Sequence[BaseDataSample], optional): Data from + dataloader. Defaults to None. + outputs (Sequence[BaseDataSample], optional): Outputs from model. + Defaults to None. + """ + self.after_iter(runner, data_batch=None, outputs=None) + + def every_n_epochs(self, runner: object, n: int) -> bool: + """Test whether or not current epoch can be evenly divided by n. + + Args: + runner (object): The runner of the training process. + n (int): Whether or not current epoch can be evenly divided by n. + + Returns: + bool: whether or not current epoch can be evenly divided by n. + """ + return (runner.epoch + 1) % n == 0 if n > 0 else False # type: ignore + + def every_n_inner_iters(self, runner: object, n: int) -> bool: + """Test whether or not current inner iteration can be evenly divided by + n. + + Args: + runner (object): The runner of the training process. + n (int): Whether or not current inner iteration can be evenly + divided by n. + + Returns: + bool: whether or not current inner iteration can be evenly + divided by n. + """ + return (runner.inner_iter + # type: ignore + 1) % n == 0 if n > 0 else False + + def every_n_iters(self, runner: object, n: int) -> bool: + """Test whether or not current iteration can be evenly divided by n. + + Args: + runner (object): The runner of the training process. + n (int): Whether or not current iteration can be + evenly divided by n. + + Returns: + bool: Return True if the current iteration can be evenly divided + by n, otherwise False. + """ + return (runner.iter + 1) % n == 0 if n > 0 else False # type: ignore + + def end_of_epoch(self, runner: object) -> bool: + """Check whether the current epoch reaches the `max_epochs` or not. + + Args: + runner (object): The runner of the training process. + + Returns: + bool: whether the end of current epoch or not. + """ + return runner.inner_iter + 1 == len(runner.data_loader) # type: ignore + + def is_last_epoch(self, runner: object) -> bool: + """Test whether or not current epoch is the last epoch. + + Args: + runner (object): The runner of the training process. + + Returns: + bool: bool: Return True if the current epoch reaches the + `max_epochs`, otherwise False. + """ + return runner.epoch + 1 == runner._max_epochs # type: ignore + + def is_last_iter(self, runner: object) -> bool: + """Test whether or not current epoch is the last iteration. + + Args: + runner (object): The runner of the training process. + + Returns: + bool: whether or not current iteration is the last iteration. + """ + return runner.iter + 1 == runner._max_iters # type: ignore diff --git a/tests/test_hook/test_hook.py b/tests/test_hook/test_hook.py new file mode 100644 index 00000000..5884a161 --- /dev/null +++ b/tests/test_hook/test_hook.py @@ -0,0 +1,201 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest.mock import Mock + +from mmengine.hooks import Hook + + +class TestHook: + + def test_before_run(self): + hook = Hook() + runner = Mock() + hook.before_run(runner) + + def test_after_run(self): + hook = Hook() + runner = Mock() + hook.after_run(runner) + + def test_before_epoch(self): + hook = Hook() + runner = Mock() + hook.before_epoch(runner) + + def test_after_epoch(self): + hook = Hook() + runner = Mock() + hook.after_epoch(runner) + + def test_before_iter(self): + hook = Hook() + runner = Mock() + data_batch = {} + hook.before_iter(runner, data_batch) + + def test_after_iter(self): + hook = Hook() + runner = Mock() + data_batch = {} + outputs = {} + hook.after_iter(runner, data_batch, outputs) + + def test_before_save_checkpoint(self): + hook = Hook() + runner = Mock() + checkpoint = {} + hook.before_save_checkpoint(runner, checkpoint) + + def test_after_load_checkpoint(self): + hook = Hook() + runner = Mock() + checkpoint = {} + hook.after_load_checkpoint(runner, checkpoint) + + def test_before_train_epoch(self): + hook = Hook() + runner = Mock() + hook.before_train_epoch(runner) + + def test_before_val_epoch(self): + hook = Hook() + runner = Mock() + hook.before_val_epoch(runner) + + def test_before_test_epoch(self): + hook = Hook() + runner = Mock() + hook.before_test_epoch(runner) + + def test_after_train_epoch(self): + hook = Hook() + runner = Mock() + hook.after_train_epoch(runner) + + def test_after_val_epoch(self): + hook = Hook() + runner = Mock() + hook.after_val_epoch(runner) + + def test_after_test_epoch(self): + hook = Hook() + runner = Mock() + hook.after_test_epoch(runner) + + def test_before_train_iter(self): + hook = Hook() + runner = Mock() + data_batch = {} + hook.before_train_iter(runner, data_batch) + + def test_before_val_iter(self): + hook = Hook() + runner = Mock() + data_batch = {} + hook.before_val_iter(runner, data_batch) + + def test_before_test_iter(self): + hook = Hook() + runner = Mock() + data_batch = {} + hook.before_test_iter(runner, data_batch) + + def test_after_train_iter(self): + hook = Hook() + runner = Mock() + data_batch = {} + outputs = {} + hook.after_train_iter(runner, data_batch, outputs) + + def test_after_val_iter(self): + hook = Hook() + runner = Mock() + data_batch = {} + outputs = {} + hook.after_val_iter(runner, data_batch, outputs) + + def test_after_test_iter(self): + hook = Hook() + runner = Mock() + data_batch = {} + outputs = {} + hook.after_test_iter(runner, data_batch, outputs) + + def test_every_n_epochs(self): + hook = Hook() + runner = Mock() + + for i in range(100): + runner.epoch = i + return_val = hook.every_n_epochs(runner, 3) + if (i + 1) % 3 == 0: + assert return_val + else: + assert not return_val + + def test_every_n_inner_iters(self): + hook = Hook() + runner = Mock() + + for i in range(100): + runner.inner_iter = i + return_val = hook.every_n_inner_iters(runner, 3) + if (i + 1) % 3 == 0: + assert return_val + else: + assert not return_val + + def test_every_n_iters(self): + hook = Hook() + runner = Mock() + for i in range(100): + runner.iter = i + return_val = hook.every_n_iters(runner, 3) + if (i + 1) % 3 == 0: + assert return_val + else: + assert not return_val + + def test_end_of_epoch(self): + hook = Hook() + runner = Mock() + + # last inner iter + runner.inner_iter = 1 + runner.data_loader.__len__ = Mock(return_value=2) + return_val = hook.end_of_epoch(runner) + assert return_val + + # not the last inner iter + runner.inner_iter = 0 + return_val = hook.end_of_epoch(runner) + assert not return_val + + def test_is_last_epoch(self): + hook = Hook() + runner = Mock() + + # last epoch + runner.epoch = 1 + runner._max_epochs = 2 + return_val = hook.is_last_epoch(runner) + assert return_val + + # not the last epoch + runner.epoch = 0 + return_val = hook.is_last_epoch(runner) + assert not return_val + + def test_is_last_iter(self): + hook = Hook() + runner = Mock() + + # last iter + runner.iter = 1 + runner._max_iters = 2 + return_val = hook.is_last_iter(runner) + assert return_val + + # not the last iter + runner.iter = 0 + return_val = hook.is_last_iter(runner) + assert not return_val -- GitLab