From 1244e486ae33783ac1ebde07ee9df0c583e7364d Mon Sep 17 00:00:00 2001
From: Yifei Yang <2744335995@qq.com>
Date: Tue, 1 Mar 2022 12:02:34 +0800
Subject: [PATCH] [Feature] Add Iter Timer Hook (#48)

* [Feature]: Add Part3 of Hooks

* [Feature]: Add Hook

* add iter timer hook

* update test

* [Fix]: Add docstring and type hint for base hook

* fix mypy

* improve doc coverage and merge main

Co-authored-by: seuyou <3463423099@qq.com>
---
 mmengine/hooks/__init__.py              |  3 +-
 mmengine/hooks/iter_timer_hook.py       | 58 +++++++++++++++++++++++++
 tests/test_hook/test_iter_timer_hook.py | 29 +++++++++++++
 3 files changed, 89 insertions(+), 1 deletion(-)
 create mode 100644 mmengine/hooks/iter_timer_hook.py
 create mode 100644 tests/test_hook/test_iter_timer_hook.py

diff --git a/mmengine/hooks/__init__.py b/mmengine/hooks/__init__.py
index 99cd738e..39364154 100644
--- a/mmengine/hooks/__init__.py
+++ b/mmengine/hooks/__init__.py
@@ -1,4 +1,5 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 from .hook import Hook
+from .iter_timer_hook import IterTimerHook
 
-__all__ = ['Hook']
+__all__ = ['Hook', 'IterTimerHook']
diff --git a/mmengine/hooks/iter_timer_hook.py b/mmengine/hooks/iter_timer_hook.py
new file mode 100644
index 00000000..ecc84465
--- /dev/null
+++ b/mmengine/hooks/iter_timer_hook.py
@@ -0,0 +1,58 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import time
+from typing import Optional, Sequence
+
+from mmengine.data import BaseDataSample
+from mmengine.registry import HOOKS
+from .hook import Hook
+
+
+@HOOKS.register_module()
+class IterTimerHook(Hook):
+    """A hook that logs the time spent during iteration.
+
+    Eg. ``data_time`` for loading data and ``time`` for a model train step.
+    """
+
+    def before_epoch(self, runner: object) -> None:
+        """Record time flag before start a epoch.
+
+        Args:
+            runner (object): The runner of the training process.
+        """
+        self.t = time.time()
+
+    def before_iter(
+            self,
+            runner: object,
+            data_batch: Optional[Sequence[BaseDataSample]] = None) -> None:
+        """Logging time for loading data and update the time flag.
+
+        Args:
+            runner (object): The runner of the training process.
+            data_batch (Sequence[BaseDataSample]): Data from dataloader.
+                Defaults to None.
+        """
+        # TODO: update for new logging system
+        runner.log_buffer.update({  # type: ignore
+            'data_time': time.time() - self.t
+        })
+
+    def after_iter(self,
+                   runner: object,
+                   data_batch: Optional[Sequence[BaseDataSample]] = None,
+                   outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
+        """Logging time for a iteration and update the time flag.
+
+        Args:
+            runner (object): The runner of the training process.
+            data_batch (Sequence[BaseDataSample]): Data from dataloader.
+                Defaults to None.
+            outputs (Sequence[BaseDataSample]): Outputs from model.
+                Defaults to None.
+        """
+        # TODO: update for new logging system
+        runner.log_buffer.update({  # type: ignore
+            'time': time.time() - self.t
+        })
+        self.t = time.time()
diff --git a/tests/test_hook/test_iter_timer_hook.py b/tests/test_hook/test_iter_timer_hook.py
new file mode 100644
index 00000000..5e3b6e71
--- /dev/null
+++ b/tests/test_hook/test_iter_timer_hook.py
@@ -0,0 +1,29 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from unittest.mock import Mock
+
+from mmengine.hooks import IterTimerHook
+
+
+class TestIterTimerHook:
+
+    def test_before_epoch(self):
+        Hook = IterTimerHook()
+        Runner = Mock()
+        Hook.before_epoch(Runner)
+        assert isinstance(Hook.t, float)
+
+    def test_before_iter(self):
+        Hook = IterTimerHook()
+        Runner = Mock()
+        Runner.log_buffer = dict()
+        Hook.before_epoch(Runner)
+        Hook.before_iter(Runner)
+        assert 'data_time' in Runner.log_buffer
+
+    def test_after_iter(self):
+        Hook = IterTimerHook()
+        Runner = Mock()
+        Runner.log_buffer = dict()
+        Hook.before_epoch(Runner)
+        Hook.after_iter(Runner)
+        assert 'time' in Runner.log_buffer
-- 
GitLab