From 1244e486ae33783ac1ebde07ee9df0c583e7364d Mon Sep 17 00:00:00 2001 From: Yifei Yang <2744335995@qq.com> Date: Tue, 1 Mar 2022 12:02:34 +0800 Subject: [PATCH] [Feature] Add Iter Timer Hook (#48) * [Feature]: Add Part3 of Hooks * [Feature]: Add Hook * add iter timer hook * update test * [Fix]: Add docstring and type hint for base hook * fix mypy * improve doc coverage and merge main Co-authored-by: seuyou <3463423099@qq.com> --- mmengine/hooks/__init__.py | 3 +- mmengine/hooks/iter_timer_hook.py | 58 +++++++++++++++++++++++++ tests/test_hook/test_iter_timer_hook.py | 29 +++++++++++++ 3 files changed, 89 insertions(+), 1 deletion(-) create mode 100644 mmengine/hooks/iter_timer_hook.py create mode 100644 tests/test_hook/test_iter_timer_hook.py diff --git a/mmengine/hooks/__init__.py b/mmengine/hooks/__init__.py index 99cd738e..39364154 100644 --- a/mmengine/hooks/__init__.py +++ b/mmengine/hooks/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .hook import Hook +from .iter_timer_hook import IterTimerHook -__all__ = ['Hook'] +__all__ = ['Hook', 'IterTimerHook'] diff --git a/mmengine/hooks/iter_timer_hook.py b/mmengine/hooks/iter_timer_hook.py new file mode 100644 index 00000000..ecc84465 --- /dev/null +++ b/mmengine/hooks/iter_timer_hook.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import time +from typing import Optional, Sequence + +from mmengine.data import BaseDataSample +from mmengine.registry import HOOKS +from .hook import Hook + + +@HOOKS.register_module() +class IterTimerHook(Hook): + """A hook that logs the time spent during iteration. + + Eg. ``data_time`` for loading data and ``time`` for a model train step. + """ + + def before_epoch(self, runner: object) -> None: + """Record time flag before start a epoch. + + Args: + runner (object): The runner of the training process. + """ + self.t = time.time() + + def before_iter( + self, + runner: object, + data_batch: Optional[Sequence[BaseDataSample]] = None) -> None: + """Logging time for loading data and update the time flag. + + Args: + runner (object): The runner of the training process. + data_batch (Sequence[BaseDataSample]): Data from dataloader. + Defaults to None. + """ + # TODO: update for new logging system + runner.log_buffer.update({ # type: ignore + 'data_time': time.time() - self.t + }) + + def after_iter(self, + runner: object, + data_batch: Optional[Sequence[BaseDataSample]] = None, + outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + """Logging time for a iteration and update the time flag. + + Args: + runner (object): The runner of the training process. + data_batch (Sequence[BaseDataSample]): Data from dataloader. + Defaults to None. + outputs (Sequence[BaseDataSample]): Outputs from model. + Defaults to None. + """ + # TODO: update for new logging system + runner.log_buffer.update({ # type: ignore + 'time': time.time() - self.t + }) + self.t = time.time() diff --git a/tests/test_hook/test_iter_timer_hook.py b/tests/test_hook/test_iter_timer_hook.py new file mode 100644 index 00000000..5e3b6e71 --- /dev/null +++ b/tests/test_hook/test_iter_timer_hook.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest.mock import Mock + +from mmengine.hooks import IterTimerHook + + +class TestIterTimerHook: + + def test_before_epoch(self): + Hook = IterTimerHook() + Runner = Mock() + Hook.before_epoch(Runner) + assert isinstance(Hook.t, float) + + def test_before_iter(self): + Hook = IterTimerHook() + Runner = Mock() + Runner.log_buffer = dict() + Hook.before_epoch(Runner) + Hook.before_iter(Runner) + assert 'data_time' in Runner.log_buffer + + def test_after_iter(self): + Hook = IterTimerHook() + Runner = Mock() + Runner.log_buffer = dict() + Hook.before_epoch(Runner) + Hook.after_iter(Runner) + assert 'time' in Runner.log_buffer -- GitLab