From 104858414763ba6e5631d51787b1da79d931980e Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Sat, 26 Mar 2022 21:21:25 +0800 Subject: [PATCH] [Enhancement] Refine GlobalAccessble (#144) * rename global accessible and intergration get_sintance and create_instance * move ManagerMixin to utils * fix as docstring and seporate get_instance to get_instance and get_current_instance * fix lint * fix docstring, rename and move test_global_meta * fix manager's runtime error description fix manager's runtime error description * Add comments * Add comments --- docs/zh_cn/api.rst | 1 + mmengine/logging/__init__.py | 6 +- mmengine/logging/base_global_accsessible.py | 173 -------------------- mmengine/logging/logger.py | 37 +++-- mmengine/logging/message_hub.py | 8 +- mmengine/runner/runner.py | 6 +- mmengine/utils/__init__.py | 3 +- mmengine/utils/manager.py | 161 ++++++++++++++++++ mmengine/visualization/writer.py | 11 +- tests/test_logging/test_global_meta.py | 110 ------------- tests/test_logging/test_logger.py | 14 +- tests/test_logging/test_message_hub.py | 10 +- tests/test_utils/test_manager.py | 72 ++++++++ tests/test_visualizer/test_writer.py | 2 +- 14 files changed, 284 insertions(+), 330 deletions(-) delete mode 100644 mmengine/logging/base_global_accsessible.py create mode 100644 mmengine/utils/manager.py delete mode 100644 tests/test_logging/test_global_meta.py create mode 100644 tests/test_utils/test_manager.py diff --git a/docs/zh_cn/api.rst b/docs/zh_cn/api.rst index 999c6023..2a9992bb 100644 --- a/docs/zh_cn/api.rst +++ b/docs/zh_cn/api.rst @@ -36,3 +36,4 @@ Distributed Logging -------- .. automodule:: mmengine.logging + :members: diff --git a/mmengine/logging/__init__.py b/mmengine/logging/__init__.py index 3744edcc..13945401 100644 --- a/mmengine/logging/__init__.py +++ b/mmengine/logging/__init__.py @@ -1,10 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .base_global_accsessible import BaseGlobalAccessible, MetaGlobalAccessible from .log_buffer import LogBuffer from .logger import MMLogger, print_log from .message_hub import MessageHub -__all__ = [ - 'LogBuffer', 'MessageHub', 'MetaGlobalAccessible', 'BaseGlobalAccessible', - 'MMLogger', 'print_log' -] +__all__ = ['LogBuffer', 'MessageHub', 'MMLogger', 'print_log'] diff --git a/mmengine/logging/base_global_accsessible.py b/mmengine/logging/base_global_accsessible.py deleted file mode 100644 index 14742648..00000000 --- a/mmengine/logging/base_global_accsessible.py +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import inspect -from collections import OrderedDict -from typing import Any, Optional - - -class MetaGlobalAccessible(type): - """The metaclass for global accessible class. - - The subclasses inheriting from ``MetaGlobalAccessible`` will manage their - own ``_instance_dict`` and root instances. The constructors of subclasses - must contain an optional ``name`` argument and all other arguments must - have default values. - - Examples: - >>> class SubClass1(metaclass=MetaGlobalAccessible): - >>> def __init__(self, *args, **kwargs): - >>> pass - AssertionError: <class '__main__.SubClass1'>.__init__ must have the - name argument. - >>> class SubClass2(metaclass=MetaGlobalAccessible): - >>> def __init__(self, a, name=None, **kwargs): - >>> pass - AssertionError: - In <class '__main__.SubClass2'>.__init__, Only the name argument is - allowed to have no default values. - >>> class SubClass3(metaclass=MetaGlobalAccessible): - >>> def __init__(self, name, **kwargs): - >>> pass # Right format - >>> class SubClass4(metaclass=MetaGlobalAccessible): - >>> def __init__(self, a=1, name='', **kwargs): - >>> pass # Right format - """ - - def __init__(cls, *args): - cls._instance_dict = OrderedDict() - params = inspect.getfullargspec(cls) - # `inspect.getfullargspec` returns a tuple includes `(args, varargs, - # varkw, defaults, kwonlyargs, kwonlydefaults, annotations)`. - # To make sure `cls(name='root')` can be implemented, the - # `args` and `defaults` should be checked. - params_names = params[0] if params[0] else [] - default_params = params[3] if params[3] else [] - assert 'name' in params_names, f'{cls}.__init__ must have the name ' \ - 'argument' - if len(default_params) == len(params_names) - 2 and 'name' != \ - params[0][1]: - raise AssertionError(f'In {cls}.__init__, Only the name argument ' - 'is allowed to have no default values.') - if len(default_params) < len(params_names) - 2: - raise AssertionError('Besides name, the arguments of the ' - f'{cls}.__init__ must have default values') - cls.root = cls(name='root') - super().__init__(*args) - - -class BaseGlobalAccessible(metaclass=MetaGlobalAccessible): - """``BaseGlobalAccessible`` is the base class for classes that have global - access requirements. - - The subclasses inheriting from ``BaseGlobalAccessible`` can get their - global instancees. - - Examples: - >>> class GlobalAccessible(BaseGlobalAccessible): - >>> def __init__(self, name=''): - >>> super().__init__(name) - >>> - >>> GlobalAccessible.create_instance('name') - >>> instance_1 = GlobalAccessible.get_instance('name') - >>> instance_2 = GlobalAccessible.get_instance('name') - >>> assert id(instance_1) == id(instance_2) - - Args: - name (str): Name of the instance. Defaults to ''. - """ - - def __init__(self, name: str = '', **kwargs): - self._name = name - - @classmethod - def create_instance(cls, name: str = '', **kwargs) -> Any: - """Create subclass instance by name, and subclass cannot create - instances with duplicated names. The created instance will be stored in - ``cls._instance_dict``, and can be accessed by ``get_instance``. - - Examples: - >>> instance_1 = GlobalAccessible.create_instance('name') - >>> instance_2 = GlobalAccessible.create_instance('name') - AssertionError: <class '__main__.GlobalAccessible'> cannot be - created by name twice. - >>> root_instance = GlobalAccessible.create_instance() - >>> root_instance.instance_name # get default root instance - root - - Args: - name (str): Name of instance. Defaults to ''. - - Returns: - object: Subclass instance. - """ - instance_dict = cls._instance_dict - # Create instance and fill the instance in the `instance_dict`. - if name: - assert name not in instance_dict, f'{cls} cannot be created by ' \ - f'{name} twice.' - instance = cls(name=name, **kwargs) - instance_dict[name] = instance - return instance - # Get default root instance. - else: - if kwargs: - raise ValueError('If name is not specified, create_instance ' - f'will return root {cls} and cannot accept ' - f'any arguments, but got kwargs: {kwargs}') - return cls.root - - @classmethod - def get_instance(cls, name: str = '', current: bool = False) -> Any: - """Get subclass instance by name if the name exists. if name is not - specified, this method will return latest created instance of root - instance. - - Examples - >>> instance = GlobalAccessible.create_instance('name1') - >>> instance = GlobalAccessible.get_instance('name1') - >>> instance.instance_name - name1 - >>> instance = GlobalAccessible.create_instance('name2') - >>> instance = GlobalAccessible.get_instance(current=True) - >>> instance.instance_name - name2 - >>> instance = GlobalAccessible.get_instance() - >>> instance.instance_name # get root instance - root - >>> instance = GlobalAccessible.get_instance('name3') # error - AssertionError: Cannot get <class '__main__.GlobalAccessible'> by - name: name3, please make sure you have created it - - Args: - name (str): Name of instance. Defaults to ''. - current(bool): Whether to return the latest created instance or - the root instance, if name is not spicified. Defaults to False. - - Returns: - object: Corresponding name instance, the latest instance, or root - instance. - """ - instance_dict = cls._instance_dict - # Get the instance by name. - if name: - assert name in instance_dict, \ - f'Cannot get {cls} by name: {name}, please make sure you ' \ - 'have created it' - return instance_dict[name] - # Get latest instantiated instance or root instance. - else: - if current: - current_name = next(iter(reversed(cls._instance_dict))) - assert current_name, f'Before calling {cls}.get_instance, ' \ - 'you should call create_instance.' - return cls._instance_dict[current_name] - else: - return cls.root - - @property - def instance_name(self) -> Optional[str]: - """Get the name of instance. - - Returns: - str: Name of instance. - """ - return self._name diff --git a/mmengine/logging/logger.py b/mmengine/logging/logger.py index 3c2aa13e..25a14520 100644 --- a/mmengine/logging/logger.py +++ b/mmengine/logging/logger.py @@ -8,7 +8,7 @@ from typing import Optional, Union import torch.distributed as dist from termcolor import colored -from .base_global_accsessible import BaseGlobalAccessible +from mmengine.utils import ManagerMixin class MMFormatter(logging.Formatter): @@ -84,10 +84,10 @@ class MMFormatter(logging.Formatter): return result -class MMLogger(Logger, BaseGlobalAccessible): +class MMLogger(Logger, ManagerMixin): """The Logger manager which can create formatted logger and get specified logger globally. MMLogger is created and accessed in the same way as - BaseGlobalAccessible. + ManagerMixin. Args: name (str): Logger name. Defaults to ''. @@ -104,7 +104,7 @@ class MMLogger(Logger, BaseGlobalAccessible): log_level: str = 'NOTSET', file_mode: str = 'w'): Logger.__init__(self, name) - BaseGlobalAccessible.__init__(self, name) + ManagerMixin.__init__(self, name) # Get rank in DDP mode. if dist.is_available() and dist.is_initialized(): rank = dist.get_rank() @@ -137,19 +137,22 @@ class MMLogger(Logger, BaseGlobalAccessible): def print_log(msg, logger: Optional[Union[Logger, str]] = None, - level=logging.INFO): + level=logging.INFO) -> None: """Print a log message. Args: msg (str): The message to be logged. - logger (Logger or str, optional): The logger to be used. + logger (Logger or str, optional): If the type of logger is + ``logging.Logger``, we directly use logger to log messages. Some special loggers are: - - "silent": no message will be printed. - - "current": Log message via the latest created logger. - - other str: the logger obtained with `MMLogger.get_instance`. + - "silent": No message will be printed. + - "current": Use latest created logger to log message. + - other str: Instance name of logger. The corresponding logger + will log message if it has been created, otherwise ``print_log`` + will raise a `ValueError`. - None: The `print()` method will be used to print log messages. level (int): Logging level. Only available when `logger` is a Logger - object or "root". + object, "current", or a created logger instance name. """ if logger is None: print(msg) @@ -158,13 +161,17 @@ def print_log(msg, elif logger == 'silent': pass elif logger == 'current': - logger_instance = MMLogger.get_instance(current=True) + logger_instance = MMLogger.get_current_instance() logger_instance.log(level, msg) elif isinstance(logger, str): - try: - _logger = MMLogger.get_instance(logger) - _logger.log(level, msg) - except AssertionError: + # If the type of `logger` is `str`, but not with value of `current` or + # `silent`, we assume it indicates the name of the logger. If the + # corresponding logger has not been created, `print_log` will raise + # a `ValueError`. + if MMLogger.check_instance_created(logger): + logger_instance = MMLogger.get_instance(logger) + logger_instance.log(level, msg) + else: raise ValueError(f'MMLogger: {logger} has not been created!') else: raise TypeError( diff --git a/mmengine/logging/message_hub.py b/mmengine/logging/message_hub.py index a93d6038..75a2a4bc 100644 --- a/mmengine/logging/message_hub.py +++ b/mmengine/logging/message_hub.py @@ -6,14 +6,14 @@ from typing import Any, Union import numpy as np import torch +from mmengine.utils import ManagerMixin from mmengine.visualization.utils import check_type -from .base_global_accsessible import BaseGlobalAccessible from .log_buffer import LogBuffer -class MessageHub(BaseGlobalAccessible): +class MessageHub(ManagerMixin): """Message hub for component interaction. MessageHub is created and - accessed in the same way as BaseGlobalAccessible. + accessed in the same way as ManagerMixin. ``MessageHub`` will record log information and runtime information. The log information refers to the learning rate, loss, etc. of the model @@ -52,7 +52,7 @@ class MessageHub(BaseGlobalAccessible): log_dict (str): Used for batch updating :attr:`_log_buffers`. Examples: - >>> message_hub = MessageHub.create_instance() + >>> message_hub = MessageHub.get_instance('mmengine') >>> log_dict = dict(a=1, b=2, c=3) >>> message_hub.update_log_vars(log_dict) >>> # The default count of `a`, `b` and `c` is 1. diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 27960dc2..42f61ced 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -602,7 +602,7 @@ class Runner: 'logger should be MMLogger object, a dict or None, ' f'but got {logger}') - return MMLogger.create_instance(**logger) + return MMLogger.get_instance(**logger) def build_message_hub( self, @@ -632,7 +632,7 @@ class Runner: 'message_hub should be MessageHub object, a dict or None, ' f'but got {message_hub}') - return MessageHub.create_instance(**message_hub) + return MessageHub.get_instance(**message_hub) def build_writer( self, @@ -664,7 +664,7 @@ class Runner: 'writer should be ComposedWriter object, a dict or None, ' f'but got {writer}') - return ComposedWriter.create_instance(**writer) + return ComposedWriter.get_instance(**writer) def build_model(self, model: Union[nn.Module, Dict]) -> nn.Module: """Build model. diff --git a/mmengine/utils/__init__.py b/mmengine/utils/__init__.py index 7bf023d0..5c0a56b9 100644 --- a/mmengine/utils/__init__.py +++ b/mmengine/utils/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .hub import load_url +from .manager import ManagerMeta, ManagerMixin from .misc import (check_prerequisites, concat_list, deprecated_api_warning, find_latest_checkpoint, has_method, import_modules_from_strings, is_list_of, @@ -22,5 +23,5 @@ __all__ = [ 'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple', 'is_method_overridden', 'has_method', 'mmcv_full_available', 'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url', - 'find_latest_checkpoint' + 'find_latest_checkpoint', 'ManagerMeta', 'ManagerMixin' ] diff --git a/mmengine/utils/manager.py b/mmengine/utils/manager.py new file mode 100644 index 00000000..923ba0f0 --- /dev/null +++ b/mmengine/utils/manager.py @@ -0,0 +1,161 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect +import threading +from collections import OrderedDict +from typing import Any + +_lock = threading.RLock() + + +def _accquire_lock() -> None: + """Acquire the module-level lock for serializing access to shared data. + + This should be released with _release_lock(). + """ + if _lock: + _lock.acquire() + + +def _release_lock() -> None: + """Release the module-level lock acquired by calling _accquire_lock().""" + if _lock: + _lock.release() + + +class ManagerMeta(type): + """The metaclass for global accessible class. + + The subclasses inheriting from ``ManagerMeta`` will manage their + own ``_instance_dict`` and root instances. The constructors of subclasses + must contain the ``name`` argument. + + Examples: + >>> class SubClass1(metaclass=ManagerMeta): + >>> def __init__(self, *args, **kwargs): + >>> pass + AssertionError: <class '__main__.SubClass1'>.__init__ must have the + name argument. + >>> class SubClass2(metaclass=ManagerMeta): + >>> def __init__(self, name): + >>> pass + >>> # valid format. + """ + + def __init__(cls, *args): + cls._instance_dict = OrderedDict() + params = inspect.getfullargspec(cls) + params_names = params[0] if params[0] else [] + assert 'name' in params_names, f'{cls} must have the `name` argument' + super().__init__(*args) + + +class ManagerMixin(metaclass=ManagerMeta): + """``ManagerMixin`` is the base class for classes that have global access + requirements. + + The subclasses inheriting from ``ManagerMixin`` can get their + global instances. + + Examples: + >>> class GlobalAccessible(ManagerMixin): + >>> def __init__(self, name=''): + >>> super().__init__(name) + >>> + >>> GlobalAccessible.get_instance('name') + >>> instance_1 = GlobalAccessible.get_instance('name') + >>> instance_2 = GlobalAccessible.get_instance('name') + >>> assert id(instance_1) == id(instance_2) + + Args: + name (str): Name of the instance. Defaults to ''. + """ + + def __init__(self, name: str = '', **kwargs): + self._instance_name = name + + @classmethod + def get_instance(cls, name: str, **kwargs) -> Any: + """Get subclass instance by name if the name exists. + + If corresponding name instance has not been created, ``get_instance`` + will create an instance, otherwise ``get_instance`` will return the + corresponding instance. + + Examples + >>> instance1 = GlobalAccessible.get_instance('name1') + >>> # Create name1 instance. + >>> instance.instance_name + name1 + >>> instance2 = GlobalAccessible.get_instance('name1') + >>> # Get name1 instance. + >>> assert id(instance1) == id(instance2) + + Args: + name (str): Name of instance. Defaults to ''. + + Returns: + object: Corresponding name instance, the latest instance, or root + instance. + """ + _accquire_lock() + assert isinstance(name, str), \ + f'type of name should be str, but got {type(cls)}' + instance_dict = cls._instance_dict + # Get the instance by name. + if name not in instance_dict: + instance = cls(name=name, **kwargs) + instance_dict[name] = instance + # Get latest instantiated instance or root instance. + _release_lock() + return instance_dict[name] + + @classmethod + def get_current_instance(cls): + """Get latest created instance. + + Before calling ``get_current_instance``, The subclass must have called + ``get_instance(xxx)`` at least once. + + Examples + >>> instance = GlobalAccessible.get_current_instance(current=True) + AssertionError: At least one of name and current needs to be set + >>> instance = GlobalAccessible.get_instance('name1') + >>> instance.instance_name + name1 + >>> instance = GlobalAccessible.get_current_instance(current=True) + >>> instance.instance_name + name1 + + Returns: + object: Latest created instance. + """ + _accquire_lock() + if not cls._instance_dict: + raise RuntimeError( + f'Before calling {cls.__name__}.get_instance(' + 'current=True), ' + 'you should call get_instance(name=xxx) at least once.') + name = next(iter(reversed(cls._instance_dict))) + _release_lock() + return cls._instance_dict[name] + + @classmethod + def check_instance_created(cls, name: str) -> bool: + """Check whether the name corresponding instance exists. + + Args: + name (str): Name of instance. + + Returns: + bool: Whether the name corresponding instance exists. + """ + return name in cls._instance_dict + + @property + def instance_name(self) -> str: + """Get the name of instance. + + Returns: + str: Name of instance. + """ + return self._instance_name diff --git a/mmengine/visualization/writer.py b/mmengine/visualization/writer.py index 31ab3e08..39e622fb 100644 --- a/mmengine/visualization/writer.py +++ b/mmengine/visualization/writer.py @@ -11,9 +11,8 @@ import torch from mmengine.data import BaseDataSample from mmengine.fileio import dump -from mmengine.logging import BaseGlobalAccessible from mmengine.registry import VISUALIZERS, WRITERS -from mmengine.utils import TORCH_VERSION +from mmengine.utils import TORCH_VERSION, ManagerMixin from .visualizer import Visualizer @@ -676,15 +675,15 @@ class TensorboardWriter(BaseWriter): self._tensorboard.close() -class ComposedWriter(BaseGlobalAccessible): +class ComposedWriter(ManagerMixin): """Wrapper class to compose multiple a subclass of :class:`BaseWriter` - instances. By inheriting BaseGlobalAccessible, it can be accessed anywhere - once instantiated. + instances. By inheriting ManagerMixin, it can be accessed anywhere once + instantiated. Examples: >>> from mmengine.visualization import ComposedWriter >>> import numpy as np - >>> composed_writer= ComposedWriter.create_instance( \ + >>> composed_writer= ComposedWriter.get_instance( \ 'composed_writer', writers=[dict(type='LocalWriter', \ visualizer=dict(type='DetVisualizer'), \ save_dir='temp_dir'), dict(type='WandbWriter')]) diff --git a/tests/test_logging/test_global_meta.py b/tests/test_logging/test_global_meta.py deleted file mode 100644 index 3cb7e125..00000000 --- a/tests/test_logging/test_global_meta.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import pytest - -from mmengine.logging import BaseGlobalAccessible, MetaGlobalAccessible - - -class SubClassA(BaseGlobalAccessible): - - def __init__(self, name='', *args, **kwargs): - super().__init__(name, *args, **kwargs) - - -class SubClassB(BaseGlobalAccessible): - - def __init__(self, name='', *args, **kwargs): - super().__init__(name, *args, **kwargs) - - -class TestGlobalMeta: - - def test_init(self): - # Subclass's constructor does not contain name arguments will raise an - # error. - with pytest.raises(AssertionError): - - class SubClassNoName1(metaclass=MetaGlobalAccessible): - - def __init__(self, a, *args, **kwargs): - pass - - # The constructor of subclasses must have default values for all - # arguments except name. Since `MetaGlobalAccessible` cannot tell which - # parameter does not have ha default value, we should test invalid - # subclasses separately. - with pytest.raises(AssertionError): - - class SubClassNoDefault1(metaclass=MetaGlobalAccessible): - - def __init__(self, a, name='', *args, **kwargs): - pass - - with pytest.raises(AssertionError): - - class SubClassNoDefault2(metaclass=MetaGlobalAccessible): - - def __init__(self, a, b, name='', *args, **kwargs): - pass - - # Valid subclass. - class GlobalAccessible1(metaclass=MetaGlobalAccessible): - - def __init__(self, name): - self.name = name - - # Allow name not to be the first arguments. - - class GlobalAccessible2(metaclass=MetaGlobalAccessible): - - def __init__(self, a=1, name=''): - self.name = name - - assert GlobalAccessible1.root.name == 'root' - - -class TestBaseGlobalAccessible: - - def test_init(self): - # test get root instance. - assert BaseGlobalAccessible.root._name == 'root' - # test create instance by name. - base_cls = BaseGlobalAccessible('name') - assert base_cls._name == 'name' - - def test_create_instance(self): - # SubClass should manage their own `_instance_dict`. - SubClassA.create_instance('instance_a') - SubClassB.create_instance('instance_b') - assert SubClassB._instance_dict != SubClassA._instance_dict - - # test `message_hub` can create by name. - message_hub = SubClassA.create_instance('name1') - assert message_hub.instance_name == 'name1' - # test return root message_hub - message_hub = SubClassA.create_instance() - assert message_hub.instance_name == 'root' - # test default get root `message_hub`. - - def test_get_instance(self): - message_hub = SubClassA.get_instance() - assert message_hub.instance_name == 'root' - # test default get latest `message_hub`. - message_hub = SubClassA.create_instance('name2') - message_hub = SubClassA.get_instance(current=True) - assert message_hub.instance_name == 'name2' - message_hub.mark = -1 - # test get latest `message_hub` repeatedly. - message_hub = SubClassA.create_instance('name3') - assert message_hub.instance_name == 'name3' - message_hub = SubClassA.get_instance(current=True) - assert message_hub.instance_name == 'name3' - # test get root repeatedly. - message_hub = SubClassA.get_instance() - assert message_hub.instance_name == 'root' - # test get name1 repeatedly - message_hub = SubClassA.get_instance('name2') - assert message_hub.mark == -1 - # create_instance will raise error if `name` is not specified and - # given other arguments - with pytest.raises(ValueError): - SubClassA.create_instance(a=1) diff --git a/tests/test_logging/test_logger.py b/tests/test_logging/test_logger.py index a81078dd..8c41fe24 100644 --- a/tests/test_logging/test_logger.py +++ b/tests/test_logging/test_logger.py @@ -17,7 +17,7 @@ class TestLogger: @patch('torch.distributed.is_initialized', lambda: True) @patch('torch.distributed.is_available', lambda: True) def test_init_rank0(self, tmp_path): - logger = MMLogger.create_instance('rank0.pkg1', log_level='INFO') + logger = MMLogger.get_instance('rank0.pkg1', log_level='INFO') assert logger.name == 'rank0.pkg1' assert logger.instance_name == 'rank0.pkg1' # Logger get from `MMLogger.get_instance` does not inherit from @@ -30,7 +30,7 @@ class TestLogger: # If `rank=0`, the `log_level` of stream_handler and file_handler # depends on the given arguments. tmp_file = tmp_path / 'tmp_file.log' - logger = MMLogger.create_instance( + logger = MMLogger.get_instance( 'rank0.pkg2', log_level='INFO', log_file=str(tmp_file)) assert isinstance(logger, logging.Logger) assert len(logger.handlers) == 2 @@ -47,7 +47,7 @@ class TestLogger: # If `rank!=1`, the `loglevel` of file_handler is `logging.ERROR`. tmp_file = tmp_path / 'tmp_file.log' log_path = tmp_path / 'rank1_tmp_file.log' - logger = MMLogger.create_instance( + logger = MMLogger.get_instance( 'rank1.pkg2', log_level='INFO', log_file=str(tmp_file)) assert len(logger.handlers) == 2 assert logger.handlers[0].level == logging.ERROR @@ -60,7 +60,7 @@ class TestLogger: def test_handler(self, capsys, tmp_path, log_level): # test stream handler can output correct format logs logger_name = f'test_stream_{str(log_level)}' - logger = MMLogger.create_instance(logger_name, log_level=log_level) + logger = MMLogger.get_instance(logger_name, log_level=log_level) logger.log(level=log_level, msg='welcome') out, _ = capsys.readouterr() # Skip match colored INFO @@ -73,7 +73,7 @@ class TestLogger: # test file_handler output plain text without color. tmp_file = tmp_path / 'tmp_file.log' logger_name = f'test_file_{log_level}' - logger = MMLogger.create_instance( + logger = MMLogger.get_instance( logger_name, log_level=log_level, log_file=tmp_file) logger.log(level=log_level, msg='welcome') with open(tmp_file, 'r') as f: @@ -87,7 +87,7 @@ class TestLogger: def test_erro_format(self, capsys): # test error level log can output file path, function name and # line number - logger = MMLogger.create_instance('test_error', log_level='INFO') + logger = MMLogger.get_instance('test_error', log_level='INFO') logger.error('welcome') lineno = sys._getframe().f_lineno - 1 file_path = __file__ @@ -109,7 +109,7 @@ class TestLogger: print_log('welcome', logger='silent') out, _ = capsys.readouterr() assert out == '' - logger = MMLogger.create_instance('test_print_log') + logger = MMLogger.get_instance('test_print_log') # Test using specified logger print_log('welcome', logger=logger) out, _ = capsys.readouterr() diff --git a/tests/test_logging/test_message_hub.py b/tests/test_logging/test_message_hub.py index c9ec25a1..4623c3b5 100644 --- a/tests/test_logging/test_message_hub.py +++ b/tests/test_logging/test_message_hub.py @@ -15,7 +15,7 @@ class TestMessageHub: assert len(message_hub.log_buffers) == 0 def test_update_log(self): - message_hub = MessageHub.create_instance() + message_hub = MessageHub.get_instance('mmengine') # test create target `LogBuffer` by name message_hub.update_log('name', 1) log_buffer = message_hub.log_buffers['name'] @@ -26,7 +26,7 @@ class TestMessageHub: # unmatched string will raise a key error def test_update_info(self): - message_hub = MessageHub.create_instance() + message_hub = MessageHub.get_instance('mmengine') # test runtime value can be overwritten. message_hub.update_info('key', 2) assert message_hub.runtime_info['key'] == 2 @@ -34,7 +34,7 @@ class TestMessageHub: assert message_hub.runtime_info['key'] == 1 def test_get_log_buffers(self): - message_hub = MessageHub.create_instance() + message_hub = MessageHub.get_instance('mmengine') # Get undefined key will raise error with pytest.raises(KeyError): message_hub.get_log('unknown') @@ -50,7 +50,7 @@ class TestMessageHub: assert (recorded_count == count).all() def test_get_runtime(self): - message_hub = MessageHub.create_instance() + message_hub = MessageHub.get_instance('mmengine') with pytest.raises(KeyError): message_hub.get_info('unknown') recorded_dict = dict(a=1, b=2) @@ -58,7 +58,7 @@ class TestMessageHub: assert message_hub.get_info('test_value') == recorded_dict def test_get_log_vars(self): - message_hub = MessageHub.create_instance() + message_hub = MessageHub.get_instance('mmengine') log_dict = dict( loss=1, loss_cls=torch.tensor(2), diff --git a/tests/test_utils/test_manager.py b/tests/test_utils/test_manager.py new file mode 100644 index 00000000..f51b397f --- /dev/null +++ b/tests/test_utils/test_manager.py @@ -0,0 +1,72 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest + +from mmengine.utils import ManagerMeta, ManagerMixin + + +class SubClassA(ManagerMixin): + + def __init__(self, name='', *args, **kwargs): + super().__init__(name, *args, **kwargs) + + +class SubClassB(ManagerMixin): + + def __init__(self, name='', *args, **kwargs): + super().__init__(name, *args, **kwargs) + + +class TestGlobalMeta: + + def test_init(self): + # Subclass's constructor does not contain name arguments will raise an + # error. + with pytest.raises(AssertionError): + + class SubClassNoName1(metaclass=ManagerMeta): + + def __init__(self, a, *args, **kwargs): + pass + + # Valid subclass. + class GlobalAccessible1(metaclass=ManagerMeta): + + def __init__(self, name): + self.name = name + + +class TestManagerMixin: + + def test_init(self): + # test create instance by name. + base_cls = ManagerMixin('name') + assert base_cls.instance_name == 'name' + + def test_get_instance(self): + # SubClass should manage their own `_instance_dict`. + with pytest.raises(RuntimeError): + SubClassA.get_current_instance() + SubClassA.get_instance('instance_a') + SubClassB.get_instance('instance_b') + assert SubClassB._instance_dict != SubClassA._instance_dict + + # Test `message_hub` can create by name. + message_hub = SubClassA.get_instance('name1') + assert message_hub.instance_name == 'name1' + # No arguments will raise an assertion error. + + SubClassA.get_instance('name2') + message_hub = SubClassA.get_current_instance() + message_hub.mark = -1 + assert message_hub.instance_name == 'name2' + # Test get latest `message_hub` repeatedly. + message_hub = SubClassA.get_instance('name3') + assert message_hub.instance_name == 'name3' + message_hub = SubClassA.get_current_instance() + assert message_hub.instance_name == 'name3' + # Test get name2 repeatedly. + message_hub = SubClassA.get_instance('name2') + assert message_hub.mark == -1 + # Non-string instance name will raise `AssertionError`. + with pytest.raises(AssertionError): + SubClassA.get_instance(name=1) diff --git a/tests/test_visualizer/test_writer.py b/tests/test_visualizer/test_writer.py index 5219a2a4..447a246d 100644 --- a/tests/test_visualizer/test_writer.py +++ b/tests/test_visualizer/test_writer.py @@ -332,7 +332,7 @@ class TestComposedWriter: assert len(composed_writer._writers) == 2 # test global - composed_writer = ComposedWriter.create_instance( + composed_writer = ComposedWriter.get_instance( 'composed_writer', writers=[ WandbWriter(), -- GitLab