From 4742544b2564cbd1e297dc036021180c0b94bfe6 Mon Sep 17 00:00:00 2001
From: RangiLyu <lyuchqi@gmail.com>
Date: Thu, 5 May 2022 11:59:51 +0800
Subject: [PATCH] [Feature] Support collect all registered module info. (#193)

* [Feature] Support collect all registered module info.

* update

* update

* add unit tests

* add to runner

* resolve comments
---
 mmengine/registry/__init__.py              |  3 +-
 mmengine/registry/registry.py              |  4 ++
 mmengine/registry/utils.py                 | 78 ++++++++++++++++++++++
 mmengine/runner/runner.py                  |  9 ++-
 tests/test_registry/test_registry_utils.py | 62 +++++++++++++++++
 5 files changed, 154 insertions(+), 2 deletions(-)
 create mode 100644 mmengine/registry/utils.py
 create mode 100644 tests/test_registry/test_registry_utils.py

diff --git a/mmengine/registry/__init__.py b/mmengine/registry/__init__.py
index 56c65b80..5d05ac24 100644
--- a/mmengine/registry/__init__.py
+++ b/mmengine/registry/__init__.py
@@ -5,11 +5,12 @@ from .root import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, METRICS,
                    MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
                    PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS,
                    TRANSFORMS, VISBACKENDS, VISUALIZERS, WEIGHT_INITIALIZERS)
+from .utils import count_registered_modules, traverse_registry_tree
 
 __all__ = [
     'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS',
     'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS',
     'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS',
     'METRICS', 'MODEL_WRAPPERS', 'LOOPS', 'VISBACKENDS', 'VISUALIZERS',
-    'DefaultScope'
+    'DefaultScope', 'traverse_registry_tree', 'count_registered_modules'
 ]
diff --git a/mmengine/registry/registry.py b/mmengine/registry/registry.py
index bb064045..37734da4 100644
--- a/mmengine/registry/registry.py
+++ b/mmengine/registry/registry.py
@@ -256,6 +256,10 @@ class Registry:
     def children(self):
         return self._children
 
+    @property
+    def root(self):
+        return self._get_root_registry()
+
     def _get_root_registry(self) -> 'Registry':
         """Return the root registry."""
         root = self
diff --git a/mmengine/registry/utils.py b/mmengine/registry/utils.py
new file mode 100644
index 00000000..3ace5f5f
--- /dev/null
+++ b/mmengine/registry/utils.py
@@ -0,0 +1,78 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import datetime
+import os.path as osp
+from typing import Optional
+
+from mmengine.fileio import dump
+from . import root
+from .registry import Registry
+
+
+def traverse_registry_tree(registry: Registry, verbose: bool = True) -> list:
+    """Traverse the whole registry tree from any given node, and collect
+    information of all registered modules in this registry tree.
+
+    Args:
+        registry (Registry): a registry node in the registry tree.
+        verbose (bool): Whether to print log. Default: True
+
+    Returns:
+        list: Statistic results of all modules in each node of the registry
+        tree.
+    """
+    root_registry = registry.root
+    modules_info = []
+
+    def _dfs_registry(_registry):
+        if isinstance(_registry, Registry):
+            num_modules = len(_registry.module_dict)
+            scope = _registry.scope
+            registry_info = dict(num_modules=num_modules, scope=scope)
+            for name, registered_class in _registry.module_dict.items():
+                folder = '/'.join(registered_class.__module__.split('.')[:-1])
+                if folder in registry_info:
+                    registry_info[folder].append(name)
+                else:
+                    registry_info[folder] = [name]
+            if verbose:
+                print(f"Find {num_modules} modules in {scope}'s "
+                      f"'{_registry.name}' registry ")
+            modules_info.append(registry_info)
+        else:
+            return
+        for _, child in _registry.children.items():
+            _dfs_registry(child)
+
+    _dfs_registry(root_registry)
+    return modules_info
+
+
+def count_registered_modules(save_path: Optional[str] = None,
+                             verbose: bool = True) -> dict:
+    """Scan all modules in MMEngine's root and child registries and dump to
+    json.
+
+    Args:
+        save_path (str, optional): Path to save the json file.
+        verbose (bool): Whether to print log. Default: True
+    Returns:
+        dict: Statistic results of all registered modules.
+    """
+    registries_info = {}
+    # traverse all registries in MMEngine
+    for item in dir(root):
+        if not item.startswith('__'):
+            registry = getattr(root, item)
+            if isinstance(registry, Registry):
+                registries_info[item] = traverse_registry_tree(
+                    registry, verbose)
+    scan_data = dict(
+        scan_date=datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
+        registries=registries_info)
+    if verbose:
+        print('Finish registry analysis, got: ', scan_data)
+    if save_path is not None:
+        json_path = osp.join(save_path, 'modules_statistic_results.json')
+        dump(scan_data, json_path, indent=2)
+        print(f'Result has been saved to {json_path}')
+    return scan_data
diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py
index 182eefa1..98cddde6 100644
--- a/mmengine/runner/runner.py
+++ b/mmengine/runner/runner.py
@@ -30,7 +30,8 @@ from mmengine.model import is_model_wrapper
 from mmengine.optim import _ParamScheduler, build_optimizer
 from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
                                MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS,
