Skip to content
Snippets Groups Projects
Unverified Commit 05fd5e81 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Enhance] Refine MMLogger and change save dir of MMLogger and Visualizer (#205)

* MMLogger can call get_current_instance without get_instance, change log path

* fix docstring

* fix docstring and update UT

* Fix runner

* fix docstring and lint

* fix ut below python3.8

* resolve circle import
parent 2475c7a0
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os
import os.path as osp
import sys
from logging import Logger, LogRecord
from typing import Optional, Union
......@@ -9,7 +8,7 @@ from typing import Optional, Union
import torch.distributed as dist
from termcolor import colored
from mmengine.utils import ManagerMixin, mkdir_or_exist
from mmengine.utils import ManagerMixin
class MMFormatter(logging.Formatter):
......@@ -105,8 +104,6 @@ class MMLogger(Logger, ManagerMixin):
config.
- Different from ``logging.Logger``, ``MMLogger`` will not log warrning
or error message without ``Handler``.
- If `log_file=/path/to/tmp.log`, all logs will be saved to
`/path/to/tmp/tmp.log`
Examples:
>>> logger = MMLogger.get_instance(name='MMLogger',
......@@ -164,15 +161,6 @@ class MMLogger(Logger, ManagerMixin):
self.handlers.append(stream_handler)
if log_file is not None:
# If `log_file=/path/to/tmp.log`, all logs will be saved to
# `/path/to/tmp/tmp.log`
log_dir = osp.dirname(log_file)
filename = osp.basename(log_file)
filename_list = filename.split('.')
sub_file_name = '.'.join(filename_list[:-1])
log_dir = osp.join(log_dir, sub_file_name)
mkdir_or_exist(log_dir)
log_file = osp.join(log_dir, filename)
if rank != 0:
# rename `log_file` with rank suffix.
path_split = log_file.split(os.sep)
......@@ -199,6 +187,21 @@ class MMLogger(Logger, ManagerMixin):
file_handler.setLevel(log_level)
self.handlers.append(file_handler)
@classmethod
def get_current_instance(cls) -> 'MMLogger':
"""Get latest created ``MMLogger`` instance.
:obj:`MMLogger` can call :meth:`get_current_instance` before any
instance has been created, and return a logger with the instance name
"mmengine".
Returns:
MMLogger: Configured logger instance.
"""
if not cls._instance_dict:
cls.get_instance('mmengine')
return super(MMLogger, cls).get_current_instance()
def callHandlers(self, record: LogRecord) -> None:
"""Pass a record to all relevant handlers.
......
......@@ -6,7 +6,6 @@ import numpy as np
import torch
from mmengine.utils import ManagerMixin
from mmengine.visualization.utils import check_type
from .history_buffer import HistoryBuffer
......@@ -80,6 +79,21 @@ class MessageHub(ManagerMixin):
('Key in `resumed_keys` must contained in `log_scalars` or '
f'`runtime_info`, but got {key}')
@classmethod
def get_current_instance(cls) -> 'MessageHub':
"""Get latest created ``MessageHub`` instance.
:obj:`MessageHub` can call :meth:`get_current_instance` before any
instance has been created, and return a message hub with the instance
name "mmengine".
Returns:
MessageHub: Empty ``MessageHub`` instance.
"""
if not cls._instance_dict:
cls.get_instance('mmengine')
return super(MessageHub, cls).get_current_instance()
def update_scalar(self,
key: str,
value: Union[int, float, np.ndarray, torch.Tensor],
......@@ -116,7 +130,7 @@ class MessageHub(ManagerMixin):
could be resumed. Defaults to True.
"""
self._set_resumed_keys(key, resumed)
checked_value = self._get_valid_value(key, value)
checked_value = self._get_valid_value(value)
assert isinstance(count, int), (
f'The type of count must be int. but got {type(count): {count}}')
if key in self._log_scalars:
......@@ -153,13 +167,11 @@ class MessageHub(ManagerMixin):
if isinstance(log_val, dict):
assert 'value' in log_val, \
f'value must be defined in {log_val}'
count = self._get_valid_value(log_name,
log_val.get('count', 1))
checked_value = self._get_valid_value(log_name,
log_val['value'])
count = self._get_valid_value(log_val.get('count', 1))
checked_value = self._get_valid_value(log_val['value'])
else:
count = 1
checked_value = self._get_valid_value(log_name, log_val)
checked_value = self._get_valid_value(log_val)
assert isinstance(count,
int), ('The type of count must be int. but got '
f'{type(count): {count}}')
......@@ -268,13 +280,12 @@ class MessageHub(ManagerMixin):
# return copy.deepcopy(self._runtime_info[key])
return self._runtime_info[key]
def _get_valid_value(self, key: str,
value: Union[torch.Tensor, np.ndarray, int, float]) \
def _get_valid_value(
self, value: Union[torch.Tensor, np.ndarray, int, float]) \
-> Union[int, float]:
"""Convert value to python built-in type.
Args:
key (str): name of log.
value (torch.Tensor or np.ndarray or int or float): value of log.
Returns:
......@@ -287,7 +298,7 @@ class MessageHub(ManagerMixin):
assert value.numel() == 1
value = value.item()
else:
check_type(key, value, (int, float))
assert isinstance(value, (int, float))
return value # type: ignore
def __getstate__(self):
......
......@@ -316,6 +316,8 @@ class Runner:
self._experiment_name = f'{filename_no_ext}_{self._timestamp}'
else:
self._experiment_name = self.timestamp
self._log_dir = osp.join(self.work_dir, self.timestamp)
mmengine.mkdir_or_exist(self._log_dir)
# Used to reset registries location. See :meth:`Registry.build` for
# more details.
self.default_scope = DefaultScope.get_instance(
......@@ -628,7 +630,7 @@ class Runner:
MMLogger: A MMLogger object build from ``logger``.
"""
if log_file is None:
log_file = osp.join(self.work_dir, f'{self._experiment_name}.log')
log_file = osp.join(self._log_dir, f'{self._experiment_name}.log')
log_cfg = dict(log_level=log_level, log_file=log_file, **kwargs)
log_cfg.setdefault('name', self._experiment_name)
......@@ -678,7 +680,7 @@ class Runner:
visualizer = dict(
name=self._experiment_name,
vis_backends=[
dict(type='LocalVisBackend', save_dir=self._work_dir)
dict(type='LocalVisBackend', save_dir=self._log_dir)
])
return Visualizer.get_instance(**visualizer)
......@@ -688,7 +690,7 @@ class Runner:
if isinstance(visualizer, dict):
# ensure visualizer containing name key
visualizer.setdefault('name', self._experiment_name)
visualizer.setdefault('save_dir', self._work_dir)
visualizer.setdefault('save_dir', self._log_dir)
return VISUALIZERS.build(visualizer)
else:
raise TypeError(
......
......@@ -3,6 +3,7 @@ import logging
import os
import re
import sys
from collections import OrderedDict
from unittest.mock import patch
import pytest
......@@ -52,7 +53,7 @@ class TestLogger:
def test_init_rank1(self, tmp_path):
# If `rank!=1`, the `loglevel` of file_handler is `logging.ERROR`.
tmp_file = tmp_path / 'tmp_file.log'
log_path = tmp_path / 'tmp_file' / 'tmp_file_rank1.log'
log_path = tmp_path / 'tmp_file_rank1.log'
logger = MMLogger.get_instance(
'rank1.pkg2', log_level='INFO', log_file=str(tmp_file))
assert len(logger.handlers) == 1
......@@ -88,7 +89,7 @@ class TestLogger:
logger = MMLogger.get_instance(
instance_name, log_level=log_level, log_file=tmp_file)
logger.log(level=log_level, msg='welcome')
with open(tmp_path / 'tmp_file' / 'tmp_file.log', 'r') as f:
with open(tmp_path / 'tmp_file.log', 'r') as f:
log_text = f.read()
match = re.fullmatch(
self.file_handler_regex_time +
......@@ -150,3 +151,15 @@ class TestLogger:
print_log('welcome', logger=dict)
with pytest.raises(ValueError):
print_log('welcome', logger='unknown')
def test_get_instance(self):
# Test get root mmengine logger.
MMLogger._instance_dict = OrderedDict()
root_logger = MMLogger.get_current_instance()
mmdet_logger = MMLogger.get_instance('mmdet')
assert root_logger.name == mmdet_logger.name
assert id(root_logger) != id(mmdet_logger)
assert id(MMLogger.get_instance('mmengine')) == id(root_logger)
# Test original `get_current_instance` function.
MMLogger.get_instance('mmdet')
assert MMLogger.get_current_instance().instance_name == 'mmdet'
......@@ -86,7 +86,7 @@ class TestMessageHub:
assert loss_bbox.current() == 3
assert loss_iou.mean() == 0.5
with pytest.raises(TypeError):
with pytest.raises(AssertionError):
loss_dict = dict(error_type=[])
message_hub.update_scalars(loss_dict)
......@@ -112,3 +112,12 @@ class TestMessageHub:
instance.get_info('iter')
instance.get_scalar('loss')
def test_get_instance(self):
# Test get root mmengine message hub.
MessageHub._instance_dict = OrderedDict()
root_logger = MessageHub.get_current_instance()
assert id(MessageHub.get_instance('mmengine')) == id(root_logger)
# Test original `get_current_instance` function.
MessageHub.get_instance('mmdet')
assert MessageHub.get_current_instance().instance_name == 'mmdet'
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment