diff --git a/mmengine/__init__.py b/mmengine/__init__.py index 7bc5108bc4278e6edd729937e39d160aeb67ae27..093d93392c84b108a2a61de24140a51501878a5c 100644 --- a/mmengine/__init__.py +++ b/mmengine/__init__.py @@ -5,5 +5,6 @@ from .data import * from .dataset import * from .fileio import * from .hooks import * +from .logging import * from .registry import * from .utils import * diff --git a/mmengine/logging/__init__.py b/mmengine/logging/__init__.py index d9a8ce094836d7727acd2e7129e611314e4835d1..3744edccbef5def91fb5ab400dd971c6b7c56c4a 100644 --- a/mmengine/logging/__init__.py +++ b/mmengine/logging/__init__.py @@ -1,4 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base_global_accsessible import BaseGlobalAccessible, MetaGlobalAccessible +from .log_buffer import LogBuffer +from .logger import MMLogger, print_log +from .message_hub import MessageHub -__all__ = ['MetaGlobalAccessible', 'BaseGlobalAccessible'] +__all__ = [ + 'LogBuffer', 'MessageHub', 'MetaGlobalAccessible', 'BaseGlobalAccessible', + 'MMLogger', 'print_log' +] diff --git a/mmengine/logging/base_global_accsessible.py b/mmengine/logging/base_global_accsessible.py index 4f0a8472d87bdb923d5910b8884dc09f69895b44..147426488c51b84e4e8ced45e12995f7913c3b33 100644 --- a/mmengine/logging/base_global_accsessible.py +++ b/mmengine/logging/base_global_accsessible.py @@ -14,28 +14,42 @@ class MetaGlobalAccessible(type): Examples: >>> class SubClass1(metaclass=MetaGlobalAccessible): - >>> def __init__(self, args, **kwargs): + >>> def __init__(self, *args, **kwargs): >>> pass - AssertionError: The arguments of the - ``<class '__main__.subclass'>.__init__`` must contain name argument. + AssertionError: <class '__main__.SubClass1'>.__init__ must have the + name argument. >>> class SubClass2(metaclass=MetaGlobalAccessible): - >>> def __init__(self, a, name=None, *args, **kwargs): + >>> def __init__(self, a, name=None, **kwargs): >>> pass - AssertionError: The arguments of the - ``<class '__main__.subclass'>.__init__`` must have default values. + AssertionError: + In <class '__main__.SubClass2'>.__init__, Only the name argument is + allowed to have no default values. >>> class SubClass3(metaclass=MetaGlobalAccessible): - >>> def __init__(self, a, name=None, *args, **kwargs): + >>> def __init__(self, name, **kwargs): + >>> pass # Right format + >>> class SubClass4(metaclass=MetaGlobalAccessible): + >>> def __init__(self, a=1, name='', **kwargs): >>> pass # Right format """ def __init__(cls, *args): cls._instance_dict = OrderedDict() params = inspect.getfullargspec(cls) - # Make sure `cls('root')` can be implemented. - assert 'name' in params[0], \ - f'The arguments of the {cls}.__init__ must contain name argument' - assert len(params[3]) == len(params[0]) - 1, \ - f'The arguments of the {cls}.__init__ must have default values' + # `inspect.getfullargspec` returns a tuple includes `(args, varargs, + # varkw, defaults, kwonlyargs, kwonlydefaults, annotations)`. + # To make sure `cls(name='root')` can be implemented, the + # `args` and `defaults` should be checked. + params_names = params[0] if params[0] else [] + default_params = params[3] if params[3] else [] + assert 'name' in params_names, f'{cls}.__init__ must have the name ' \ + 'argument' + if len(default_params) == len(params_names) - 2 and 'name' != \ + params[0][1]: + raise AssertionError(f'In {cls}.__init__, Only the name argument ' + 'is allowed to have no default values.') + if len(default_params) < len(params_names) - 2: + raise AssertionError('Besides name, the arguments of the ' + f'{cls}.__init__ must have default values') cls.root = cls(name='root') super().__init__(*args) @@ -52,19 +66,20 @@ class BaseGlobalAccessible(metaclass=MetaGlobalAccessible): >>> def __init__(self, name=''): >>> super().__init__(name) >>> + >>> GlobalAccessible.create_instance('name') >>> instance_1 = GlobalAccessible.get_instance('name') >>> instance_2 = GlobalAccessible.get_instance('name') >>> assert id(instance_1) == id(instance_2) Args: - name (str): The name of the instance. Defaults to None. + name (str): Name of the instance. Defaults to ''. """ - def __init__(self, name: str = '', *args, **kwargs): + def __init__(self, name: str = '', **kwargs): self._name = name @classmethod - def create_instance(cls, name: str = None, *args, **kwargs) -> Any: + def create_instance(cls, name: str = '', **kwargs) -> Any: """Create subclass instance by name, and subclass cannot create instances with duplicated names. The created instance will be stored in ``cls._instance_dict``, and can be accessed by ``get_instance``. @@ -79,30 +94,29 @@ class BaseGlobalAccessible(metaclass=MetaGlobalAccessible): root Args: - name (str, optional): The name of instance. Defaults to None. + name (str): Name of instance. Defaults to ''. Returns: - object: The subclass instance. + object: Subclass instance. """ instance_dict = cls._instance_dict # Create instance and fill the instance in the `instance_dict`. - if name is not None: + if name: assert name not in instance_dict, f'{cls} cannot be created by ' \ f'{name} twice.' - instance = cls(name, *args, **kwargs) + instance = cls(name=name, **kwargs) instance_dict[name] = instance return instance # Get default root instance. else: - if args or kwargs: + if kwargs: raise ValueError('If name is not specified, create_instance ' f'will return root {cls} and cannot accept ' - f'any arguments, but got args: {args}, ' - f'kwargs: {kwargs}') + f'any arguments, but got kwargs: {kwargs}') return cls.root @classmethod - def get_instance(cls, name: str = None, current: bool = False) -> Any: + def get_instance(cls, name: str = '', current: bool = False) -> Any: """Get subclass instance by name if the name exists. if name is not specified, this method will return latest created instance of root instance. @@ -114,7 +128,7 @@ class BaseGlobalAccessible(metaclass=MetaGlobalAccessible): name1 >>> instance = GlobalAccessible.create_instance('name2') >>> instance = GlobalAccessible.get_instance(current=True) - >>> instance.instance_name # the latest created instance is name2 + >>> instance.instance_name name2 >>> instance = GlobalAccessible.get_instance() >>> instance.instance_name # get root instance @@ -124,19 +138,17 @@ class BaseGlobalAccessible(metaclass=MetaGlobalAccessible): name: name3, please make sure you have created it Args: - name (str, optional): The name of instance. Defaults to None. - current(bool): Whether to return the latest created instance - or the root instance, if name is not spicified. Defaults to - None. - current (bool): Whether to return the latest created instance. - Defaults to False. + name (str): Name of instance. Defaults to ''. + current(bool): Whether to return the latest created instance or + the root instance, if name is not spicified. Defaults to False. + Returns: object: Corresponding name instance, the latest instance, or root instance. """ instance_dict = cls._instance_dict # Get the instance by name. - if name is not None: + if name: assert name in instance_dict, \ f'Cannot get {cls} by name: {name}, please make sure you ' \ 'have created it' @@ -156,6 +168,6 @@ class BaseGlobalAccessible(metaclass=MetaGlobalAccessible): """Get the name of instance. Returns: - str: The name of instance. + str: Name of instance. """ return self._name diff --git a/mmengine/logging/log_buffer.py b/mmengine/logging/log_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..c55236480dacc0101c63e96d7a8bfdf2284d7b56 --- /dev/null +++ b/mmengine/logging/log_buffer.py @@ -0,0 +1,181 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from functools import partial +from typing import Any, Callable, Optional, Sequence, Tuple, Union + +import numpy as np + + +class BaseLogBuffer: + """Unified storage format for different log types. + + Record the history of log for further statistics. The subclass inherited + from ``BaseLogBuffer`` will implement the specific statistical methods. + + Args: + log_history (Sequence): History logs. Defaults to []. + count_history (Sequence): Counts of history logs. Defaults to []. + max_length (int): The max length of history logs. Defaults to 1000000. + """ + _statistics_methods: dict = dict() + + def __init__(self, + log_history: Sequence = [], + count_history: Sequence = [], + max_length: int = 1000000): + + self.max_length = max_length + assert len(log_history) == len(count_history), \ + 'The lengths of log_history and count_histroy should be equal' + if len(log_history) > max_length: + warnings.warn(f'The length of history buffer({len(log_history)}) ' + f'exceeds the max_length({max_length}), the first ' + 'few elements will be ignored.') + self._log_history = np.array(log_history[-max_length:]) + self._count_history = np.array(count_history[-max_length:]) + else: + self._log_history = np.array(log_history) + self._count_history = np.array(count_history) + + def update(self, log_val: Union[int, float], count: int = 1) -> None: + """update the log history. If the length of the buffer exceeds + ``self._max_length``, the oldest element will be removed from the + buffer. + + Args: + log_val (int or float): The value of log. + count (int): The accumulation times of log, defaults to 1. + ``count`` will be used in smooth statistics. + """ + if (not isinstance(log_val, (int, float)) + or not isinstance(count, (int, float))): + raise TypeError(f'log_val must be int or float but got ' + f'{type(log_val)}, count must be int but got ' + f'{type(count)}') + self._log_history = np.append(self._log_history, log_val) + self._count_history = np.append(self._count_history, count) + if len(self._log_history) > self.max_length: + self._log_history = self._log_history[-self.max_length:] + self._count_history = self._count_history[-self.max_length:] + + @property + def data(self) -> Tuple[np.ndarray, np.ndarray]: + """Get the ``_log_history`` and ``_count_history``. + + Returns: + Tuple[np.ndarray, np.ndarray]: History logs and the counts of + the history logs. + """ + return self._log_history, self._count_history + + @classmethod + def register_statistics(cls, method: Callable) -> Callable: + """Register custom statistics method to ``_statistics_methods``. + + Args: + method (Callable): Custom statistics method. + + Returns: + Callable: Original custom statistics method. + """ + method_name = method.__name__ + assert method_name not in cls._statistics_methods, \ + 'method_name cannot be registered twice!' + cls._statistics_methods[method_name] = method + return method + + def statistics(self, method_name: str, *arg, **kwargs) -> Any: + """Access statistics method by name. + + Args: + method_name (str): Name of method. + + Returns: + Any: Depends on corresponding method. + """ + if method_name not in self._statistics_methods: + raise KeyError(f'{method_name} has not been registered in ' + 'BaseLogBuffer._statistics_methods') + method = self._statistics_methods[method_name] + # Provide self arguments for registered functions. + method = partial(method, self) + return method(*arg, **kwargs) + + +class LogBuffer(BaseLogBuffer): + """``LogBuffer`` inherits from ``BaseLogBuffer`` and provides some basic + statistics methods, such as ``min``, ``max``, ``current`` and ``mean``.""" + + @BaseLogBuffer.register_statistics + def mean(self, window_size: Optional[int] = None) -> np.ndarray: + """Return the mean of the latest ``window_size`` values in log + histories. If ``window_size is None``, return the global mean of + history logs. + + Args: + window_size (int, optional): Size of statistics window. + + Returns: + np.ndarray: Mean value within the window. + """ + if window_size is not None: + assert isinstance(window_size, int), \ + 'The type of window size should be int, but got ' \ + f'{type(window_size)}' + else: + window_size = len(self._log_history) + logs_sum = self._log_history[-window_size:].sum() + counts_sum = self._count_history[-window_size:].sum() + return logs_sum / counts_sum + + @BaseLogBuffer.register_statistics + def max(self, window_size: Optional[int] = None) -> np.ndarray: + """Return the maximum value of the latest ``window_size`` values in log + histories. If ``window_size is None``, return the global maximum value + of history logs. + + Args: + window_size (int, optional): Size of statistics window. + + Returns: + np.ndarray: The maximum value within the window. + """ + if window_size is not None: + assert isinstance(window_size, int), \ + 'The type of window size should be int, but got ' \ + f'{type(window_size)}' + else: + window_size = len(self._log_history) + return self._log_history[-window_size:].max() + + @BaseLogBuffer.register_statistics + def min(self, window_size: Optional[int] = None) -> np.ndarray: + """Return the minimum value of the latest ``window_size`` values in log + histories. If ``window_size is None``, return the global minimum value + of history logs. + + Args: + window_size (int, optional): Size of statistics window. + + Returns: + np.ndarray: The minimum value within the window. + """ + if window_size is not None: + assert isinstance(window_size, int), \ + 'The type of window size should be int, but got ' \ + f'{type(window_size)}' + else: + window_size = len(self._log_history) + return self._log_history[-window_size:].min() + + @BaseLogBuffer.register_statistics + def current(self) -> np.ndarray: + """Return the recently updated values in log histories. + + Returns: + np.ndarray: Recently updated values in log histories. + """ + if len(self._log_history) == 0: + raise ValueError('LogBuffer._log_history is an empty array! ' + 'please call update first') + return self._log_history[-1] diff --git a/mmengine/logging/logger.py b/mmengine/logging/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..3c2aa13e9f1441a5d27c8b0a63ae493144bb4533 --- /dev/null +++ b/mmengine/logging/logger.py @@ -0,0 +1,172 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +import os +import sys +from logging import Logger, LogRecord +from typing import Optional, Union + +import torch.distributed as dist +from termcolor import colored + +from .base_global_accsessible import BaseGlobalAccessible + + +class MMFormatter(logging.Formatter): + """Colorful format for MMLogger. If the log level is error, the logger will + additionally output the location of the code. + + Args: + color (bool): Whether to use colorful format. filehandler is not + allowed to use color format, otherwise it will be garbled. + """ + _color_mapping: dict = dict( + ERROR='red', WARNING='yellow', INFO='white', DEBUG='green') + + def __init__(self, color: bool = True, **kwargs): + + super().__init__(**kwargs) + # Get prefix format according to color. + error_prefix = self._get_prefix('ERROR', color) + warn_prefix = self._get_prefix('WARNING', color) + info_prefix = self._get_prefix('INFO', color) + debug_prefix = self._get_prefix('DEBUG', color) + # Config output format. + self.err_format = f'%(asctime)s - %(name)s - {error_prefix} - ' \ + f'%(pathname)s - %(funcName)s - %(lineno)d - ' \ + '%(message)s' + self.warn_format = f'%(asctime)s - %(name)s - {warn_prefix} - %(' \ + 'message)s' + self.info_format = f'%(asctime)s - %(name)s - {info_prefix} - %(' \ + 'message)s' + self.debug_format = f'%(asctime)s - %(name)s - {debug_prefix} - %(' \ + 'message)s' + + def _get_prefix(self, level: str, color: bool) -> str: + """Get the prefix of the target log level. + + Args: + level (str): log level. + color (bool): Whether to get colorful prefix. + + Returns: + str: The plain or colorful prefix. + """ + if color: + prefix = colored( + level, + self._color_mapping[level], + attrs=['blink', 'underline']) + else: + prefix = level + return prefix + + def format(self, record: LogRecord) -> str: + """Override the `logging.Formatter.format`` method `. Output the + message according to the specified log level. + + Args: + record (LogRecord): A LogRecord instance represents an event being + logged. + + Returns: + str: Formatted result. + """ + if record.levelno == logging.ERROR: + self._style._fmt = self.err_format + elif record.levelno == logging.WARNING: + self._style._fmt = self.warn_format + elif record.levelno == logging.INFO: + self._style._fmt = self.info_format + elif record.levelno == logging.DEBUG: + self._style._fmt = self.debug_format + + result = logging.Formatter.format(self, record) + return result + + +class MMLogger(Logger, BaseGlobalAccessible): + """The Logger manager which can create formatted logger and get specified + logger globally. MMLogger is created and accessed in the same way as + BaseGlobalAccessible. + + Args: + name (str): Logger name. Defaults to ''. + log_file (str, optional): The log filename. If specified, a + ``FileHandler`` will be added to the logger. Defaults to None. + log_level: The log level of the handler. Defaults to 'NOTSET'. + file_mode (str): The file mode used in opening log file. + Defaults to 'w'. + """ + + def __init__(self, + name: str = '', + log_file: Optional[str] = None, + log_level: str = 'NOTSET', + file_mode: str = 'w'): + Logger.__init__(self, name) + BaseGlobalAccessible.__init__(self, name) + # Get rank in DDP mode. + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + else: + rank = 0 + + # Config stream_handler. If `rank != 0`. stream_handler can only + # export ERROR logs. + stream_handler = logging.StreamHandler(stream=sys.stdout) + stream_handler.setFormatter(MMFormatter(color=True)) + stream_handler.setLevel(log_level) if rank == 0 else \ + stream_handler.setLevel(logging.ERROR) + self.handlers.append(stream_handler) + + if log_file is not None: + if rank != 0: + # rename `log_file` with rank prefix. + path_split = log_file.split(os.sep) + path_split[-1] = f'rank{rank}_{path_split[-1]}' + log_file = os.sep.join(path_split) + # Here, the default behaviour of the official logger is 'a'. Thus, + # we provide an interface to change the file mode to the default + # behaviour. `FileHandler` is not supported to have colors, + # otherwise it will appear garbled. + file_handler = logging.FileHandler(log_file, file_mode) + file_handler.setFormatter(MMFormatter(color=False)) + file_handler.setLevel(log_level) + self.handlers.append(file_handler) + + +def print_log(msg, + logger: Optional[Union[Logger, str]] = None, + level=logging.INFO): + """Print a log message. + + Args: + msg (str): The message to be logged. + logger (Logger or str, optional): The logger to be used. + Some special loggers are: + - "silent": no message will be printed. + - "current": Log message via the latest created logger. + - other str: the logger obtained with `MMLogger.get_instance`. + - None: The `print()` method will be used to print log messages. + level (int): Logging level. Only available when `logger` is a Logger + object or "root". + """ + if logger is None: + print(msg) + elif isinstance(logger, logging.Logger): + logger.log(level, msg) + elif logger == 'silent': + pass + elif logger == 'current': + logger_instance = MMLogger.get_instance(current=True) + logger_instance.log(level, msg) + elif isinstance(logger, str): + try: + _logger = MMLogger.get_instance(logger) + _logger.log(level, msg) + except AssertionError: + raise ValueError(f'MMLogger: {logger} has not been created!') + else: + raise TypeError( + '`logger` should be either a logging.Logger object, str, ' + f'"silent", "current" or None, but got {type(logger)}') diff --git a/mmengine/logging/message_hub.py b/mmengine/logging/message_hub.py new file mode 100644 index 0000000000000000000000000000000000000000..2b5087e889df11f69c7de5d0446acd21b594b48b --- /dev/null +++ b/mmengine/logging/message_hub.py @@ -0,0 +1,107 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from collections import OrderedDict +from typing import Any, Union + +from .base_global_accsessible import BaseGlobalAccessible +from .log_buffer import LogBuffer + + +class MessageHub(BaseGlobalAccessible): + """Message hub for component interaction. MessageHub is created and + accessed in the same way as BaseGlobalAccessible. + + ``MessageHub`` will record log information and runtime information. The + log information refers to the learning rate, loss, etc. of the model + when training a model, which will be stored as ``LogBuffer``. The runtime + information refers to the iter times, meta information of runner etc., + which will be overwritten by next update. + + Args: + name (str): Name of message hub, for global access. Defaults to ''. + """ + + def __init__(self, name: str = ''): + self._log_buffers: OrderedDict = OrderedDict() + self._runtime_info: OrderedDict = OrderedDict() + super().__init__(name) + + def update_log(self, key: str, value: Union[int, float], count: int = 1) \ + -> None: + """Update log buffer. + + Args: + key (str): Key of ``LogBuffer``. + value (int or float): Value of log. + count (int): Accumulation times of log, defaults to 1. `count` + will be used in smooth statistics. + """ + if key in self._log_buffers: + self._log_buffers[key].update(value, count) + else: + self._log_buffers[key] = LogBuffer([value], [count]) + + def update_info(self, key: str, value: Any) -> None: + """Update runtime information. + + Args: + key (str): Key of runtime information. + value (Any): Value of runtime information. + """ + self._runtime_info[key] = value + + @property + def log_buffers(self) -> OrderedDict: + """Get all ``LogBuffer`` instances. + + Note: + Considering the large memory footprint of ``log_buffers`` in the + post-training, ``MessageHub.log_buffers`` will not return the + result of ``copy.deepcopy``. + + Returns: + OrderedDict: All ``LogBuffer`` instances. + """ + return self._log_buffers + + @property + def runtime_info(self) -> OrderedDict: + """Get all runtime information. + + Returns: + OrderedDict: A copy of all runtime information. + """ + return copy.deepcopy(self._runtime_info) + + def get_log(self, key: str) -> LogBuffer: + """Get ``LogBuffer`` instance by key. + + Note: + Considering the large memory footprint of ``log_buffers`` in the + post-training, ``MessageHub.get_log`` will not return the + result of ``copy.deepcopy``. + + Args: + key (str): Key of ``LogBuffer``. + + Returns: + LogBuffer: Corresponding ``LogBuffer`` instance if the key exists. + """ + if key not in self.log_buffers: + raise KeyError(f'{key} is not found in Messagehub.log_buffers: ' + f'instance name is: {MessageHub.instance_name}') + return self._log_buffers[key] + + def get_info(self, key: str) -> Any: + """Get runtime information by key. + + Args: + key (str): Key of runtime information. + + Returns: + Any: A copy of corresponding runtime information if the key exists. + """ + if key not in self.runtime_info: + raise KeyError(f'{key} is not found in Messagehub.log_buffers: ' + f'instance name is: {MessageHub.instance_name}') + return copy.deepcopy(self._runtime_info[key]) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index f80e094dbb217ec06c7154e670825e7d1f27eeab..831e85089f4697ad1734c7a6aad4585338e21d03 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -2,4 +2,5 @@ addict numpy pyyaml regex;sys_platform=='win32' +termcolor yapf diff --git a/tests/test_logging/test_global_meta.py b/tests/test_logging/test_global_meta.py index 2a36b8e8d8c7caf749f98ed24281e4a06097afe9..3cb7e1252f29089d7b74633ac68f3c2ec45caedc 100644 --- a/tests/test_logging/test_global_meta.py +++ b/tests/test_logging/test_global_meta.py @@ -23,26 +23,43 @@ class TestGlobalMeta: # error. with pytest.raises(AssertionError): - class SubClassNoName(metaclass=MetaGlobalAccessible): + class SubClassNoName1(metaclass=MetaGlobalAccessible): - def __init__(self, *args, **kwargs): + def __init__(self, a, *args, **kwargs): pass - # Subclass's constructor contains arguments without default value will - # raise an error. + # The constructor of subclasses must have default values for all + # arguments except name. Since `MetaGlobalAccessible` cannot tell which + # parameter does not have ha default value, we should test invalid + # subclasses separately. with pytest.raises(AssertionError): - class SubClassNoDefault(metaclass=MetaGlobalAccessible): + class SubClassNoDefault1(metaclass=MetaGlobalAccessible): def __init__(self, a, name='', *args, **kwargs): pass - class GlobalAccessible(metaclass=MetaGlobalAccessible): + with pytest.raises(AssertionError): + + class SubClassNoDefault2(metaclass=MetaGlobalAccessible): + + def __init__(self, a, b, name='', *args, **kwargs): + pass + + # Valid subclass. + class GlobalAccessible1(metaclass=MetaGlobalAccessible): + + def __init__(self, name): + self.name = name + + # Allow name not to be the first arguments. + + class GlobalAccessible2(metaclass=MetaGlobalAccessible): - def __init__(self, name=''): + def __init__(self, a=1, name=''): self.name = name - assert GlobalAccessible.root.name == 'root' + assert GlobalAccessible1.root.name == 'root' class TestBaseGlobalAccessible: diff --git a/tests/test_logging/test_logger.py b/tests/test_logging/test_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..a81078dd3ff41ca8754cb9426ffa983f2d098407 --- /dev/null +++ b/tests/test_logging/test_logger.py @@ -0,0 +1,138 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +import os +import re +import sys +from unittest.mock import patch + +import pytest + +from mmengine import MMLogger, print_log + + +class TestLogger: + regex_time = r'\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3}' + + @patch('torch.distributed.get_rank', lambda: 0) + @patch('torch.distributed.is_initialized', lambda: True) + @patch('torch.distributed.is_available', lambda: True) + def test_init_rank0(self, tmp_path): + logger = MMLogger.create_instance('rank0.pkg1', log_level='INFO') + assert logger.name == 'rank0.pkg1' + assert logger.instance_name == 'rank0.pkg1' + # Logger get from `MMLogger.get_instance` does not inherit from + # `logging.root` + assert logger.parent is None + assert len(logger.handlers) == 1 + assert isinstance(logger.handlers[0], logging.StreamHandler) + assert logger.level == logging.NOTSET + assert logger.handlers[0].level == logging.INFO + # If `rank=0`, the `log_level` of stream_handler and file_handler + # depends on the given arguments. + tmp_file = tmp_path / 'tmp_file.log' + logger = MMLogger.create_instance( + 'rank0.pkg2', log_level='INFO', log_file=str(tmp_file)) + assert isinstance(logger, logging.Logger) + assert len(logger.handlers) == 2 + assert isinstance(logger.handlers[0], logging.StreamHandler) + assert isinstance(logger.handlers[1], logging.FileHandler) + logger_pkg3 = MMLogger.get_instance('rank0.pkg2') + assert id(logger_pkg3) == id(logger) + logging.shutdown() + + @patch('torch.distributed.get_rank', lambda: 1) + @patch('torch.distributed.is_initialized', lambda: True) + @patch('torch.distributed.is_available', lambda: True) + def test_init_rank1(self, tmp_path): + # If `rank!=1`, the `loglevel` of file_handler is `logging.ERROR`. + tmp_file = tmp_path / 'tmp_file.log' + log_path = tmp_path / 'rank1_tmp_file.log' + logger = MMLogger.create_instance( + 'rank1.pkg2', log_level='INFO', log_file=str(tmp_file)) + assert len(logger.handlers) == 2 + assert logger.handlers[0].level == logging.ERROR + assert logger.handlers[1].level == logging.INFO + assert os.path.exists(log_path) + logging.shutdown() + + @pytest.mark.parametrize('log_level', + [logging.WARNING, logging.INFO, logging.DEBUG]) + def test_handler(self, capsys, tmp_path, log_level): + # test stream handler can output correct format logs + logger_name = f'test_stream_{str(log_level)}' + logger = MMLogger.create_instance(logger_name, log_level=log_level) + logger.log(level=log_level, msg='welcome') + out, _ = capsys.readouterr() + # Skip match colored INFO + loglevl_name = logging._levelToName[log_level] + match = re.fullmatch( + self.regex_time + f' - {logger_name} - ' + f'(.*){loglevl_name}(.*) - welcome\n', out) + assert match is not None + + # test file_handler output plain text without color. + tmp_file = tmp_path / 'tmp_file.log' + logger_name = f'test_file_{log_level}' + logger = MMLogger.create_instance( + logger_name, log_level=log_level, log_file=tmp_file) + logger.log(level=log_level, msg='welcome') + with open(tmp_file, 'r') as f: + log_text = f.read() + match = re.fullmatch( + self.regex_time + f' - {logger_name} - {loglevl_name} - ' + f'welcome\n', log_text) + assert match is not None + logging.shutdown() + + def test_erro_format(self, capsys): + # test error level log can output file path, function name and + # line number + logger = MMLogger.create_instance('test_error', log_level='INFO') + logger.error('welcome') + lineno = sys._getframe().f_lineno - 1 + file_path = __file__ + function_name = sys._getframe().f_code.co_name + pattern = self.regex_time + r' - test_error - (.*)ERROR(.*) - '\ + f'{file_path} - {function_name} - ' \ + f'{lineno} - welcome\n' + out, _ = capsys.readouterr() + match = re.fullmatch(pattern, out) + assert match is not None + + def test_print_log(self, capsys, tmp_path): + # caplog cannot record MMLogger's logs. + # Test simple print. + print_log('welcome', logger=None) + out, _ = capsys.readouterr() + assert out == 'welcome\n' + # Test silent logger and skip print. + print_log('welcome', logger='silent') + out, _ = capsys.readouterr() + assert out == '' + logger = MMLogger.create_instance('test_print_log') + # Test using specified logger + print_log('welcome', logger=logger) + out, _ = capsys.readouterr() + match = re.fullmatch( + self.regex_time + ' - test_print_log - (.*)INFO(.*) - ' + 'welcome\n', out) + assert match is not None + # Test access logger by name. + print_log('welcome', logger='test_print_log') + out, _ = capsys.readouterr() + match = re.fullmatch( + self.regex_time + ' - test_print_log - (.*)INFO(.*) - ' + 'welcome\n', out) + assert match is not None + # Test access the latest created logger. + print_log('welcome', logger='current') + out, _ = capsys.readouterr() + match = re.fullmatch( + self.regex_time + ' - test_print_log - (.*)INFO(.*) - ' + 'welcome\n', out) + assert match is not None + # Test invalid logger type. + with pytest.raises(TypeError): + print_log('welcome', logger=dict) + with pytest.raises(ValueError): + print_log('welcome', logger='unknown') diff --git a/tests/test_logging/test_loggger_buffer.py b/tests/test_logging/test_loggger_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..5ea5ff821dd3f3c2b520b146921c5b776d15c165 --- /dev/null +++ b/tests/test_logging/test_loggger_buffer.py @@ -0,0 +1,118 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import pytest +import torch + +from mmengine import LogBuffer + + +class TestLoggerBuffer: + + def test_init(self): + # `BaseLogBuffer` is an abstract class, using `CurrentLogBuffer` to + # test `update` method + log_buffer = LogBuffer() + assert log_buffer.max_length == 1000000 + log_history, counts = log_buffer.data + assert len(log_history) == 0 + assert len(counts) == 0 + # test the length of array exceed `max_length` + logs = np.random.randint(1, 10, log_buffer.max_length + 1) + counts = np.random.randint(1, 10, log_buffer.max_length + 1) + log_buffer = LogBuffer(logs, counts) + log_history, count_history = log_buffer.data + + assert len(log_history) == log_buffer.max_length + assert len(count_history) == log_buffer.max_length + assert logs[1] == log_history[0] + assert counts[1] == count_history[0] + + # The different lengths of `log_history` and `count_history` will + # raise error + with pytest.raises(AssertionError): + LogBuffer([1, 2], [1]) + + @pytest.mark.parametrize('array_method', + [torch.tensor, np.array, lambda x: x]) + def test_update(self, array_method): + # `BaseLogBuffer` is an abstract class, using `CurrentLogBuffer` to + # test `update` method + log_buffer = LogBuffer() + log_history = array_method([1, 2, 3, 4, 5]) + count_history = array_method([5, 5, 5, 5, 5]) + for i in range(len(log_history)): + log_buffer.update(float(log_history[i]), float(count_history[i])) + + recorded_history, recorded_count = log_buffer.data + for a, b in zip(log_history, recorded_history): + assert float(a) == float(b) + for a, b in zip(count_history, recorded_count): + assert float(a) == float(b) + + # test the length of `array` exceed `max_length` + max_array = array_method([[-1] + [1] * (log_buffer.max_length - 1)]) + max_count = array_method([[-1] + [1] * (log_buffer.max_length - 1)]) + log_buffer = LogBuffer(max_array, max_count) + log_buffer.update(1) + log_history, count_history = log_buffer.data + assert log_history[0] == 1 + assert count_history[0] == 1 + assert len(log_history) == log_buffer.max_length + assert len(count_history) == log_buffer.max_length + # Update an iterable object will raise a type error, `log_val` and + # `count` should be single value + with pytest.raises(TypeError): + log_buffer.update(array_method([1, 2])) + + @pytest.mark.parametrize('statistics_method, log_buffer_type', + [(np.min, 'min'), (np.max, 'max')]) + def test_max_min(self, statistics_method, log_buffer_type): + log_history = np.random.randint(1, 5, 20) + count_history = np.ones(20) + log_buffer = LogBuffer(log_history, count_history) + assert statistics_method(log_history[-10:]) == \ + getattr(log_buffer, log_buffer_type)(10) + assert statistics_method(log_history) == \ + getattr(log_buffer, log_buffer_type)() + + def test_mean(self): + log_history = np.random.randint(1, 5, 20) + count_history = np.ones(20) + log_buffer = LogBuffer(log_history, count_history) + assert np.sum(log_history[-10:]) / \ + np.sum(count_history[-10:]) == \ + log_buffer.mean(10) + assert np.sum(log_history) / \ + np.sum(count_history) == \ + log_buffer.mean() + + def test_current(self): + log_history = np.random.randint(1, 5, 20) + count_history = np.ones(20) + log_buffer = LogBuffer(log_history, count_history) + assert log_history[-1] == log_buffer.current() + # test get empty array + log_buffer = LogBuffer() + with pytest.raises(ValueError): + log_buffer.current() + + def test_statistics(self): + log_history = np.array([1, 2, 3, 4, 5]) + count_history = np.array([1, 1, 1, 1, 1]) + log_buffer = LogBuffer(log_history, count_history) + assert log_buffer.statistics('mean') == 3 + assert log_buffer.statistics('min') == 1 + assert log_buffer.statistics('max') == 5 + assert log_buffer.statistics('current') == 5 + # Access unknown method will raise an error. + with pytest.raises(KeyError): + log_buffer.statistics('unknown') + + def test_register_statistics(self): + + @LogBuffer.register_statistics + def custom_statistics(self): + return -1 + + log_buffer = LogBuffer() + assert log_buffer.statistics('custom_statistics') == -1 diff --git a/tests/test_logging/test_message_hub.py b/tests/test_logging/test_message_hub.py new file mode 100644 index 0000000000000000000000000000000000000000..3b4b6452c247c36407943f4a8085f6d7ce697668 --- /dev/null +++ b/tests/test_logging/test_message_hub.py @@ -0,0 +1,57 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import pytest + +from mmengine import MessageHub + + +class TestMessageHub: + + def test_init(self): + message_hub = MessageHub('name') + assert message_hub.instance_name == 'name' + assert len(message_hub.log_buffers) == 0 + assert len(message_hub.log_buffers) == 0 + + def test_update_log(self): + message_hub = MessageHub.create_instance() + # test create target `LogBuffer` by name + message_hub.update_log('name', 1) + log_buffer = message_hub.log_buffers['name'] + assert (log_buffer._log_history == np.array([1])).all() + # test update target `LogBuffer` by name + message_hub.update_log('name', 1) + assert (log_buffer._log_history == np.array([1, 1])).all() + # unmatched string will raise a key error + + def test_update_info(self): + message_hub = MessageHub.create_instance() + # test runtime value can be overwritten. + message_hub.update_info('key', 2) + assert message_hub.runtime_info['key'] == 2 + message_hub.update_info('key', 1) + assert message_hub.runtime_info['key'] == 1 + + def test_get_log_buffers(self): + message_hub = MessageHub.create_instance() + # Get undefined key will raise error + with pytest.raises(KeyError): + message_hub.get_log('unknown') + # test get log_buffer as wished + log_history = np.array([1, 2, 3, 4, 5]) + count = np.array([1, 1, 1, 1, 1]) + for i in range(len(log_history)): + message_hub.update_log('test_value', float(log_history[i]), + int(count[i])) + recorded_history, recorded_count = \ + message_hub.get_log('test_value').data + assert (log_history == recorded_history).all() + assert (recorded_count == count).all() + + def test_get_runtime(self): + message_hub = MessageHub.create_instance() + with pytest.raises(KeyError): + message_hub.get_info('unknown') + recorded_dict = dict(a=1, b=2) + message_hub.update_info('test_value', recorded_dict) + assert message_hub.get_info('test_value') == recorded_dict