-                               VISUALIZERS, DefaultScope)
+                               VISUALIZERS, DefaultScope,
+                               count_registered_modules)
 from mmengine.utils import (TORCH_VERSION, digit_version,
                             find_latest_checkpoint, is_list_of, symlink)
 from mmengine.visualization import Visualizer
@@ -328,6 +329,12 @@ class Runner:
         # Since `get_instance` could return any subclass of ManagerMixin. The
         # corresponding attribute needs a type hint.
         self.logger = self.build_logger(log_level=log_level)
+
+        # collect information of all modules registered in the registries
+        registries_info = count_registered_modules(
+            self.work_dir if self.rank == 0 else None, verbose=False)
+        self.logger.debug(registries_info)
+
         # Build `message_hub` for communication among components.
         # `message_hub` can store log scalars (loss, learning rate) and
         # runtime information (iter and epoch). Those components that do not
diff --git a/tests/test_registry/test_registry_utils.py b/tests/test_registry/test_registry_utils.py
new file mode 100644
index 00000000..ce903f0b
--- /dev/null
+++ b/tests/test_registry/test_registry_utils.py
@@ -0,0 +1,62 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from tempfile import TemporaryDirectory
+from unittest import TestCase
+
+from mmengine.registry import (Registry, count_registered_modules, root,
+                               traverse_registry_tree)
+
+
+class TestUtils(TestCase):
+
+    def test_traverse_registry_tree(self):
+        #        Hierarchical Registry
+        #                           DOGS
+        #                      _______|_______
+        #                     |               |
+        #            HOUNDS (hound)          SAMOYEDS (samoyed)
+        #           _______|_______                |
+        #          |               |               |
+        #     LITTLE_HOUNDS    MID_HOUNDS   LITTLE_SAMOYEDS
+        #     (little_hound)   (mid_hound)  (little_samoyed)
+        DOGS = Registry('dogs')
+        HOUNDS = Registry('dogs', parent=DOGS, scope='hound')
+        LITTLE_HOUNDS = Registry(  # noqa
+            'dogs', parent=HOUNDS, scope='little_hound')
+        MID_HOUNDS = Registry('dogs', parent=HOUNDS, scope='mid_hound')
+        SAMOYEDS = Registry('dogs', parent=DOGS, scope='samoyed')
+        LITTLE_SAMOYEDS = Registry(  # noqa
+            'dogs', parent=SAMOYEDS, scope='little_samoyed')
+
+        @DOGS.register_module()
+        class GoldenRetriever:
+            pass
+
+        # traversing the tree from the root
+        result = traverse_registry_tree(DOGS)
+        self.assertEqual(result[0]['num_modules'], 1)
+        self.assertEqual(len(result), 6)
+
+        # traversing the tree from leaf node
+        result_leaf = traverse_registry_tree(MID_HOUNDS)
+        # result from any node should be the same
+        self.assertEqual(result, result_leaf)
+
+    def test_count_all_registered_modules(self):
+        temp_dir = TemporaryDirectory()
+        results = count_registered_modules(temp_dir.name, verbose=True)
+        self.assertTrue(
+            osp.exists(
+                osp.join(temp_dir.name, 'modules_statistic_results.json')))
+        registries_info = results['registries']
+        for registry in registries_info:
+            self.assertTrue(hasattr(root, registry))
+            self.assertEqual(registries_info[registry][0]['num_modules'],
+                             len(getattr(root, registry).module_dict))
+        temp_dir.cleanup()
+
+        # test not saving results
+        count_registered_modules(save_path=None, verbose=False)
+        self.assertFalse(
+            osp.exists(
+                osp.join(temp_dir.name, 'modules_statistic_results.json')))
-- 
GitLab