From 5016332588950bc452fbac1c4c93b333cd7d27e0 Mon Sep 17 00:00:00 2001
From: Alex Yang <50511903+imabackstabber@users.noreply.github.com>
Date: Tue, 14 Jun 2022 14:50:24 +0800
Subject: [PATCH] [Feat] support registering function (#302)

---
 docs/zh_cn/tutorials/registry.md              | 28 ++++++++---
 .../model/base_model/data_preprocessor.py     |  2 +-
 mmengine/registry/registry.py                 | 48 +++++++++++--------
 tests/test_registry/test_registry.py          | 36 +++++++++-----
 4 files changed, 75 insertions(+), 39 deletions(-)

diff --git a/docs/zh_cn/tutorials/registry.md b/docs/zh_cn/tutorials/registry.md
index 5ed7b143..b59541b8 100644
--- a/docs/zh_cn/tutorials/registry.md
+++ b/docs/zh_cn/tutorials/registry.md
@@ -6,11 +6,11 @@ OpenMMLab 大多数算法库均使用注册器来管理他们的代码模块,
 
 ## 什么是注册器
 
-MMEngine 实现的注册器可以看作一个映射表和模块构建方法(build function)的组合。映射表维护了一个字符串到类的映射,使得用户可以借助字符串查找到相应的类,例如维护字符串 `"ResNet"` 到 `ResNet` 类的映射,使得用户可以通过 `"ResNet"` 找到 `ResNet` 类。
-而模块构建方法则定义了如何根据字符串查找到对应的类,并定义了如何实例化这个类,例如根据规则通过字符串 `"bn"` 找到 `nn.BatchNorm2d`,并且实例化 `BatchNorm2d` 模块。
+MMEngine 实现的注册器可以看作一个映射表和模块构建方法(build function)的组合。映射表维护了一个字符串到类或者函数的映射,使得用户可以借助字符串查找到相应的类或函数,例如维护字符串 `"ResNet"` 到 `ResNet` 类或函数的映射,使得用户可以通过 `"ResNet"` 找到 `ResNet` 类或函数;
+而模块构建方法则定义了如何根据字符串查找到对应的类或函数,并定义了如何实例化这个类或调用这个函数,例如根据规则通过字符串 `"bn"` 找到 `nn.BatchNorm2d`,并且实例化 `BatchNorm2d` 模块。又或者根据规则通过字符串 `"bn"` 找到 `build_batchnorm2d`,并且调用函数获得 `BatchNorm2d` 模块。
 MMEngine 中的注册器默认使用 [build_from_cfg 函数](https://mmengine.readthedocs.io/zh_CN/latest/api.html#mmengine.registry.build_from_cfg) 来查找并实例化字符串对应的类。
 
-一个注册器管理的类通常有相似的接口和功能,因此该注册器可以被视作这些类的抽象。例如注册器 `Classifier` 可以被视作所有分类网络的抽象,管理了 `ResNet`, `SEResNet` 和 `RegNetX` 等分类网络的类。
+一个注册器管理的类或函数通常有相似的接口和功能,因此该注册器可以被视作这些类或函数的抽象。例如注册器 `Classifier` 可以被视作所有分类网络的抽象,管理了 `ResNet`, `SEResNet` 和 `RegNetX` 等分类网络的类以及 `build_ResNet`,  `build_SEResNet` 和 `build_RegNetX` 等分类网络的构建函数。
 使用注册器管理功能相似的模块可以显著提高代码的扩展性和灵活性。用户可以跳至`使用注册器提高代码的扩展性`章节了解注册器是如何提高代码拓展性的。
 
 ## 入门用法
@@ -32,10 +32,10 @@ from mmengine import Registry
 CONVERTERS = Registry('converter')
 ```
 
-然后我们可以实现不同的转换器。
+然后我们可以实现不同的转换器。例如,在 `converters/converter_cls.py` 中实现 `Converter1` 和 `Converter2`,在 `converters/converter_func.py` 中实现 `converter3`。
 
 ```python
-# converters/converter.py
+# converters/converter_cls.py
 from .builder import CONVERTERS
 
 # 使用注册器管理模块
@@ -53,12 +53,23 @@ class Converter2(object):
         self.c = c
 ```
 
-使用注册器管理模块的关键步骤是,将实现的模块注册到注册表 `CONVERTERS` 中。通过 `@CONVERTERS.register_module()` 装饰所实现的模块,字符串和类之间的映射就可以由 `CONVERTERS` 构建和维护,我们也可以通过 `CONVERTERS.register_module(module=Converter1)` 实现同样的功能。
+```python
+# converters/converter_func.py
+from .builder import CONVERTERS
+from .converter_cls import Converter1
+@CONVERTERS.register_module()
+def converter3(a, b)
+    return Converter1(a, b)
+```
+
+使用注册器管理模块的关键步骤是,将实现的模块注册到注册表 `CONVERTERS` 中。通过 `@CONVERTERS.register_module()` 装饰所实现的模块,字符串和类或函数之间的映射就可以由 `CONVERTERS` 构建和维护,我们也可以通过 `CONVERTERS.register_module(module=Converter1)` 实现同样的功能。
 
-通过注册,我们就可以通过 `CONVERTERS` 建立字符串与类之间的映射,
+通过注册,我们就可以通过 `CONVERTERS` 建立字符串与类或函数之间的映射,
 
 ```python
 'Converter1' -> <class 'Converter1'>
+'Converter2' -> <class 'Converter2'>
+'Converter3' -> <function 'Converter3'>
 ```
 
 ```{note}
@@ -72,6 +83,9 @@ class Converter2(object):
 # 注意,converter_cfg 可以通过解析配置文件得到
 converter_cfg = dict(type='Converter1', a=a_value, b=b_value)
 converter = CONVERTERS.build(converter_cfg)
+converter3_cfg = dict(type='converter3', a=a_value, b=b_value)
+# returns the calling result
+converter3 = CONVERTERS.build(converter3_cfg)
 ```
 
 如果我们想使用 `Converter2`,仅需修改配置。
diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py
index 2640ce55..9fa5b48e 100644
--- a/mmengine/model/base_model/data_preprocessor.py
+++ b/mmengine/model/base_model/data_preprocessor.py
@@ -191,7 +191,7 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
                 f'RGB or gray image, but got {len(mean)}')
             assert len(std) == 3 or len(std) == 1, (  # type: ignore
                 'The length of std should be 1 or 3 to be compatible with RGB '  # type: ignore # noqa: E501
-                f'or gray image, but got {len(std)}')
+                f'or gray image, but got {len(std)}')  # type: ignore
             self._enable_normalize = True
             self.register_buffer('mean',
                                  torch.tensor(mean).view(-1, 1, 1), False)
diff --git a/mmengine/registry/registry.py b/mmengine/registry/registry.py
index 4690ee4b..45f64300 100644
--- a/mmengine/registry/registry.py
+++ b/mmengine/registry/registry.py
@@ -81,7 +81,8 @@ def build_from_cfg(
         cfg: Union[dict, ConfigDict, Config],
         registry: 'Registry',
         default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any:
-    """Build a module from config dict.
+    """Build a module from config dict when it is a class configuration, or
+    call a function from config dict when it is a function configuration.
 
     At least one of the ``cfg`` and ``default_args`` contains the key "type"
     which type should be either str or class. If they all contain it, the key
@@ -101,6 +102,12 @@ def build_from_cfg(
         >>>         self.stages = stages
         >>> cfg = dict(type='ResNet', depth=50)
         >>> model = build_from_cfg(cfg, MODELS)
+        >>> # Returns an instantiated object
+        >>> @MODELS.register_module()
+        >>> def resnet50():
+        >>>     pass
+        >>> resnet = build_from_cfg(dict(type='resnet50'), MODELS)
+        >>> # Return a result of the calling function
 
     Args:
         cfg (dict or ConfigDict or Config): Config dict. It should at least
@@ -151,7 +158,7 @@ def build_from_cfg(
                 ' it was registered as expected. More details can be found at'
                 ' https://mmengine.readthedocs.io/en/latest/tutorials/config.html#import-custom-python-modules'  # noqa: E501
             )
-    elif inspect.isclass(obj_type):
+    elif inspect.isclass(obj_type) or inspect.isfunction(obj_type):
         obj_cls = obj_type
     else:
         raise TypeError(
@@ -182,9 +189,10 @@ def build_from_cfg(
 
 
 class Registry:
-    """A registry to map strings to classes.
+    """A registry to map strings to classes or functions.
 
-    Registered objects could be built from registry.
+    Registered object could be built from registry. Meanwhile, registered
+    functions could be called from registry.
 
     Args:
         name (str): Registry name.
@@ -210,6 +218,10 @@ class Registry:
         >>>     pass
         >>> # build model from `MODELS`
         >>> resnet = MODELS.build(dict(type='ResNet'))
+        >>> @MODELS.register_module()
+        >>> def resnet50():
+        >>>     pass
+        >>> resnet = MODELS.build(dict(type='resnet50'))
 
         >>> # hierarchical registry
         >>> DETECTORS = Registry('detectors', parent=MODELS, scope='det')
@@ -525,25 +537,25 @@ class Registry:
         self.children[registry.scope] = registry
 
     def _register_module(self,
-                         module_class: Type,
+                         module: Type,
                          module_name: Optional[Union[str, List[str]]] = None,
                          force: bool = False) -> None:
         """Register a module.
 
         Args:
-            module_class (type): Module class to be registered.
+            module (type): Module class or function to be registered.
             module_name (str or list of str, optional): The module name to be
                 registered. If not specified, the class name will be used.
                 Defaults to None.
             force (bool): Whether to override an existing class with the same
                 name. Defaults to False.
         """
-        if not inspect.isclass(module_class):
-            raise TypeError('module must be a class, '
-                            f'but got {type(module_class)}')
+        if not inspect.isclass(module) and not inspect.isfunction(module):
+            raise TypeError('module must be a class or a function, '
+                            f'but got {type(module)}')
 
         if module_name is None:
-            module_name = module_class.__name__
+            module_name = module.__name__
         if isinstance(module_name, str):
             module_name = [module_name]
         for name in module_name:
@@ -551,7 +563,7 @@ class Registry:
                 existed_module = self.module_dict[name]
                 raise KeyError(f'{name} is already registered in {self.name} '
                                f'at {existed_module.__module__}')
-            self._module_dict[name] = module_class
+            self._module_dict[name] = module
 
     def register_module(
             self,
@@ -569,8 +581,8 @@ class Registry:
                 registered. If not specified, the class name will be used.
             force (bool): Whether to override an existing class with the same
                 name. Default to False.
-            module (type, optional): Module class to be registered. Defaults to
-                None.
+            module (type, optional): Module class or function to be registered.
+                Defaults to None.
 
         Examples:
             >>> backbones = Registry('backbone')
@@ -599,14 +611,12 @@ class Registry:
 
         # use it as a normal method: x.register_module(module=SomeClass)
         if module is not None:
-            self._register_module(
-                module_class=module, module_name=name, force=force)
+            self._register_module(module=module, module_name=name, force=force)
             return module
 
         # use it as a decorator: @x.register_module()
-        def _register(cls):
-            self._register_module(
-                module_class=cls, module_name=name, force=force)
-            return cls
+        def _register(module):
+            self._register_module(module=module, module_name=name, force=force)
+            return module
 
         return _register
diff --git a/tests/test_registry/test_registry.py b/tests/test_registry/test_registry.py
index 9070df09..9a716984 100644
--- a/tests/test_registry/test_registry.py
+++ b/tests/test_registry/test_registry.py
@@ -57,12 +57,23 @@ class TestRegistry:
     def test_register_module(self):
         CATS = Registry('cat')
 
-        # can only decorate a class
+        @CATS.register_module()
+        def muchkin():
+            pass
+
+        assert CATS.get('muchkin') is muchkin
+        assert 'muchkin' in CATS
+
+        # can only decorate a class or a function
         with pytest.raises(TypeError):
 
-            @CATS.register_module()
-            def some_method():
-                pass
+            class Demo:
+
+                def some_method(self):
+                    pass
+
+            method = Demo().some_method
+            CATS.register_module(name='some_method', module=method)
 
         # test `name` parameter which must be either of None, a string or a
         # sequence of string
@@ -71,7 +82,7 @@ class TestRegistry:
         class BritishShorthair:
             pass
 
-        assert len(CATS) == 1
+        assert len(CATS) == 2
         assert CATS.get('BritishShorthair') is BritishShorthair
 
         # `name` is a string
@@ -79,7 +90,7 @@ class TestRegistry:
         class Munchkin:
             pass
 
-        assert len(CATS) == 2
+        assert len(CATS) == 3
         assert CATS.get('Munchkin') is Munchkin
         assert 'Munchkin' in CATS
 
@@ -90,7 +101,7 @@ class TestRegistry:
 
         assert CATS.get('Siamese') is SiameseCat
         assert CATS.get('Siamese2') is SiameseCat
-        assert len(CATS) == 4
+        assert len(CATS) == 5
 
         # `name` is an invalid type
         with pytest.raises(
@@ -127,14 +138,15 @@ class TestRegistry:
         class BritishShorthair:
             pass
 
-        assert len(CATS) == 4
+        assert len(CATS) == 5
 
         # test `module` parameter, which is either None or a class
         # when the `register_module`` is called as a method rather than a
         # decorator, which must be a class
         with pytest.raises(
                 TypeError,
-                match="module must be a class, but got <class 'str'>"):
+                match='module must be a class or a function,'
+                " but got <class 'str'>"):
             CATS.register_module(module='string')
 
         class SphynxCat:
@@ -142,16 +154,16 @@ class TestRegistry:
 
         CATS.register_module(module=SphynxCat)
         assert CATS.get('SphynxCat') is SphynxCat
-        assert len(CATS) == 5
+        assert len(CATS) == 6
 
         CATS.register_module(name='Sphynx1', module=SphynxCat)
         assert CATS.get('Sphynx1') is SphynxCat
-        assert len(CATS) == 6
+        assert len(CATS) == 7
 
         CATS.register_module(name=['Sphynx2', 'Sphynx3'], module=SphynxCat)
         assert CATS.get('Sphynx2') is SphynxCat
         assert CATS.get('Sphynx3') is SphynxCat
-        assert len(CATS) == 8
+        assert len(CATS) == 9
 
     def _build_registry(self):
         """A helper function to build a Hierarchical Registry."""
-- 
GitLab