diff --git a/mmengine/hooks/__init__.py b/mmengine/hooks/__init__.py index 99cd738ed257d13d959ad7488b3147e644afa806..39364154babedcdca6be714bd239088020c6159e 100644 --- a/mmengine/hooks/__init__.py +++ b/mmengine/hooks/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .hook import Hook +from .iter_timer_hook import IterTimerHook -__all__ = ['Hook'] +__all__ = ['Hook', 'IterTimerHook'] diff --git a/mmengine/hooks/iter_timer_hook.py b/mmengine/hooks/iter_timer_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..ecc84465441b2eeba7d48b94a7e6ae635c0a8302 --- /dev/null +++ b/mmengine/hooks/iter_timer_hook.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import time +from typing import Optional, Sequence + +from mmengine.data import BaseDataSample +from mmengine.registry import HOOKS +from .hook import Hook + + +@HOOKS.register_module() +class IterTimerHook(Hook): + """A hook that logs the time spent during iteration. + + Eg. ``data_time`` for loading data and ``time`` for a model train step. + """ + + def before_epoch(self, runner: object) -> None: + """Record time flag before start a epoch. + + Args: + runner (object): The runner of the training process. + """ + self.t = time.time() + + def before_iter( + self, + runner: object, + data_batch: Optional[Sequence[BaseDataSample]] = None) -> None: + """Logging time for loading data and update the time flag. + + Args: + runner (object): The runner of the training process. + data_batch (Sequence[BaseDataSample]): Data from dataloader. + Defaults to None. + """ + # TODO: update for new logging system + runner.log_buffer.update({ # type: ignore + 'data_time': time.time() - self.t + }) + + def after_iter(self, + runner: object, + data_batch: Optional[Sequence[BaseDataSample]] = None, + outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + """Logging time for a iteration and update the time flag. + + Args: + runner (object): The runner of the training process. + data_batch (Sequence[BaseDataSample]): Data from dataloader. + Defaults to None. + outputs (Sequence[BaseDataSample]): Outputs from model. + Defaults to None. + """ + # TODO: update for new logging system + runner.log_buffer.update({ # type: ignore + 'time': time.time() - self.t + }) + self.t = time.time() diff --git a/tests/test_hook/test_iter_timer_hook.py b/tests/test_hook/test_iter_timer_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..5e3b6e71b2c077e950914bceea69c1621ce95f09 --- /dev/null +++ b/tests/test_hook/test_iter_timer_hook.py @@ -0,0 +1,29 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest.mock import Mock + +from mmengine.hooks import IterTimerHook + + +class TestIterTimerHook: + + def test_before_epoch(self): + Hook = IterTimerHook() + Runner = Mock() + Hook.before_epoch(Runner) + assert isinstance(Hook.t, float) + + def test_before_iter(self): + Hook = IterTimerHook() + Runner = Mock() + Runner.log_buffer = dict() + Hook.before_epoch(Runner) + Hook.before_iter(Runner) + assert 'data_time' in Runner.log_buffer + + def test_after_iter(self): + Hook = IterTimerHook() + Runner = Mock() + Runner.log_buffer = dict() + Hook.before_epoch(Runner) + Hook.after_iter(Runner) + assert 'time' in Runner.log_buffer