Skip to content
Snippets Groups Projects
Unverified Commit 2d782b49 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Feature] add logging UT and impl (#43)

* first commit

* update logging

* update MessageHub unit test and LogBuffer unitest

* update logging docs

* update logging impl and test

* update test logging

* update test

* update logging test and impl

* Fix logging test

* Update log_buffer.py

* rename statistics argumentes

* fix as comment

* Fix as comment

* Fix as comment

* Fix as comment

* Fix meta class

* Fix as comment

* Fix as comment

* Fix as comment

* Fix name declare

* Fix as comment

* Fix as comment

* Fix as comment

* Fix docstring

* Fix as comment

* Fix as comment
parent e1462335
No related branches found
No related tags found
No related merge requests found
...@@ -5,5 +5,6 @@ from .data import * ...@@ -5,5 +5,6 @@ from .data import *
from .dataset import * from .dataset import *
from .fileio import * from .fileio import *
from .hooks import * from .hooks import *
from .logging import *
from .registry import * from .registry import *
from .utils import * from .utils import *
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .base_global_accsessible import BaseGlobalAccessible, MetaGlobalAccessible from .base_global_accsessible import BaseGlobalAccessible, MetaGlobalAccessible
from .log_buffer import LogBuffer
from .logger import MMLogger, print_log
from .message_hub import MessageHub
__all__ = ['MetaGlobalAccessible', 'BaseGlobalAccessible'] __all__ = [
'LogBuffer', 'MessageHub', 'MetaGlobalAccessible', 'BaseGlobalAccessible',
'MMLogger', 'print_log'
]
...@@ -14,28 +14,42 @@ class MetaGlobalAccessible(type): ...@@ -14,28 +14,42 @@ class MetaGlobalAccessible(type):
Examples: Examples:
>>> class SubClass1(metaclass=MetaGlobalAccessible): >>> class SubClass1(metaclass=MetaGlobalAccessible):
>>> def __init__(self, args, **kwargs): >>> def __init__(self, *args, **kwargs):
>>> pass >>> pass
AssertionError: The arguments of the AssertionError: <class '__main__.SubClass1'>.__init__ must have the
``<class '__main__.subclass'>.__init__`` must contain name argument. name argument.
>>> class SubClass2(metaclass=MetaGlobalAccessible): >>> class SubClass2(metaclass=MetaGlobalAccessible):
>>> def __init__(self, a, name=None, *args, **kwargs): >>> def __init__(self, a, name=None, **kwargs):
>>> pass >>> pass
AssertionError: The arguments of the AssertionError:
``<class '__main__.subclass'>.__init__`` must have default values. In <class '__main__.SubClass2'>.__init__, Only the name argument is
allowed to have no default values.
>>> class SubClass3(metaclass=MetaGlobalAccessible): >>> class SubClass3(metaclass=MetaGlobalAccessible):
>>> def __init__(self, a, name=None, *args, **kwargs): >>> def __init__(self, name, **kwargs):
>>> pass # Right format
>>> class SubClass4(metaclass=MetaGlobalAccessible):
>>> def __init__(self, a=1, name='', **kwargs):
>>> pass # Right format >>> pass # Right format
""" """
def __init__(cls, *args): def __init__(cls, *args):
cls._instance_dict = OrderedDict() cls._instance_dict = OrderedDict()
params = inspect.getfullargspec(cls) params = inspect.getfullargspec(cls)
# Make sure `cls('root')` can be implemented. # `inspect.getfullargspec` returns a tuple includes `(args, varargs,
assert 'name' in params[0], \ # varkw, defaults, kwonlyargs, kwonlydefaults, annotations)`.
f'The arguments of the {cls}.__init__ must contain name argument' # To make sure `cls(name='root')` can be implemented, the
assert len(params[3]) == len(params[0]) - 1, \ # `args` and `defaults` should be checked.
f'The arguments of the {cls}.__init__ must have default values' params_names = params[0] if params[0] else []
default_params = params[3] if params[3] else []
assert 'name' in params_names, f'{cls}.__init__ must have the name ' \
'argument'
if len(default_params) == len(params_names) - 2 and 'name' != \
params[0][1]:
raise AssertionError(f'In {cls}.__init__, Only the name argument '
'is allowed to have no default values.')
if len(default_params) < len(params_names) - 2:
raise AssertionError('Besides name, the arguments of the '
f'{cls}.__init__ must have default values')
cls.root = cls(name='root') cls.root = cls(name='root')
super().__init__(*args) super().__init__(*args)
...@@ -52,19 +66,20 @@ class BaseGlobalAccessible(metaclass=MetaGlobalAccessible): ...@@ -52,19 +66,20 @@ class BaseGlobalAccessible(metaclass=MetaGlobalAccessible):
>>> def __init__(self, name=''): >>> def __init__(self, name=''):
>>> super().__init__(name) >>> super().__init__(name)
>>> >>>
>>> GlobalAccessible.create_instance('name')
>>> instance_1 = GlobalAccessible.get_instance('name') >>> instance_1 = GlobalAccessible.get_instance('name')
>>> instance_2 = GlobalAccessible.get_instance('name') >>> instance_2 = GlobalAccessible.get_instance('name')
>>> assert id(instance_1) == id(instance_2) >>> assert id(instance_1) == id(instance_2)
Args: Args:
name (str): The name of the instance. Defaults to None. name (str): Name of the instance. Defaults to ''.
""" """
def __init__(self, name: str = '', *args, **kwargs): def __init__(self, name: str = '', **kwargs):
self._name = name self._name = name
@classmethod @classmethod
def create_instance(cls, name: str = None, *args, **kwargs) -> Any: def create_instance(cls, name: str = '', **kwargs) -> Any:
"""Create subclass instance by name, and subclass cannot create """Create subclass instance by name, and subclass cannot create
instances with duplicated names. The created instance will be stored in instances with duplicated names. The created instance will be stored in
``cls._instance_dict``, and can be accessed by ``get_instance``. ``cls._instance_dict``, and can be accessed by ``get_instance``.
...@@ -79,30 +94,29 @@ class BaseGlobalAccessible(metaclass=MetaGlobalAccessible): ...@@ -79,30 +94,29 @@ class BaseGlobalAccessible(metaclass=MetaGlobalAccessible):
root root
Args: Args:
name (str, optional): The name of instance. Defaults to None. name (str): Name of instance. Defaults to ''.
Returns: Returns:
object: The subclass instance. object: Subclass instance.
""" """
instance_dict = cls._instance_dict instance_dict = cls._instance_dict
# Create instance and fill the instance in the `instance_dict`. # Create instance and fill the instance in the `instance_dict`.
if name is not None: if name:
assert name not in instance_dict, f'{cls} cannot be created by ' \ assert name not in instance_dict, f'{cls} cannot be created by ' \
f'{name} twice.' f'{name} twice.'
instance = cls(name, *args, **kwargs) instance = cls(name=name, **kwargs)
instance_dict[name] = instance instance_dict[name] = instance
return instance return instance
# Get default root instance. # Get default root instance.
else: else:
if args or kwargs: if kwargs:
raise ValueError('If name is not specified, create_instance ' raise ValueError('If name is not specified, create_instance '
f'will return root {cls} and cannot accept ' f'will return root {cls} and cannot accept '
f'any arguments, but got args: {args}, ' f'any arguments, but got kwargs: {kwargs}')
f'kwargs: {kwargs}')
return cls.root return cls.root
@classmethod @classmethod
def get_instance(cls, name: str = None, current: bool = False) -> Any: def get_instance(cls, name: str = '', current: bool = False) -> Any:
"""Get subclass instance by name if the name exists. if name is not """Get subclass instance by name if the name exists. if name is not
specified, this method will return latest created instance of root specified, this method will return latest created instance of root
instance. instance.
...@@ -114,7 +128,7 @@ class BaseGlobalAccessible(metaclass=MetaGlobalAccessible): ...@@ -114,7 +128,7 @@ class BaseGlobalAccessible(metaclass=MetaGlobalAccessible):
name1 name1
>>> instance = GlobalAccessible.create_instance('name2') >>> instance = GlobalAccessible.create_instance('name2')
>>> instance = GlobalAccessible.get_instance(current=True) >>> instance = GlobalAccessible.get_instance(current=True)
>>> instance.instance_name # the latest created instance is name2 >>> instance.instance_name
name2 name2
>>> instance = GlobalAccessible.get_instance() >>> instance = GlobalAccessible.get_instance()
>>> instance.instance_name # get root instance >>> instance.instance_name # get root instance
...@@ -124,19 +138,17 @@ class BaseGlobalAccessible(metaclass=MetaGlobalAccessible): ...@@ -124,19 +138,17 @@ class BaseGlobalAccessible(metaclass=MetaGlobalAccessible):
name: name3, please make sure you have created it name: name3, please make sure you have created it
Args: Args:
name (str, optional): The name of instance. Defaults to None. name (str): Name of instance. Defaults to ''.
current(bool): Whether to return the latest created instance current(bool): Whether to return the latest created instance or
or the root instance, if name is not spicified. Defaults to the root instance, if name is not spicified. Defaults to False.
None.
current (bool): Whether to return the latest created instance.
Defaults to False.
Returns: Returns:
object: Corresponding name instance, the latest instance, or root object: Corresponding name instance, the latest instance, or root
instance. instance.
""" """
instance_dict = cls._instance_dict instance_dict = cls._instance_dict
# Get the instance by name. # Get the instance by name.
if name is not None: if name:
assert name in instance_dict, \ assert name in instance_dict, \
f'Cannot get {cls} by name: {name}, please make sure you ' \ f'Cannot get {cls} by name: {name}, please make sure you ' \
'have created it' 'have created it'
...@@ -156,6 +168,6 @@ class BaseGlobalAccessible(metaclass=MetaGlobalAccessible): ...@@ -156,6 +168,6 @@ class BaseGlobalAccessible(metaclass=MetaGlobalAccessible):
"""Get the name of instance. """Get the name of instance.
Returns: Returns:
str: The name of instance. str: Name of instance.
""" """
return self._name return self._name
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from functools import partial
from typing import Any, Callable, Optional, Sequence, Tuple, Union
import numpy as np
class BaseLogBuffer:
"""Unified storage format for different log types.
Record the history of log for further statistics. The subclass inherited
from ``BaseLogBuffer`` will implement the specific statistical methods.
Args:
log_history (Sequence): History logs. Defaults to [].
count_history (Sequence): Counts of history logs. Defaults to [].
max_length (int): The max length of history logs. Defaults to 1000000.
"""
_statistics_methods: dict = dict()
def __init__(self,
log_history: Sequence = [],
count_history: Sequence = [],
max_length: int = 1000000):
self.max_length = max_length
assert len(log_history) == len(count_history), \
'The lengths of log_history and count_histroy should be equal'
if len(log_history) > max_length:
warnings.warn(f'The length of history buffer({len(log_history)}) '
f'exceeds the max_length({max_length}), the first '
'few elements will be ignored.')
self._log_history = np.array(log_history[-max_length:])
self._count_history = np.array(count_history[-max_length:])
else:
self._log_history = np.array(log_history)
self._count_history = np.array(count_history)
def update(self, log_val: Union[int, float], count: int = 1) -> None:
"""update the log history. If the length of the buffer exceeds
``self._max_length``, the oldest element will be removed from the
buffer.
Args:
log_val (int or float): The value of log.
count (int): The accumulation times of log, defaults to 1.
``count`` will be used in smooth statistics.
"""
if (not isinstance(log_val, (int, float))
or not isinstance(count, (int, float))):
raise TypeError(f'log_val must be int or float but got '
f'{type(log_val)}, count must be int but got '
f'{type(count)}')
self._log_history = np.append(self._log_history, log_val)
self._count_history = np.append(self._count_history, count)
if len(self._log_history) > self.max_length:
self._log_history = self._log_history[-self.max_length:]
self._count_history = self._count_history[-self.max_length:]
@property
def data(self) -> Tuple[np.ndarray, np.ndarray]:
"""Get the ``_log_history`` and ``_count_history``.
Returns:
Tuple[np.ndarray, np.ndarray]: History logs and the counts of
the history logs.
"""
return self._log_history, self._count_history
@classmethod
def register_statistics(cls, method: Callable) -> Callable:
"""Register custom statistics method to ``_statistics_methods``.
Args:
method (Callable): Custom statistics method.
Returns:
Callable: Original custom statistics method.
"""
method_name = method.__name__
assert method_name not in cls._statistics_methods, \
'method_name cannot be registered twice!'
cls._statistics_methods[method_name] = method
return method
def statistics(self, method_name: str, *arg, **kwargs) -> Any:
"""Access statistics method by name.
Args:
method_name (str): Name of method.
Returns:
Any: Depends on corresponding method.
"""
if method_name not in self._statistics_methods:
raise KeyError(f'{method_name} has not been registered in '
'BaseLogBuffer._statistics_methods')
method = self._statistics_methods[method_name]
# Provide self arguments for registered functions.
method = partial(method, self)
return method(*arg, **kwargs)
class LogBuffer(BaseLogBuffer):
"""``LogBuffer`` inherits from ``BaseLogBuffer`` and provides some basic
statistics methods, such as ``min``, ``max``, ``current`` and ``mean``."""
@BaseLogBuffer.register_statistics
def mean(self, window_size: Optional[int] = None) -> np.ndarray:
"""Return the mean of the latest ``window_size`` values in log
histories. If ``window_size is None``, return the global mean of
history logs.
Args:
window_size (int, optional): Size of statistics window.
Returns:
np.ndarray: Mean value within the window.
"""
if window_size is not None:
assert isinstance(window_size, int), \
'The type of window size should be int, but got ' \
f'{type(window_size)}'
else:
window_size = len(self._log_history)
logs_sum = self._log_history[-window_size:].sum()
counts_sum = self._count_history[-window_size:].sum()
return logs_sum / counts_sum
@BaseLogBuffer.register_statistics
def max(self, window_size: Optional[int] = None) -> np.ndarray:
"""Return the maximum value of the latest ``window_size`` values in log
histories. If ``window_size is None``, return the global maximum value
of history logs.
Args:
window_size (int, optional): Size of statistics window.
Returns:
np.ndarray: The maximum value within the window.
"""
if window_size is not None:
assert isinstance(window_size, int), \
'The type of window size should be int, but got ' \
f'{type(window_size)}'
else:
window_size = len(self._log_history)
return self._log_history[-window_size:].max()
@BaseLogBuffer.register_statistics
def min(self, window_size: Optional[int] = None) -> np.ndarray:
"""Return the minimum value of the latest ``window_size`` values in log
histories. If ``window_size is None``, return the global minimum value
of history logs.
Args:
window_size (int, optional): Size of statistics window.
Returns:
np.ndarray: The minimum value within the window.
"""
if window_size is not None:
assert isinstance(window_size, int), \
'The type of window size should be int, but got ' \
f'{type(window_size)}'
else:
window_size = len(self._log_history)
return self._log_history[-window_size:].min()
@BaseLogBuffer.register_statistics
def current(self) -> np.ndarray:
"""Return the recently updated values in log histories.
Returns:
np.ndarray: Recently updated values in log histories.
"""
if len(self._log_history) == 0:
raise ValueError('LogBuffer._log_history is an empty array! '
'please call update first')
return self._log_history[-1]
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os
import sys
from logging import Logger, LogRecord
from typing import Optional, Union
import torch.distributed as dist
from termcolor import colored
from .base_global_accsessible import BaseGlobalAccessible
class MMFormatter(logging.Formatter):
"""Colorful format for MMLogger. If the log level is error, the logger will
additionally output the location of the code.
Args:
color (bool): Whether to use colorful format. filehandler is not
allowed to use color format, otherwise it will be garbled.
"""
_color_mapping: dict = dict(
ERROR='red', WARNING='yellow', INFO='white', DEBUG='green')
def __init__(self, color: bool = True, **kwargs):
super().__init__(**kwargs)
# Get prefix format according to color.
error_prefix = self._get_prefix('ERROR', color)
warn_prefix = self._get_prefix('WARNING', color)
info_prefix = self._get_prefix('INFO', color)
debug_prefix = self._get_prefix('DEBUG', color)
# Config output format.
self.err_format = f'%(asctime)s - %(name)s - {error_prefix} - ' \
f'%(pathname)s - %(funcName)s - %(lineno)d - ' \
'%(message)s'
self.warn_format = f'%(asctime)s - %(name)s - {warn_prefix} - %(' \
'message)s'
self.info_format = f'%(asctime)s - %(name)s - {info_prefix} - %(' \
'message)s'
self.debug_format = f'%(asctime)s - %(name)s - {debug_prefix} - %(' \
'message)s'
def _get_prefix(self, level: str, color: bool) -> str:
"""Get the prefix of the target log level.
Args:
level (str): log level.
color (bool): Whether to get colorful prefix.
Returns:
str: The plain or colorful prefix.
"""
if color:
prefix = colored(
level,
self._color_mapping[level],
attrs=['blink', 'underline'])
else:
prefix = level
return prefix
def format(self, record: LogRecord) -> str:
"""Override the `logging.Formatter.format`` method `. Output the
message according to the specified log level.
Args:
record (LogRecord): A LogRecord instance represents an event being
logged.
Returns:
str: Formatted result.
"""
if record.levelno == logging.ERROR:
self._style._fmt = self.err_format
elif record.levelno == logging.WARNING:
self._style._fmt = self.warn_format
elif record.levelno == logging.INFO:
self._style._fmt = self.info_format
elif record.levelno == logging.DEBUG:
self._style._fmt = self.debug_format
result = logging.Formatter.format(self, record)
return result
class MMLogger(Logger, BaseGlobalAccessible):
"""The Logger manager which can create formatted logger and get specified
logger globally. MMLogger is created and accessed in the same way as
BaseGlobalAccessible.
Args:
name (str): Logger name. Defaults to ''.
log_file (str, optional): The log filename. If specified, a
``FileHandler`` will be added to the logger. Defaults to None.
log_level: The log level of the handler. Defaults to 'NOTSET'.
file_mode (str): The file mode used in opening log file.
Defaults to 'w'.
"""
def __init__(self,
name: str = '',
log_file: Optional[str] = None,
log_level: str = 'NOTSET',
file_mode: str = 'w'):
Logger.__init__(self, name)
BaseGlobalAccessible.__init__(self, name)
# Get rank in DDP mode.
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
else:
rank = 0
# Config stream_handler. If `rank != 0`. stream_handler can only
# export ERROR logs.
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setFormatter(MMFormatter(color=True))
stream_handler.setLevel(log_level) if rank == 0 else \
stream_handler.setLevel(logging.ERROR)
self.handlers.append(stream_handler)
if log_file is not None:
if rank != 0:
# rename `log_file` with rank prefix.
path_split = log_file.split(os.sep)
path_split[-1] = f'rank{rank}_{path_split[-1]}'
log_file = os.sep.join(path_split)
# Here, the default behaviour of the official logger is 'a'. Thus,
# we provide an interface to change the file mode to the default
# behaviour. `FileHandler` is not supported to have colors,
# otherwise it will appear garbled.
file_handler = logging.FileHandler(log_file, file_mode)
file_handler.setFormatter(MMFormatter(color=False))
file_handler.setLevel(log_level)
self.handlers.append(file_handler)
def print_log(msg,
logger: Optional[Union[Logger, str]] = None,
level=logging.INFO):
"""Print a log message.
Args:
msg (str): The message to be logged.
logger (Logger or str, optional): The logger to be used.
Some special loggers are:
- "silent": no message will be printed.
- "current": Log message via the latest created logger.
- other str: the logger obtained with `MMLogger.get_instance`.
- None: The `print()` method will be used to print log messages.
level (int): Logging level. Only available when `logger` is a Logger
object or "root".
"""
if logger is None:
print(msg)
elif isinstance(logger, logging.Logger):
logger.log(level, msg)
elif logger == 'silent':
pass
elif logger == 'current':
logger_instance = MMLogger.get_instance(current=True)
logger_instance.log(level, msg)
elif isinstance(logger, str):
try:
_logger = MMLogger.get_instance(logger)
_logger.log(level, msg)
except AssertionError:
raise ValueError(f'MMLogger: {logger} has not been created!')
else:
raise TypeError(
'`logger` should be either a logging.Logger object, str, '
f'"silent", "current" or None, but got {type(logger)}')
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from collections import OrderedDict
from typing import Any, Union
from .base_global_accsessible import BaseGlobalAccessible
from .log_buffer import LogBuffer
class MessageHub(BaseGlobalAccessible):
"""Message hub for component interaction. MessageHub is created and
accessed in the same way as BaseGlobalAccessible.
``MessageHub`` will record log information and runtime information. The
log information refers to the learning rate, loss, etc. of the model
when training a model, which will be stored as ``LogBuffer``. The runtime
information refers to the iter times, meta information of runner etc.,
which will be overwritten by next update.
Args:
name (str): Name of message hub, for global access. Defaults to ''.
"""
def __init__(self, name: str = ''):
self._log_buffers: OrderedDict = OrderedDict()
self._runtime_info: OrderedDict = OrderedDict()
super().__init__(name)
def update_log(self, key: str, value: Union[int, float], count: int = 1) \
-> None:
"""Update log buffer.
Args:
key (str): Key of ``LogBuffer``.
value (int or float): Value of log.
count (int): Accumulation times of log, defaults to 1. `count`
will be used in smooth statistics.
"""
if key in self._log_buffers:
self._log_buffers[key].update(value, count)
else:
self._log_buffers[key] = LogBuffer([value], [count])
def update_info(self, key: str, value: Any) -> None:
"""Update runtime information.
Args:
key (str): Key of runtime information.
value (Any): Value of runtime information.
"""
self._runtime_info[key] = value
@property
def log_buffers(self) -> OrderedDict:
"""Get all ``LogBuffer`` instances.
Note:
Considering the large memory footprint of ``log_buffers`` in the
post-training, ``MessageHub.log_buffers`` will not return the
result of ``copy.deepcopy``.
Returns:
OrderedDict: All ``LogBuffer`` instances.
"""
return self._log_buffers
@property
def runtime_info(self) -> OrderedDict:
"""Get all runtime information.
Returns:
OrderedDict: A copy of all runtime information.
"""
return copy.deepcopy(self._runtime_info)
def get_log(self, key: str) -> LogBuffer:
"""Get ``LogBuffer`` instance by key.
Note:
Considering the large memory footprint of ``log_buffers`` in the
post-training, ``MessageHub.get_log`` will not return the
result of ``copy.deepcopy``.
Args:
key (str): Key of ``LogBuffer``.
Returns:
LogBuffer: Corresponding ``LogBuffer`` instance if the key exists.
"""
if key not in self.log_buffers:
raise KeyError(f'{key} is not found in Messagehub.log_buffers: '
f'instance name is: {MessageHub.instance_name}')
return self._log_buffers[key]
def get_info(self, key: str) -> Any:
"""Get runtime information by key.
Args:
key (str): Key of runtime information.
Returns:
Any: A copy of corresponding runtime information if the key exists.
"""
if key not in self.runtime_info:
raise KeyError(f'{key} is not found in Messagehub.log_buffers: '
f'instance name is: {MessageHub.instance_name}')
return copy.deepcopy(self._runtime_info[key])
...@@ -2,4 +2,5 @@ addict ...@@ -2,4 +2,5 @@ addict
numpy numpy
pyyaml pyyaml
regex;sys_platform=='win32' regex;sys_platform=='win32'
termcolor
yapf yapf
...@@ -23,26 +23,43 @@ class TestGlobalMeta: ...@@ -23,26 +23,43 @@ class TestGlobalMeta:
# error. # error.
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
class SubClassNoName(metaclass=MetaGlobalAccessible): class SubClassNoName1(metaclass=MetaGlobalAccessible):
def __init__(self, *args, **kwargs): def __init__(self, a, *args, **kwargs):
pass pass
# Subclass's constructor contains arguments without default value will # The constructor of subclasses must have default values for all
# raise an error. # arguments except name. Since `MetaGlobalAccessible` cannot tell which
# parameter does not have ha default value, we should test invalid
# subclasses separately.
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
class SubClassNoDefault(metaclass=MetaGlobalAccessible): class SubClassNoDefault1(metaclass=MetaGlobalAccessible):
def __init__(self, a, name='', *args, **kwargs): def __init__(self, a, name='', *args, **kwargs):
pass pass
class GlobalAccessible(metaclass=MetaGlobalAccessible): with pytest.raises(AssertionError):
class SubClassNoDefault2(metaclass=MetaGlobalAccessible):
def __init__(self, a, b, name='', *args, **kwargs):
pass
# Valid subclass.
class GlobalAccessible1(metaclass=MetaGlobalAccessible):
def __init__(self, name):
self.name = name
# Allow name not to be the first arguments.
class GlobalAccessible2(metaclass=MetaGlobalAccessible):
def __init__(self, name=''): def __init__(self, a=1, name=''):
self.name = name self.name = name
assert GlobalAccessible.root.name == 'root' assert GlobalAccessible1.root.name == 'root'
class TestBaseGlobalAccessible: class TestBaseGlobalAccessible:
......
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os
import re
import sys
from unittest.mock import patch
import pytest
from mmengine import MMLogger, print_log
class TestLogger:
regex_time = r'\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3}'
@patch('torch.distributed.get_rank', lambda: 0)
@patch('torch.distributed.is_initialized', lambda: True)
@patch('torch.distributed.is_available', lambda: True)
def test_init_rank0(self, tmp_path):
logger = MMLogger.create_instance('rank0.pkg1', log_level='INFO')
assert logger.name == 'rank0.pkg1'
assert logger.instance_name == 'rank0.pkg1'
# Logger get from `MMLogger.get_instance` does not inherit from
# `logging.root`
assert logger.parent is None
assert len(logger.handlers) == 1
assert isinstance(logger.handlers[0], logging.StreamHandler)
assert logger.level == logging.NOTSET
assert logger.handlers[0].level == logging.INFO
# If `rank=0`, the `log_level` of stream_handler and file_handler
# depends on the given arguments.
tmp_file = tmp_path / 'tmp_file.log'
logger = MMLogger.create_instance(
'rank0.pkg2', log_level='INFO', log_file=str(tmp_file))
assert isinstance(logger, logging.Logger)
assert len(logger.handlers) == 2
assert isinstance(logger.handlers[0], logging.StreamHandler)
assert isinstance(logger.handlers[1], logging.FileHandler)
logger_pkg3 = MMLogger.get_instance('rank0.pkg2')
assert id(logger_pkg3) == id(logger)
logging.shutdown()
@patch('torch.distributed.get_rank', lambda: 1)
@patch('torch.distributed.is_initialized', lambda: True)
@patch('torch.distributed.is_available', lambda: True)
def test_init_rank1(self, tmp_path):
# If `rank!=1`, the `loglevel` of file_handler is `logging.ERROR`.
tmp_file = tmp_path / 'tmp_file.log'
log_path = tmp_path / 'rank1_tmp_file.log'
logger = MMLogger.create_instance(
'rank1.pkg2', log_level='INFO', log_file=str(tmp_file))
assert len(logger.handlers) == 2
assert logger.handlers[0].level == logging.ERROR
assert logger.handlers[1].level == logging.INFO
assert os.path.exists(log_path)
logging.shutdown()
@pytest.mark.parametrize('log_level',
[logging.WARNING, logging.INFO, logging.DEBUG])
def test_handler(self, capsys, tmp_path, log_level):
# test stream handler can output correct format logs
logger_name = f'test_stream_{str(log_level)}'
logger = MMLogger.create_instance(logger_name, log_level=log_level)
logger.log(level=log_level, msg='welcome')
out, _ = capsys.readouterr()
# Skip match colored INFO
loglevl_name = logging._levelToName[log_level]
match = re.fullmatch(
self.regex_time + f' - {logger_name} - '
f'(.*){loglevl_name}(.*) - welcome\n', out)
assert match is not None
# test file_handler output plain text without color.
tmp_file = tmp_path / 'tmp_file.log'
logger_name = f'test_file_{log_level}'
logger = MMLogger.create_instance(
logger_name, log_level=log_level, log_file=tmp_file)
logger.log(level=log_level, msg='welcome')
with open(tmp_file, 'r') as f:
log_text = f.read()
match = re.fullmatch(
self.regex_time + f' - {logger_name} - {loglevl_name} - '
f'welcome\n', log_text)
assert match is not None
logging.shutdown()
def test_erro_format(self, capsys):
# test error level log can output file path, function name and
# line number
logger = MMLogger.create_instance('test_error', log_level='INFO')
logger.error('welcome')
lineno = sys._getframe().f_lineno - 1
file_path = __file__
function_name = sys._getframe().f_code.co_name
pattern = self.regex_time + r' - test_error - (.*)ERROR(.*) - '\
f'{file_path} - {function_name} - ' \
f'{lineno} - welcome\n'
out, _ = capsys.readouterr()
match = re.fullmatch(pattern, out)
assert match is not None
def test_print_log(self, capsys, tmp_path):
# caplog cannot record MMLogger's logs.
# Test simple print.
print_log('welcome', logger=None)
out, _ = capsys.readouterr()
assert out == 'welcome\n'
# Test silent logger and skip print.
print_log('welcome', logger='silent')
out, _ = capsys.readouterr()
assert out == ''
logger = MMLogger.create_instance('test_print_log')
# Test using specified logger
print_log('welcome', logger=logger)
out, _ = capsys.readouterr()
match = re.fullmatch(
self.regex_time + ' - test_print_log - (.*)INFO(.*) - '
'welcome\n', out)
assert match is not None
# Test access logger by name.
print_log('welcome', logger='test_print_log')
out, _ = capsys.readouterr()
match = re.fullmatch(
self.regex_time + ' - test_print_log - (.*)INFO(.*) - '
'welcome\n', out)
assert match is not None
# Test access the latest created logger.
print_log('welcome', logger='current')
out, _ = capsys.readouterr()
match = re.fullmatch(
self.regex_time + ' - test_print_log - (.*)INFO(.*) - '
'welcome\n', out)
assert match is not None
# Test invalid logger type.
with pytest.raises(TypeError):
print_log('welcome', logger=dict)
with pytest.raises(ValueError):
print_log('welcome', logger='unknown')
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch
from mmengine import LogBuffer
class TestLoggerBuffer:
def test_init(self):
# `BaseLogBuffer` is an abstract class, using `CurrentLogBuffer` to
# test `update` method
log_buffer = LogBuffer()
assert log_buffer.max_length == 1000000
log_history, counts = log_buffer.data
assert len(log_history) == 0
assert len(counts) == 0
# test the length of array exceed `max_length`
logs = np.random.randint(1, 10, log_buffer.max_length + 1)
counts = np.random.randint(1, 10, log_buffer.max_length + 1)
log_buffer = LogBuffer(logs, counts)
log_history, count_history = log_buffer.data
assert len(log_history) == log_buffer.max_length
assert len(count_history) == log_buffer.max_length
assert logs[1] == log_history[0]
assert counts[1] == count_history[0]
# The different lengths of `log_history` and `count_history` will
# raise error
with pytest.raises(AssertionError):
LogBuffer([1, 2], [1])
@pytest.mark.parametrize('array_method',
[torch.tensor, np.array, lambda x: x])
def test_update(self, array_method):
# `BaseLogBuffer` is an abstract class, using `CurrentLogBuffer` to
# test `update` method
log_buffer = LogBuffer()
log_history = array_method([1, 2, 3, 4, 5])
count_history = array_method([5, 5, 5, 5, 5])
for i in range(len(log_history)):
log_buffer.update(float(log_history[i]), float(count_history[i]))
recorded_history, recorded_count = log_buffer.data
for a, b in zip(log_history, recorded_history):
assert float(a) == float(b)
for a, b in zip(count_history, recorded_count):
assert float(a) == float(b)
# test the length of `array` exceed `max_length`
max_array = array_method([[-1] + [1] * (log_buffer.max_length - 1)])
max_count = array_method([[-1] + [1] * (log_buffer.max_length - 1)])
log_buffer = LogBuffer(max_array, max_count)
log_buffer.update(1)
log_history, count_history = log_buffer.data
assert log_history[0] == 1
assert count_history[0] == 1
assert len(log_history) == log_buffer.max_length
assert len(count_history) == log_buffer.max_length
# Update an iterable object will raise a type error, `log_val` and
# `count` should be single value
with pytest.raises(TypeError):
log_buffer.update(array_method([1, 2]))
@pytest.mark.parametrize('statistics_method, log_buffer_type',
[(np.min, 'min'), (np.max, 'max')])
def test_max_min(self, statistics_method, log_buffer_type):
log_history = np.random.randint(1, 5, 20)
count_history = np.ones(20)
log_buffer = LogBuffer(log_history, count_history)
assert statistics_method(log_history[-10:]) == \
getattr(log_buffer, log_buffer_type)(10)
assert statistics_method(log_history) == \
getattr(log_buffer, log_buffer_type)()
def test_mean(self):
log_history = np.random.randint(1, 5, 20)
count_history = np.ones(20)
log_buffer = LogBuffer(log_history, count_history)
assert np.sum(log_history[-10:]) / \
np.sum(count_history[-10:]) == \
log_buffer.mean(10)
assert np.sum(log_history) / \
np.sum(count_history) == \
log_buffer.mean()
def test_current(self):
log_history = np.random.randint(1, 5, 20)
count_history = np.ones(20)
log_buffer = LogBuffer(log_history, count_history)
assert log_history[-1] == log_buffer.current()
# test get empty array
log_buffer = LogBuffer()
with pytest.raises(ValueError):
log_buffer.current()
def test_statistics(self):
log_history = np.array([1, 2, 3, 4, 5])
count_history = np.array([1, 1, 1, 1, 1])
log_buffer = LogBuffer(log_history, count_history)
assert log_buffer.statistics('mean') == 3
assert log_buffer.statistics('min') == 1
assert log_buffer.statistics('max') == 5
assert log_buffer.statistics('current') == 5
# Access unknown method will raise an error.
with pytest.raises(KeyError):
log_buffer.statistics('unknown')
def test_register_statistics(self):
@LogBuffer.register_statistics
def custom_statistics(self):
return -1
log_buffer = LogBuffer()
assert log_buffer.statistics('custom_statistics') == -1
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
from mmengine import MessageHub
class TestMessageHub:
def test_init(self):
message_hub = MessageHub('name')
assert message_hub.instance_name == 'name'
assert len(message_hub.log_buffers) == 0
assert len(message_hub.log_buffers) == 0
def test_update_log(self):
message_hub = MessageHub.create_instance()
# test create target `LogBuffer` by name
message_hub.update_log('name', 1)
log_buffer = message_hub.log_buffers['name']
assert (log_buffer._log_history == np.array([1])).all()
# test update target `LogBuffer` by name
message_hub.update_log('name', 1)
assert (log_buffer._log_history == np.array([1, 1])).all()
# unmatched string will raise a key error
def test_update_info(self):
message_hub = MessageHub.create_instance()
# test runtime value can be overwritten.
message_hub.update_info('key', 2)
assert message_hub.runtime_info['key'] == 2
message_hub.update_info('key', 1)
assert message_hub.runtime_info['key'] == 1
def test_get_log_buffers(self):
message_hub = MessageHub.create_instance()
# Get undefined key will raise error
with pytest.raises(KeyError):
message_hub.get_log('unknown')
# test get log_buffer as wished
log_history = np.array([1, 2, 3, 4, 5])
count = np.array([1, 1, 1, 1, 1])
for i in range(len(log_history)):
message_hub.update_log('test_value', float(log_history[i]),
int(count[i]))
recorded_history, recorded_count = \
message_hub.get_log('test_value').data
assert (log_history == recorded_history).all()
assert (recorded_count == count).all()
def test_get_runtime(self):
message_hub = MessageHub.create_instance()
with pytest.raises(KeyError):
message_hub.get_info('unknown')
recorded_dict = dict(a=1, b=2)
message_hub.update_info('test_value', recorded_dict)
assert message_hub.get_info('test_value') == recorded_dict
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment