From 4742544b2564cbd1e297dc036021180c0b94bfe6 Mon Sep 17 00:00:00 2001 From: RangiLyu <lyuchqi@gmail.com> Date: Thu, 5 May 2022 11:59:51 +0800 Subject: [PATCH] [Feature] Support collect all registered module info. (#193) * [Feature] Support collect all registered module info. * update * update * add unit tests * add to runner * resolve comments --- mmengine/registry/__init__.py | 3 +- mmengine/registry/registry.py | 4 ++ mmengine/registry/utils.py | 78 ++++++++++++++++++++++ mmengine/runner/runner.py | 9 ++- tests/test_registry/test_registry_utils.py | 62 +++++++++++++++++ 5 files changed, 154 insertions(+), 2 deletions(-) create mode 100644 mmengine/registry/utils.py create mode 100644 tests/test_registry/test_registry_utils.py diff --git a/mmengine/registry/__init__.py b/mmengine/registry/__init__.py index 56c65b80..5d05ac24 100644 --- a/mmengine/registry/__init__.py +++ b/mmengine/registry/__init__.py @@ -5,11 +5,12 @@ from .root import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, METRICS, MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS, TRANSFORMS, VISBACKENDS, VISUALIZERS, WEIGHT_INITIALIZERS) +from .utils import count_registered_modules, traverse_registry_tree __all__ = [ 'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', 'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS', 'LOOPS', 'VISBACKENDS', 'VISUALIZERS', - 'DefaultScope' + 'DefaultScope', 'traverse_registry_tree', 'count_registered_modules' ] diff --git a/mmengine/registry/registry.py b/mmengine/registry/registry.py index bb064045..37734da4 100644 --- a/mmengine/registry/registry.py +++ b/mmengine/registry/registry.py @@ -256,6 +256,10 @@ class Registry: def children(self): return self._children + @property + def root(self): + return self._get_root_registry() + def _get_root_registry(self) -> 'Registry': """Return the root registry.""" root = self diff --git a/mmengine/registry/utils.py b/mmengine/registry/utils.py new file mode 100644 index 00000000..3ace5f5f --- /dev/null +++ b/mmengine/registry/utils.py @@ -0,0 +1,78 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import datetime +import os.path as osp +from typing import Optional + +from mmengine.fileio import dump +from . import root +from .registry import Registry + + +def traverse_registry_tree(registry: Registry, verbose: bool = True) -> list: + """Traverse the whole registry tree from any given node, and collect + information of all registered modules in this registry tree. + + Args: + registry (Registry): a registry node in the registry tree. + verbose (bool): Whether to print log. Default: True + + Returns: + list: Statistic results of all modules in each node of the registry + tree. + """ + root_registry = registry.root + modules_info = [] + + def _dfs_registry(_registry): + if isinstance(_registry, Registry): + num_modules = len(_registry.module_dict) + scope = _registry.scope + registry_info = dict(num_modules=num_modules, scope=scope) + for name, registered_class in _registry.module_dict.items(): + folder = '/'.join(registered_class.__module__.split('.')[:-1]) + if folder in registry_info: + registry_info[folder].append(name) + else: + registry_info[folder] = [name] + if verbose: + print(f"Find {num_modules} modules in {scope}'s " + f"'{_registry.name}' registry ") + modules_info.append(registry_info) + else: + return + for _, child in _registry.children.items(): + _dfs_registry(child) + + _dfs_registry(root_registry) + return modules_info + + +def count_registered_modules(save_path: Optional[str] = None, + verbose: bool = True) -> dict: + """Scan all modules in MMEngine's root and child registries and dump to + json. + + Args: + save_path (str, optional): Path to save the json file. + verbose (bool): Whether to print log. Default: True + Returns: + dict: Statistic results of all registered modules. + """ + registries_info = {} + # traverse all registries in MMEngine + for item in dir(root): + if not item.startswith('__'): + registry = getattr(root, item) + if isinstance(registry, Registry): + registries_info[item] = traverse_registry_tree( + registry, verbose) + scan_data = dict( + scan_date=datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), + registries=registries_info) + if verbose: + print('Finish registry analysis, got: ', scan_data) + if save_path is not None: + json_path = osp.join(save_path, 'modules_statistic_results.json') + dump(scan_data, json_path, indent=2) + print(f'Result has been saved to {json_path}') + return scan_data diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 182eefa1..98cddde6 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -30,7 +30,8 @@ from mmengine.model import is_model_wrapper from mmengine.optim import _ParamScheduler, build_optimizer from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS, - VISUALIZERS, DefaultScope) + VISUALIZERS, DefaultScope, + count_registered_modules) from mmengine.utils import (TORCH_VERSION, digit_version, find_latest_checkpoint, is_list_of, symlink) from mmengine.visualization import Visualizer @@ -328,6 +329,12 @@ class Runner: # Since `get_instance` could return any subclass of ManagerMixin. The # corresponding attribute needs a type hint. self.logger = self.build_logger(log_level=log_level) + + # collect information of all modules registered in the registries + registries_info = count_registered_modules( + self.work_dir if self.rank == 0 else None, verbose=False) + self.logger.debug(registries_info) + # Build `message_hub` for communication among components. # `message_hub` can store log scalars (loss, learning rate) and # runtime information (iter and epoch). Those components that do not diff --git a/tests/test_registry/test_registry_utils.py b/tests/test_registry/test_registry_utils.py new file mode 100644 index 00000000..ce903f0b --- /dev/null +++ b/tests/test_registry/test_registry_utils.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from tempfile import TemporaryDirectory +from unittest import TestCase + +from mmengine.registry import (Registry, count_registered_modules, root, + traverse_registry_tree) + + +class TestUtils(TestCase): + + def test_traverse_registry_tree(self): + # Hierarchical Registry + # DOGS + # _______|_______ + # | | + # HOUNDS (hound) SAMOYEDS (samoyed) + # _______|_______ | + # | | | + # LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS + # (little_hound) (mid_hound) (little_samoyed) + DOGS = Registry('dogs') + HOUNDS = Registry('dogs', parent=DOGS, scope='hound') + LITTLE_HOUNDS = Registry( # noqa + 'dogs', parent=HOUNDS, scope='little_hound') + MID_HOUNDS = Registry('dogs', parent=HOUNDS, scope='mid_hound') + SAMOYEDS = Registry('dogs', parent=DOGS, scope='samoyed') + LITTLE_SAMOYEDS = Registry( # noqa + 'dogs', parent=SAMOYEDS, scope='little_samoyed') + + @DOGS.register_module() + class GoldenRetriever: + pass + + # traversing the tree from the root + result = traverse_registry_tree(DOGS) + self.assertEqual(result[0]['num_modules'], 1) + self.assertEqual(len(result), 6) + + # traversing the tree from leaf node + result_leaf = traverse_registry_tree(MID_HOUNDS) + # result from any node should be the same + self.assertEqual(result, result_leaf) + + def test_count_all_registered_modules(self): + temp_dir = TemporaryDirectory() + results = count_registered_modules(temp_dir.name, verbose=True) + self.assertTrue( + osp.exists( + osp.join(temp_dir.name, 'modules_statistic_results.json'))) + registries_info = results['registries'] + for registry in registries_info: + self.assertTrue(hasattr(root, registry)) + self.assertEqual(registries_info[registry][0]['num_modules'], + len(getattr(root, registry).module_dict)) + temp_dir.cleanup() + + # test not saving results + count_registered_modules(save_path=None, verbose=False) + self.assertFalse( + osp.exists( + osp.join(temp_dir.name, 'modules_statistic_results.json'))) -- GitLab