From 5016332588950bc452fbac1c4c93b333cd7d27e0 Mon Sep 17 00:00:00 2001 From: Alex Yang <50511903+imabackstabber@users.noreply.github.com> Date: Tue, 14 Jun 2022 14:50:24 +0800 Subject: [PATCH] [Feat] support registering function (#302) --- docs/zh_cn/tutorials/registry.md | 28 ++++++++--- .../model/base_model/data_preprocessor.py | 2 +- mmengine/registry/registry.py | 48 +++++++++++-------- tests/test_registry/test_registry.py | 36 +++++++++----- 4 files changed, 75 insertions(+), 39 deletions(-) diff --git a/docs/zh_cn/tutorials/registry.md b/docs/zh_cn/tutorials/registry.md index 5ed7b143..b59541b8 100644 --- a/docs/zh_cn/tutorials/registry.md +++ b/docs/zh_cn/tutorials/registry.md @@ -6,11 +6,11 @@ OpenMMLab 大多数算法库å‡ä½¿ç”¨æ³¨å†Œå™¨æ¥ç®¡ç†ä»–们的代ç 模å—, ## 什么是注册器 -MMEngine 实现的注册器å¯ä»¥çœ‹ä½œä¸€ä¸ªæ˜ 射表和模å—构建方法(build function)的组åˆã€‚æ˜ å°„è¡¨ç»´æŠ¤äº†ä¸€ä¸ªå—ç¬¦ä¸²åˆ°ç±»çš„æ˜ å°„ï¼Œä½¿å¾—ç”¨æˆ·å¯ä»¥å€ŸåŠ©å—符串查找到相应的类,例如维护å—符串 `"ResNet"` 到 `ResNet` ç±»çš„æ˜ å°„ï¼Œä½¿å¾—ç”¨æˆ·å¯ä»¥é€šè¿‡ `"ResNet"` 找到 `ResNet` 类。 -而模å—æž„å»ºæ–¹æ³•åˆ™å®šä¹‰äº†å¦‚ä½•æ ¹æ®å—ç¬¦ä¸²æŸ¥æ‰¾åˆ°å¯¹åº”çš„ç±»ï¼Œå¹¶å®šä¹‰äº†å¦‚ä½•å®žä¾‹åŒ–è¿™ä¸ªç±»ï¼Œä¾‹å¦‚æ ¹æ®è§„则通过å—符串 `"bn"` 找到 `nn.BatchNorm2d`,并且实例化 `BatchNorm2d` 模å—。 +MMEngine 实现的注册器å¯ä»¥çœ‹ä½œä¸€ä¸ªæ˜ 射表和模å—构建方法(build function)的组åˆã€‚æ˜ å°„è¡¨ç»´æŠ¤äº†ä¸€ä¸ªå—ç¬¦ä¸²åˆ°ç±»æˆ–è€…å‡½æ•°çš„æ˜ å°„ï¼Œä½¿å¾—ç”¨æˆ·å¯ä»¥å€ŸåŠ©å—符串查找到相应的类或函数,例如维护å—符串 `"ResNet"` 到 `ResNet` ç±»æˆ–å‡½æ•°çš„æ˜ å°„ï¼Œä½¿å¾—ç”¨æˆ·å¯ä»¥é€šè¿‡ `"ResNet"` 找到 `ResNet` 类或函数; +而模å—æž„å»ºæ–¹æ³•åˆ™å®šä¹‰äº†å¦‚ä½•æ ¹æ®å—ç¬¦ä¸²æŸ¥æ‰¾åˆ°å¯¹åº”çš„ç±»æˆ–å‡½æ•°ï¼Œå¹¶å®šä¹‰äº†å¦‚ä½•å®žä¾‹åŒ–è¿™ä¸ªç±»æˆ–è°ƒç”¨è¿™ä¸ªå‡½æ•°ï¼Œä¾‹å¦‚æ ¹æ®è§„则通过å—符串 `"bn"` 找到 `nn.BatchNorm2d`,并且实例化 `BatchNorm2d` 模å—。åˆæˆ–è€…æ ¹æ®è§„则通过å—符串 `"bn"` 找到 `build_batchnorm2d`,并且调用函数获得 `BatchNorm2d` 模å—。 MMEngine ä¸çš„注册器默认使用 [build_from_cfg 函数](https://mmengine.readthedocs.io/zh_CN/latest/api.html#mmengine.registry.build_from_cfg) æ¥æŸ¥æ‰¾å¹¶å®žä¾‹åŒ–å—符串对应的类。 -一个注册器管ç†çš„类通常有相似的接å£å’ŒåŠŸèƒ½ï¼Œå› æ¤è¯¥æ³¨å†Œå™¨å¯ä»¥è¢«è§†ä½œè¿™äº›ç±»çš„抽象。例如注册器 `Classifier` å¯ä»¥è¢«è§†ä½œæ‰€æœ‰åˆ†ç±»ç½‘络的抽象,管ç†äº† `ResNet`, `SEResNet` å’Œ `RegNetX` ç‰åˆ†ç±»ç½‘络的类。 +一个注册器管ç†çš„类或函数通常有相似的接å£å’ŒåŠŸèƒ½ï¼Œå› æ¤è¯¥æ³¨å†Œå™¨å¯ä»¥è¢«è§†ä½œè¿™äº›ç±»æˆ–函数的抽象。例如注册器 `Classifier` å¯ä»¥è¢«è§†ä½œæ‰€æœ‰åˆ†ç±»ç½‘络的抽象,管ç†äº† `ResNet`, `SEResNet` å’Œ `RegNetX` ç‰åˆ†ç±»ç½‘ç»œçš„ç±»ä»¥åŠ `build_ResNet`, `build_SEResNet` å’Œ `build_RegNetX` ç‰åˆ†ç±»ç½‘络的构建函数。 使用注册器管ç†åŠŸèƒ½ç›¸ä¼¼çš„模å—å¯ä»¥æ˜¾è‘—æ高代ç 的扩展性和çµæ´»æ€§ã€‚用户å¯ä»¥è·³è‡³`使用注册器æ高代ç 的扩展性`ç« èŠ‚äº†è§£æ³¨å†Œå™¨æ˜¯å¦‚ä½•æ高代ç 拓展性的。 ## 入门用法 @@ -32,10 +32,10 @@ from mmengine import Registry CONVERTERS = Registry('converter') ``` -然åŽæˆ‘们å¯ä»¥å®žçŽ°ä¸åŒçš„转æ¢å™¨ã€‚ +然åŽæˆ‘们å¯ä»¥å®žçŽ°ä¸åŒçš„转æ¢å™¨ã€‚例如,在 `converters/converter_cls.py` ä¸å®žçŽ° `Converter1` å’Œ `Converter2`,在 `converters/converter_func.py` ä¸å®žçŽ° `converter3`。 ```python -# converters/converter.py +# converters/converter_cls.py from .builder import CONVERTERS # 使用注册器管ç†æ¨¡å— @@ -53,12 +53,23 @@ class Converter2(object): self.c = c ``` -使用注册器管ç†æ¨¡å—的关键æ¥éª¤æ˜¯ï¼Œå°†å®žçŽ°çš„模å—注册到注册表 `CONVERTERS` ä¸ã€‚通过 `@CONVERTERS.register_module()` 装饰所实现的模å—,å—ç¬¦ä¸²å’Œç±»ä¹‹é—´çš„æ˜ å°„å°±å¯ä»¥ç”± `CONVERTERS` 构建和维护,我们也å¯ä»¥é€šè¿‡ `CONVERTERS.register_module(module=Converter1)` 实现åŒæ ·çš„功能。 +```python +# converters/converter_func.py +from .builder import CONVERTERS +from .converter_cls import Converter1 +@CONVERTERS.register_module() +def converter3(a, b) + return Converter1(a, b) +``` + +使用注册器管ç†æ¨¡å—的关键æ¥éª¤æ˜¯ï¼Œå°†å®žçŽ°çš„模å—注册到注册表 `CONVERTERS` ä¸ã€‚通过 `@CONVERTERS.register_module()` 装饰所实现的模å—,å—ç¬¦ä¸²å’Œç±»æˆ–å‡½æ•°ä¹‹é—´çš„æ˜ å°„å°±å¯ä»¥ç”± `CONVERTERS` 构建和维护,我们也å¯ä»¥é€šè¿‡ `CONVERTERS.register_module(module=Converter1)` 实现åŒæ ·çš„功能。 -通过注册,我们就å¯ä»¥é€šè¿‡ `CONVERTERS` 建立å—ç¬¦ä¸²ä¸Žç±»ä¹‹é—´çš„æ˜ å°„ï¼Œ +通过注册,我们就å¯ä»¥é€šè¿‡ `CONVERTERS` 建立å—ç¬¦ä¸²ä¸Žç±»æˆ–å‡½æ•°ä¹‹é—´çš„æ˜ å°„ï¼Œ ```python 'Converter1' -> <class 'Converter1'> +'Converter2' -> <class 'Converter2'> +'Converter3' -> <function 'Converter3'> ``` ```{note} @@ -72,6 +83,9 @@ class Converter2(object): # 注æ„,converter_cfg å¯ä»¥é€šè¿‡è§£æžé…置文件得到 converter_cfg = dict(type='Converter1', a=a_value, b=b_value) converter = CONVERTERS.build(converter_cfg) +converter3_cfg = dict(type='converter3', a=a_value, b=b_value) +# returns the calling result +converter3 = CONVERTERS.build(converter3_cfg) ``` 如果我们想使用 `Converter2`,仅需修改é…置。 diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py index 2640ce55..9fa5b48e 100644 --- a/mmengine/model/base_model/data_preprocessor.py +++ b/mmengine/model/base_model/data_preprocessor.py @@ -191,7 +191,7 @@ class ImgDataPreprocessor(BaseDataPreprocessor): f'RGB or gray image, but got {len(mean)}') assert len(std) == 3 or len(std) == 1, ( # type: ignore 'The length of std should be 1 or 3 to be compatible with RGB ' # type: ignore # noqa: E501 - f'or gray image, but got {len(std)}') + f'or gray image, but got {len(std)}') # type: ignore self._enable_normalize = True self.register_buffer('mean', torch.tensor(mean).view(-1, 1, 1), False) diff --git a/mmengine/registry/registry.py b/mmengine/registry/registry.py index 4690ee4b..45f64300 100644 --- a/mmengine/registry/registry.py +++ b/mmengine/registry/registry.py @@ -81,7 +81,8 @@ def build_from_cfg( cfg: Union[dict, ConfigDict, Config], registry: 'Registry', default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> Any: - """Build a module from config dict. + """Build a module from config dict when it is a class configuration, or + call a function from config dict when it is a function configuration. At least one of the ``cfg`` and ``default_args`` contains the key "type" which type should be either str or class. If they all contain it, the key @@ -101,6 +102,12 @@ def build_from_cfg( >>> self.stages = stages >>> cfg = dict(type='ResNet', depth=50) >>> model = build_from_cfg(cfg, MODELS) + >>> # Returns an instantiated object + >>> @MODELS.register_module() + >>> def resnet50(): + >>> pass + >>> resnet = build_from_cfg(dict(type='resnet50'), MODELS) + >>> # Return a result of the calling function Args: cfg (dict or ConfigDict or Config): Config dict. It should at least @@ -151,7 +158,7 @@ def build_from_cfg( ' it was registered as expected. More details can be found at' ' https://mmengine.readthedocs.io/en/latest/tutorials/config.html#import-custom-python-modules' # noqa: E501 ) - elif inspect.isclass(obj_type): + elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): obj_cls = obj_type else: raise TypeError( @@ -182,9 +189,10 @@ def build_from_cfg( class Registry: - """A registry to map strings to classes. + """A registry to map strings to classes or functions. - Registered objects could be built from registry. + Registered object could be built from registry. Meanwhile, registered + functions could be called from registry. Args: name (str): Registry name. @@ -210,6 +218,10 @@ class Registry: >>> pass >>> # build model from `MODELS` >>> resnet = MODELS.build(dict(type='ResNet')) + >>> @MODELS.register_module() + >>> def resnet50(): + >>> pass + >>> resnet = MODELS.build(dict(type='resnet50')) >>> # hierarchical registry >>> DETECTORS = Registry('detectors', parent=MODELS, scope='det') @@ -525,25 +537,25 @@ class Registry: self.children[registry.scope] = registry def _register_module(self, - module_class: Type, + module: Type, module_name: Optional[Union[str, List[str]]] = None, force: bool = False) -> None: """Register a module. Args: - module_class (type): Module class to be registered. + module (type): Module class or function to be registered. module_name (str or list of str, optional): The module name to be registered. If not specified, the class name will be used. Defaults to None. force (bool): Whether to override an existing class with the same name. Defaults to False. """ - if not inspect.isclass(module_class): - raise TypeError('module must be a class, ' - f'but got {type(module_class)}') + if not inspect.isclass(module) and not inspect.isfunction(module): + raise TypeError('module must be a class or a function, ' + f'but got {type(module)}') if module_name is None: - module_name = module_class.__name__ + module_name = module.__name__ if isinstance(module_name, str): module_name = [module_name] for name in module_name: @@ -551,7 +563,7 @@ class Registry: existed_module = self.module_dict[name] raise KeyError(f'{name} is already registered in {self.name} ' f'at {existed_module.__module__}') - self._module_dict[name] = module_class + self._module_dict[name] = module def register_module( self, @@ -569,8 +581,8 @@ class Registry: registered. If not specified, the class name will be used. force (bool): Whether to override an existing class with the same name. Default to False. - module (type, optional): Module class to be registered. Defaults to - None. + module (type, optional): Module class or function to be registered. + Defaults to None. Examples: >>> backbones = Registry('backbone') @@ -599,14 +611,12 @@ class Registry: # use it as a normal method: x.register_module(module=SomeClass) if module is not None: - self._register_module( - module_class=module, module_name=name, force=force) + self._register_module(module=module, module_name=name, force=force) return module # use it as a decorator: @x.register_module() - def _register(cls): - self._register_module( - module_class=cls, module_name=name, force=force) - return cls + def _register(module): + self._register_module(module=module, module_name=name, force=force) + return module return _register diff --git a/tests/test_registry/test_registry.py b/tests/test_registry/test_registry.py index 9070df09..9a716984 100644 --- a/tests/test_registry/test_registry.py +++ b/tests/test_registry/test_registry.py @@ -57,12 +57,23 @@ class TestRegistry: def test_register_module(self): CATS = Registry('cat') - # can only decorate a class + @CATS.register_module() + def muchkin(): + pass + + assert CATS.get('muchkin') is muchkin + assert 'muchkin' in CATS + + # can only decorate a class or a function with pytest.raises(TypeError): - @CATS.register_module() - def some_method(): - pass + class Demo: + + def some_method(self): + pass + + method = Demo().some_method + CATS.register_module(name='some_method', module=method) # test `name` parameter which must be either of None, a string or a # sequence of string @@ -71,7 +82,7 @@ class TestRegistry: class BritishShorthair: pass - assert len(CATS) == 1 + assert len(CATS) == 2 assert CATS.get('BritishShorthair') is BritishShorthair # `name` is a string @@ -79,7 +90,7 @@ class TestRegistry: class Munchkin: pass - assert len(CATS) == 2 + assert len(CATS) == 3 assert CATS.get('Munchkin') is Munchkin assert 'Munchkin' in CATS @@ -90,7 +101,7 @@ class TestRegistry: assert CATS.get('Siamese') is SiameseCat assert CATS.get('Siamese2') is SiameseCat - assert len(CATS) == 4 + assert len(CATS) == 5 # `name` is an invalid type with pytest.raises( @@ -127,14 +138,15 @@ class TestRegistry: class BritishShorthair: pass - assert len(CATS) == 4 + assert len(CATS) == 5 # test `module` parameter, which is either None or a class # when the `register_module`` is called as a method rather than a # decorator, which must be a class with pytest.raises( TypeError, - match="module must be a class, but got <class 'str'>"): + match='module must be a class or a function,' + " but got <class 'str'>"): CATS.register_module(module='string') class SphynxCat: @@ -142,16 +154,16 @@ class TestRegistry: CATS.register_module(module=SphynxCat) assert CATS.get('SphynxCat') is SphynxCat - assert len(CATS) == 5 + assert len(CATS) == 6 CATS.register_module(name='Sphynx1', module=SphynxCat) assert CATS.get('Sphynx1') is SphynxCat - assert len(CATS) == 6 + assert len(CATS) == 7 CATS.register_module(name=['Sphynx2', 'Sphynx3'], module=SphynxCat) assert CATS.get('Sphynx2') is SphynxCat assert CATS.get('Sphynx3') is SphynxCat - assert len(CATS) == 8 + assert len(CATS) == 9 def _build_registry(self): """A helper function to build a Hierarchical Registry.""" -- GitLab