Skip to content
Snippets Groups Projects
Unverified Commit 4742544b authored by RangiLyu's avatar RangiLyu Committed by GitHub
Browse files

[Feature] Support collect all registered module info. (#193)

* [Feature] Support collect all registered module info.

* update

* update

* add unit tests

* add to runner

* resolve comments
parent 2da59bb1
No related branches found
No related tags found
No related merge requests found
...@@ -5,11 +5,12 @@ from .root import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, METRICS, ...@@ -5,11 +5,12 @@ from .root import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, METRICS,
MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS,
TRANSFORMS, VISBACKENDS, VISUALIZERS, WEIGHT_INITIALIZERS) TRANSFORMS, VISBACKENDS, VISUALIZERS, WEIGHT_INITIALIZERS)
from .utils import count_registered_modules, traverse_registry_tree
__all__ = [ __all__ = [
'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS',
'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', 'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS',
'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS', 'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS',
'METRICS', 'MODEL_WRAPPERS', 'LOOPS', 'VISBACKENDS', 'VISUALIZERS', 'METRICS', 'MODEL_WRAPPERS', 'LOOPS', 'VISBACKENDS', 'VISUALIZERS',
'DefaultScope' 'DefaultScope', 'traverse_registry_tree', 'count_registered_modules'
] ]
...@@ -256,6 +256,10 @@ class Registry: ...@@ -256,6 +256,10 @@ class Registry:
def children(self): def children(self):
return self._children return self._children
@property
def root(self):
return self._get_root_registry()
def _get_root_registry(self) -> 'Registry': def _get_root_registry(self) -> 'Registry':
"""Return the root registry.""" """Return the root registry."""
root = self root = self
......
# Copyright (c) OpenMMLab. All rights reserved.
import datetime
import os.path as osp
from typing import Optional
from mmengine.fileio import dump
from . import root
from .registry import Registry
def traverse_registry_tree(registry: Registry, verbose: bool = True) -> list:
"""Traverse the whole registry tree from any given node, and collect
information of all registered modules in this registry tree.
Args:
registry (Registry): a registry node in the registry tree.
verbose (bool): Whether to print log. Default: True
Returns:
list: Statistic results of all modules in each node of the registry
tree.
"""
root_registry = registry.root
modules_info = []
def _dfs_registry(_registry):
if isinstance(_registry, Registry):
num_modules = len(_registry.module_dict)
scope = _registry.scope
registry_info = dict(num_modules=num_modules, scope=scope)
for name, registered_class in _registry.module_dict.items():
folder = '/'.join(registered_class.__module__.split('.')[:-1])
if folder in registry_info:
registry_info[folder].append(name)
else:
registry_info[folder] = [name]
if verbose:
print(f"Find {num_modules} modules in {scope}'s "
f"'{_registry.name}' registry ")
modules_info.append(registry_info)
else:
return
for _, child in _registry.children.items():
_dfs_registry(child)
_dfs_registry(root_registry)
return modules_info
def count_registered_modules(save_path: Optional[str] = None,
verbose: bool = True) -> dict:
"""Scan all modules in MMEngine's root and child registries and dump to
json.
Args:
save_path (str, optional): Path to save the json file.
verbose (bool): Whether to print log. Default: True
Returns:
dict: Statistic results of all registered modules.
"""
registries_info = {}
# traverse all registries in MMEngine
for item in dir(root):
if not item.startswith('__'):
registry = getattr(root, item)
if isinstance(registry, Registry):
registries_info[item] = traverse_registry_tree(
registry, verbose)
scan_data = dict(
scan_date=datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
registries=registries_info)
if verbose:
print('Finish registry analysis, got: ', scan_data)
if save_path is not None:
json_path = osp.join(save_path, 'modules_statistic_results.json')
dump(scan_data, json_path, indent=2)
print(f'Result has been saved to {json_path}')
return scan_data
...@@ -30,7 +30,8 @@ from mmengine.model import is_model_wrapper ...@@ -30,7 +30,8 @@ from mmengine.model import is_model_wrapper
from mmengine.optim import _ParamScheduler, build_optimizer from mmengine.optim import _ParamScheduler, build_optimizer
from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS, MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS,
VISUALIZERS, DefaultScope) VISUALIZERS, DefaultScope,
count_registered_modules)
from mmengine.utils import (TORCH_VERSION, digit_version, from mmengine.utils import (TORCH_VERSION, digit_version,
find_latest_checkpoint, is_list_of, symlink) find_latest_checkpoint, is_list_of, symlink)
from mmengine.visualization import Visualizer from mmengine.visualization import Visualizer
...@@ -328,6 +329,12 @@ class Runner: ...@@ -328,6 +329,12 @@ class Runner:
# Since `get_instance` could return any subclass of ManagerMixin. The # Since `get_instance` could return any subclass of ManagerMixin. The
# corresponding attribute needs a type hint. # corresponding attribute needs a type hint.
self.logger = self.build_logger(log_level=log_level) self.logger = self.build_logger(log_level=log_level)
# collect information of all modules registered in the registries
registries_info = count_registered_modules(
self.work_dir if self.rank == 0 else None, verbose=False)
self.logger.debug(registries_info)
# Build `message_hub` for communication among components. # Build `message_hub` for communication among components.
# `message_hub` can store log scalars (loss, learning rate) and # `message_hub` can store log scalars (loss, learning rate) and
# runtime information (iter and epoch). Those components that do not # runtime information (iter and epoch). Those components that do not
......
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from tempfile import TemporaryDirectory
from unittest import TestCase
from mmengine.registry import (Registry, count_registered_modules, root,
traverse_registry_tree)
class TestUtils(TestCase):
def test_traverse_registry_tree(self):
# Hierarchical Registry
# DOGS
# _______|_______
# | |
# HOUNDS (hound) SAMOYEDS (samoyed)
# _______|_______ |
# | | |
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
# (little_hound) (mid_hound) (little_samoyed)
DOGS = Registry('dogs')
HOUNDS = Registry('dogs', parent=DOGS, scope='hound')
LITTLE_HOUNDS = Registry( # noqa
'dogs', parent=HOUNDS, scope='little_hound')
MID_HOUNDS = Registry('dogs', parent=HOUNDS, scope='mid_hound')
SAMOYEDS = Registry('dogs', parent=DOGS, scope='samoyed')
LITTLE_SAMOYEDS = Registry( # noqa
'dogs', parent=SAMOYEDS, scope='little_samoyed')
@DOGS.register_module()
class GoldenRetriever:
pass
# traversing the tree from the root
result = traverse_registry_tree(DOGS)
self.assertEqual(result[0]['num_modules'], 1)
self.assertEqual(len(result), 6)
# traversing the tree from leaf node
result_leaf = traverse_registry_tree(MID_HOUNDS)
# result from any node should be the same
self.assertEqual(result, result_leaf)
def test_count_all_registered_modules(self):
temp_dir = TemporaryDirectory()
results = count_registered_modules(temp_dir.name, verbose=True)
self.assertTrue(
osp.exists(
osp.join(temp_dir.name, 'modules_statistic_results.json')))
registries_info = results['registries']
for registry in registries_info:
self.assertTrue(hasattr(root, registry))
self.assertEqual(registries_info[registry][0]['num_modules'],
len(getattr(root, registry).module_dict))
temp_dir.cleanup()
# test not saving results
count_registered_modules(save_path=None, verbose=False)
self.assertFalse(
osp.exists(
osp.join(temp_dir.name, 'modules_statistic_results.json')))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment