Skip to content
Snippets Groups Projects
Unverified Commit 1244e486 authored by Yifei Yang's avatar Yifei Yang Committed by GitHub
Browse files

[Feature] Add Iter Timer Hook (#48)


* [Feature]: Add Part3 of Hooks

* [Feature]: Add Hook

* add iter timer hook

* update test

* [Fix]: Add docstring and type hint for base hook

* fix mypy

* improve doc coverage and merge main

Co-authored-by: default avatarseuyou <3463423099@qq.com>
parent 42448425
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .hook import Hook from .hook import Hook
from .iter_timer_hook import IterTimerHook
__all__ = ['Hook'] __all__ = ['Hook', 'IterTimerHook']
# Copyright (c) OpenMMLab. All rights reserved.
import time
from typing import Optional, Sequence
from mmengine.data import BaseDataSample
from mmengine.registry import HOOKS
from .hook import Hook
@HOOKS.register_module()
class IterTimerHook(Hook):
"""A hook that logs the time spent during iteration.
Eg. ``data_time`` for loading data and ``time`` for a model train step.
"""
def before_epoch(self, runner: object) -> None:
"""Record time flag before start a epoch.
Args:
runner (object): The runner of the training process.
"""
self.t = time.time()
def before_iter(
self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None) -> None:
"""Logging time for loading data and update the time flag.
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample]): Data from dataloader.
Defaults to None.
"""
# TODO: update for new logging system
runner.log_buffer.update({ # type: ignore
'data_time': time.time() - self.t
})
def after_iter(self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""Logging time for a iteration and update the time flag.
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample]): Data from dataloader.
Defaults to None.
outputs (Sequence[BaseDataSample]): Outputs from model.
Defaults to None.
"""
# TODO: update for new logging system
runner.log_buffer.update({ # type: ignore
'time': time.time() - self.t
})
self.t = time.time()
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import Mock
from mmengine.hooks import IterTimerHook
class TestIterTimerHook:
def test_before_epoch(self):
Hook = IterTimerHook()
Runner = Mock()
Hook.before_epoch(Runner)
assert isinstance(Hook.t, float)
def test_before_iter(self):
Hook = IterTimerHook()
Runner = Mock()
Runner.log_buffer = dict()
Hook.before_epoch(Runner)
Hook.before_iter(Runner)
assert 'data_time' in Runner.log_buffer
def test_after_iter(self):
Hook = IterTimerHook()
Runner = Mock()
Runner.log_buffer = dict()
Hook.before_epoch(Runner)
Hook.after_iter(Runner)
assert 'time' in Runner.log_buffer
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment