Skip to content
Snippets Groups Projects
Unverified Commit 16589ce3 authored by BayMax_BHL's avatar BayMax_BHL Committed by GitHub
Browse files

[Feature] Add ProfilerHook (#768)


* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* [Feature] Add profiler hook functionality

* Apply suggestions from code review

* Update mmengine/hooks/profiler_hook.py

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent c382f8a5
No related branches found
No related tags found
No related merge requests found
......@@ -22,3 +22,4 @@ mmengine.hooks
IterTimerHook
SyncBuffersHook
EmptyCacheHook
ProfilerHook
......@@ -22,3 +22,4 @@ mmengine.hooks
IterTimerHook
SyncBuffersHook
EmptyCacheHook
ProfilerHook
......@@ -7,6 +7,7 @@ from .iter_timer_hook import IterTimerHook
from .logger_hook import LoggerHook
from .naive_visualization_hook import NaiveVisualizationHook
from .param_scheduler_hook import ParamSchedulerHook
from .profiler_hook import ProfilerHook
from .runtime_info_hook import RuntimeInfoHook
from .sampler_seed_hook import DistSamplerSeedHook
from .sync_buffer_hook import SyncBuffersHook
......@@ -14,5 +15,5 @@ from .sync_buffer_hook import SyncBuffersHook
__all__ = [
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook', 'LoggerHook',
'NaiveVisualizationHook', 'EMAHook', 'RuntimeInfoHook'
'NaiveVisualizationHook', 'EMAHook', 'RuntimeInfoHook', 'ProfilerHook'
]
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from typing import Callable, Optional, Union
import torch
from mmengine.dist import master_only
from mmengine.hooks import Hook
from mmengine.registry import HOOKS
def check_kineto() -> bool: # noqa
kineto_exist = False
try:
if torch.autograd.kineto_available():
kineto_exist = True
except AttributeError:
warnings.warn('NO KINETO')
return kineto_exist
@HOOKS.register_module()
class ProfilerHook(Hook):
"""A hook to analyze performance during training and inference.
PyTorch Profiler is a tool that allows the collection of the performance
metrics during the training. More details on Profiler can be found at
`official docs <https://pytorch.org/docs/stable/profiler.html
#torch.profiler.profile>`_
Args:
by_epoch (bool): Profile performance by epoch or by iteration.
Defaults to True.
profile_times (int): The period (epoch/iter) recorded by the profiler.
Defaults to 1. For example, profile_iters=10 and by_epoch=False,
indicate that 0-10 iterations are recorded.
activity_with_cpu (bool): Activities to be used in the analysis (CPU)
activity_with_cuda (bool): Activities to be used in the analysis (CUDA)
schedule (dict, optional): Key-word arguments passed to
`torch.profile.schedule <https://pytorch.org/docs/stable/
profiler.html#torch.profiler.schedule>`_.
Defaults to None, which means profiling without a schedule
on_trace_ready (callable, dict, optional): Either a handler or a dict
of generating handler. Defaults to None, which means profiling
without an on_trace_ready.The Callable type needs to construct its
own function that can handle 'torch.autograd.profiler.profile'.
Two officially recommended ways are provided, namely terminal
display or tensorboard display. The terminal display content can be
adjusted through 'EventList.table()'
from 'torch.autograd.profiler_util.py'.
If using tensorboard, save to '{work_dir}/tf_tracing_logs'
by default.
record_shapes (bool): Save information about operator's input shapes.
Defaults to False.
profile_memory (bool): Track tensor memory allocation/deallocation.
Defaults to False.
with_stack (bool): Record source information (file and line number)
for the ops. Defaults to False.
with_flops (bool): Use formula to estimate the FLOPS of specific
operators (matrix multiplication and 2D convolution).
Defaults to False.
json_trace_path (str, optional): Exports the collected trace in Chrome
JSON format. Chrome use 'chrome://tracing' view json file.
Defaults to None, which means profiling does not store json files.
Examples:
>>> # tensorboard trace
>>> trace_config = dict(type='tb_trace')
>>> profiler_hook_cfg = dict(on_trace_ready=trace_config)
"""
priority = 'VERY_LOW'
def __init__(self,
*,
by_epoch: bool = True,
profile_times: int = 1,
activity_with_cpu: bool = True,
activity_with_cuda: bool = False,
schedule: Optional[dict] = None,
on_trace_ready: Union[Callable, dict, None] = None,
record_shapes: bool = False,
profile_memory: bool = False,
with_stack: bool = False,
with_flops: bool = False,
json_trace_path: Optional[str] = None) -> None:
try:
from torch import profiler
except ImportError:
raise ImportError('please upgrade torch above 1.8.1')
if not check_kineto():
raise ImportError('Due to Kineto support issues, please upgrade '
'pytorch above 1.8.1(windows users above 1.9.1)')
assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean.'
self.by_epoch = by_epoch
if profile_times < 1:
raise ValueError('profile_iters should be greater than 0, '
f'but got {profile_times}')
if by_epoch and profile_times > 1:
raise ValueError(
f'Profiler will profile 0-{profile_times} epochs.\n'
'Since profiler will slow down the training, it is recommended'
' to train 1 epoch with ProfilerHook and adjust your setting '
'according to the profiler summary.\n'
'During normal training(epoch > 1), '
'you may disable the ProfilerHook.')
self.profile_times = profile_times
assert isinstance(activity_with_cpu, bool), \
'``activity_with_cpu`` should be a boolean.'
assert isinstance(activity_with_cuda, bool), \
'``activity_with_cuda`` should be a boolean.'
self.activities = []
if activity_with_cpu:
self.activities.append(profiler.ProfilerActivity.CPU)
if activity_with_cuda:
self.activities.append(profiler.ProfilerActivity.CUDA)
if schedule is not None:
assert isinstance(schedule, dict), '``schedule`` should be a dict.'
self.schedule = profiler.schedule(**schedule)
else:
self.schedule = None
self.on_trace_ready = on_trace_ready
self.record_shapes = record_shapes
self.profile_memory = profile_memory
self.with_stack = with_stack
self.with_flops = with_flops
self.json_trace_path = json_trace_path
@master_only
def before_run(self, runner):
"""Initialize the profiler.
Through the runner parameter, the validity of the parameter is further
determined.
"""
max_times = runner.max_epochs if self.by_epoch else runner.max_iters
if max_times < self.profile_times:
raise ValueError(
f'``profile_times`` should not be greater than {max_times}')
on_trace_ready = self._parse_trace_config(runner)
self.profiler = torch.profiler.profile( # noqa
activities=self.activities,
schedule=self.schedule,
on_trace_ready=on_trace_ready,
record_shapes=self.record_shapes,
profile_memory=self.profile_memory,
with_stack=self.with_stack,
with_flops=self.with_flops)
self.profiler.__enter__()
runner.logger.info('profiler is profiling...')
def _parse_trace_config(self, runner):
"""Used to parse the parameter 'on_trace_ready'."""
if self.on_trace_ready is None:
_on_trace_ready = None
elif callable(self.on_trace_ready):
_on_trace_ready = self.on_trace_ready
elif isinstance(self.on_trace_ready, dict):
trace_cfg = self.on_trace_ready.copy()
trace_type = trace_cfg.pop('type')
# Build a log printing handle
if trace_type == 'log_trace':
def _log_handler(_profile):
print(_profile.key_averages().table(**trace_cfg))
_on_trace_ready = _log_handler
elif trace_type == 'tb_trace': # tensorboard_trace handler
try:
import torch_tb_profiler # noqa: F401
except ImportError:
raise ImportError(
'please run ``pip install torch-tb-profiler``')
if 'dir_name' not in trace_cfg:
trace_cfg['dir_name'] = osp.join(runner.log_dir,
'tf_tracing_logs')
elif not osp.isabs(trace_cfg['dir_name']):
trace_cfg['dir_name'] = osp.join(runner.log_dir,
trace_cfg['dir_name'])
runner.logger.info('trace_files of ProfilerHook will be '
f'saved to {trace_cfg["dir_name"]}.')
if self.json_trace_path is not None:
runner.logger.warn(
'When using tensorboard_trace, it is recommended to '
'save json files by setting ``worker_name`` instead of'
' setting ``json_trace_path``')
_on_trace_ready = torch.profiler.tensorboard_trace_handler(
**trace_cfg)
else:
raise ValueError('trace_type should be "log_trace" or '
f'"tb_trace", but got {trace_type}')
else:
raise ValueError(
'``on_trace_ready`` should be a handler, or dict, or None, '
f'but got {self.on_trace_ready}')
return _on_trace_ready
@master_only
def after_train_epoch(self, runner):
"""Determine if the content is exported."""
if self.by_epoch and runner.epoch == self.profile_times - 1:
self._export_chrome_trace(runner)
@master_only
def after_train_iter(self, runner, batch_idx, data_batch, outputs):
"""Update the content according to the schedule, and determine if the
content is exported."""
if self.schedule is None:
self.profiler.step()
if not self.by_epoch and runner.iter == self.profile_times - 1:
self._export_chrome_trace(runner)
def _export_chrome_trace(self, runner):
"""Exporting content."""
runner.logger.info('profiler may take a few minutes...')
self.profiler.__exit__(None, None, None)
if self.json_trace_path is not None:
self.profiler.export_chrome_trace(self.json_trace_path)
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as ops
import unittest
from unittest.mock import MagicMock
import torch
import mmengine.hooks
from mmengine.hooks import ProfilerHook
from mmengine.logging import MMLogger
from mmengine.testing import RunnerTestCase
from mmengine.utils import is_installed
@unittest.skipIf(
not mmengine.hooks.profiler_hook.check_kineto(),
reason='Due to Kineto support issues, '
'please upgrade pytorch above 1.8.1 (windows users above 1.9.1)')
class TestProfilerHook(RunnerTestCase):
def test_init(self):
# Test profile_times_args
ProfilerHook(by_epoch=False, profile_times=1)
with self.assertRaises(ValueError):
ProfilerHook(profile_times=0)
with self.assertRaises(ValueError):
ProfilerHook(by_epoch=True, profile_times=2)
# Test schedule_args
ProfilerHook(schedule=dict(wait=1, warmup=1, active=3, repeat=1))
with self.assertRaises(TypeError):
ProfilerHook(schedule=dict())
def test_parse_trace_config(self):
# Test on_trace_ready_args
runner = MagicMock()
hook = ProfilerHook(on_trace_ready=None)
hook.on_trace_ready = None
hook._parse_trace_config(runner)
def deal_profile(_profile):
pass
hook.on_trace_ready = deal_profile
hook._parse_trace_config(runner)
with self.assertRaises(ValueError):
hook.on_trace_ready = dict(type='unknown')
hook._parse_trace_config(runner)
hook.on_trace_ready = dict(
type='log_trace', sort_by='self_cpu_time_total', row_limit=10)
hook._parse_trace_config(runner)
@unittest.skipIf(
not is_installed('torch-tb-profiler'),
reason='required torch-tb-profiler')
def test_parse_trace_config_tensorboard(self):
# Test on_trace_ready_args
runner = MagicMock()
runner.log_dir = self.temp_dir.name
runner.logger = MMLogger.get_instance('test_profiler')
hook = ProfilerHook(on_trace_ready=None)
hook.on_trace_ready = dict(type='tb_trace')
hook._parse_trace_config(runner)
hook.on_trace_ready['dir_name'] = 'tb'
hook._parse_trace_config(runner)
hook.on_trace_ready['dir_name'] = ops.join(self.temp_dir.name, 'tb')
hook._parse_trace_config(runner)
# with self.assertWarns(DeprecationWarning):
hook = ProfilerHook(
on_trace_ready=dict(type='tb_trace'),
json_trace_path=ops.join(self.temp_dir.name, 'demo.json'))
hook._parse_trace_config(runner)
self.epoch_based_cfg['custom_hooks'] = [
dict(
type='ProfilerHook',
on_trace_ready=dict(
type='tb_trace', dir_name='/home/baymax/RunTime/tb'))
]
runner = self.build_runner(self.epoch_based_cfg)
runner.train()
def test_before_run(self):
runner = MagicMock()
runner.max_epochs = 1000
runner.max_iters = 10000
runner.logger = MMLogger.get_instance('test_profiler')
hook = ProfilerHook()
hook.before_run(runner)
hook.profiler.__exit__(None, None, None)
with self.assertRaises(ValueError):
hook = ProfilerHook(by_epoch=False, profile_times=10001)
hook.before_run(runner)
hook.profiler.__exit__(None, None, None)
with self.assertRaises(ValueError):
hook = ProfilerHook(by_epoch=True, profile_times=1001)
hook.before_run(runner)
hook.profiler.__exit__(None, None, None)
def test_export_chrome_trace(self):
runner = MagicMock()
runner.max_epochs = 1000
runner.logger = MMLogger.get_instance('test_profiler')
hook = ProfilerHook(
json_trace_path=ops.join(self.temp_dir.name, 'demo.json'))
hook.before_run(runner)
hook._export_chrome_trace(runner)
def test_after_train_epoch(self):
runner = MagicMock()
runner.max_epochs = 1000
runner.logger = MMLogger.get_instance('test_profiler')
runner.epoch = 0
hook = ProfilerHook()
hook.before_run(runner)
hook.profiler.__exit__(None, None, None)
hook.profiler = MagicMock()
hook.after_train_epoch(runner)
hook.profiler.__exit__.assert_called_once()
def test_after_train_iter(self):
runner = MagicMock()
runner.max_iters = 10000
runner.logger = MMLogger.get_instance('test_profiler')
runner.iter = 9
hook = ProfilerHook(by_epoch=False, profile_times=10, schedule=None)
hook.before_run(runner)
hook.profiler.__exit__(None, None, None)
hook.profiler = MagicMock()
hook.after_train_iter(runner, 1, 1, 1)
hook.profiler.__exit__.assert_called_once()
hook.profiler.step.assert_called_once()
hook = ProfilerHook(
by_epoch=False,
schedule=dict(wait=1, warmup=1, active=3, repeat=1))
hook.before_run(runner)
hook.profiler.__exit__(None, None, None)
hook.profiler = MagicMock()
hook.after_train_iter(runner, 1, 1, 1)
hook.profiler.step.assert_not_called()
def test_with_runner(self):
self.epoch_based_cfg['custom_hooks'] = [
dict(
type='ProfilerHook',
activity_with_cpu=False,
activity_with_cuda=False)
]
runner = self.build_runner(self.epoch_based_cfg)
runner.train()
json_path = ops.join(self.temp_dir.name, 'demo.json')
self.epoch_based_cfg['custom_hooks'] = [
dict(type='ProfilerHook', json_trace_path=json_path)
]
runner = self.build_runner(self.epoch_based_cfg)
runner.train()
self.assertTrue(
ops.exists(json_path), 'ERROR::json file is not generated!')
self.epoch_based_cfg['custom_hooks'] = [
dict(
type='ProfilerHook',
on_trace_ready=dict(
type='log_trace',
sort_by='self_cpu_time_total',
row_limit=10))
]
runner = self.build_runner(self.epoch_based_cfg)
runner.train()
with self.assertRaises(ValueError):
self.epoch_based_cfg['custom_hooks'] = [
dict(type='ProfilerHook', on_trace_ready=0)
]
runner = self.build_runner(self.epoch_based_cfg)
runner.train()
if torch.cuda.is_available():
self.epoch_based_cfg['custom_hooks'] = [
dict(type='ProfilerHook', activity_with_cuda=True)
]
runner = self.build_runner(self.epoch_based_cfg)
runner.train()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment