diff --git a/docs/zh_cn/tutorials/registry.md b/docs/zh_cn/tutorials/registry.md index 71152f0509f0150d3226fc323d28a27768468ffc..15ef0148324c571231cbfdcc404197dc4c0ae666 100644 --- a/docs/zh_cn/tutorials/registry.md +++ b/docs/zh_cn/tutorials/registry.md @@ -1,242 +1,184 @@ # 注册器(Registry) OpenMMLab 的算法库支æŒäº†ä¸°å¯Œçš„算法和数æ®é›†ï¼Œå› æ¤å®žçŽ°äº†å¾ˆå¤šåŠŸèƒ½ç›¸è¿‘的模å—。例如 ResNet å’Œ SE-ResNet 的算法实现分别基于 `ResNet` å’Œ `SEResNet` 类,这些类有相似的功能和接å£ï¼Œéƒ½å±žäºŽç®—法库ä¸çš„模型组件。 -为了管ç†è¿™äº›åŠŸèƒ½ç›¸ä¼¼çš„模å—,MMEngine 实现了 [注册器](https://mmengine.readthedocs.io/zh_CN/latest/api.html#mmengine.registry.Registry)。 -OpenMMLab 大多数算法库å‡ä½¿ç”¨æ³¨å†Œå™¨æ¥ç®¡ç†ä»–们的代ç 模å—,包括 [MMDetection](https://github.com/open-mmlab/mmdetection), [MMDetection3D](https://github.com/open-mmlab/mmdetection3d),[MMClassification](https://github.com/open-mmlab/mmclassification) å’Œ [MMEditing](https://github.com/open-mmlab/mmediting) ç‰ã€‚ +为了管ç†è¿™äº›åŠŸèƒ½ç›¸ä¼¼çš„模å—,MMEngine 实现了 [注册器](mmengine.registry.Registry)。 +OpenMMLab 大多数算法库å‡ä½¿ç”¨æ³¨å†Œå™¨æ¥ç®¡ç†å®ƒä»¬çš„代ç 模å—,包括 [MMDetection](https://github.com/open-mmlab/mmdetection), [MMDetection3D](https://github.com/open-mmlab/mmdetection3d),[MMClassification](https://github.com/open-mmlab/mmclassification) å’Œ [MMEditing](https://github.com/open-mmlab/mmediting) ç‰ã€‚ ## 什么是注册器 -MMEngine 实现的注册器å¯ä»¥çœ‹ä½œä¸€ä¸ªæ˜ 射表和模å—构建方法(build function)的组åˆã€‚æ˜ å°„è¡¨ç»´æŠ¤äº†ä¸€ä¸ªå—ç¬¦ä¸²åˆ°ç±»æˆ–è€…å‡½æ•°çš„æ˜ å°„ï¼Œä½¿å¾—ç”¨æˆ·å¯ä»¥å€ŸåŠ©å—符串查找到相应的类或函数,例如维护å—符串 `"ResNet"` 到 `ResNet` ç±»æˆ–å‡½æ•°çš„æ˜ å°„ï¼Œä½¿å¾—ç”¨æˆ·å¯ä»¥é€šè¿‡ `"ResNet"` 找到 `ResNet` 类或函数; -而模å—æž„å»ºæ–¹æ³•åˆ™å®šä¹‰äº†å¦‚ä½•æ ¹æ®å—ç¬¦ä¸²æŸ¥æ‰¾åˆ°å¯¹åº”çš„ç±»æˆ–å‡½æ•°ï¼Œå¹¶å®šä¹‰äº†å¦‚ä½•å®žä¾‹åŒ–è¿™ä¸ªç±»æˆ–è°ƒç”¨è¿™ä¸ªå‡½æ•°ï¼Œä¾‹å¦‚æ ¹æ®è§„则通过å—符串 `"bn"` 找到 `nn.BatchNorm2d`,并且实例化 `BatchNorm2d` 模å—。åˆæˆ–è€…æ ¹æ®è§„则通过å—符串 `"bn"` 找到 `build_batchnorm2d`,并且调用函数获得 `BatchNorm2d` 模å—。 -MMEngine ä¸çš„注册器默认使用 [build_from_cfg 函数](https://mmengine.readthedocs.io/zh_CN/latest/api.html#mmengine.registry.build_from_cfg) æ¥æŸ¥æ‰¾å¹¶å®žä¾‹åŒ–å—符串对应的类。 +MMEngine 实现的[注册器](mmengine.registry.Registry)å¯ä»¥çœ‹ä½œä¸€ä¸ªæ˜ 射表和模å—构建方法(build function)的组åˆã€‚æ˜ å°„è¡¨ç»´æŠ¤äº†ä¸€ä¸ªå—符串到**ç±»æˆ–è€…å‡½æ•°çš„æ˜ å°„**,使得用户å¯ä»¥å€ŸåŠ©å—符串查找到相应的类或函数,例如维护å—符串 `"ResNet"` 到 `ResNet` ç±»æˆ–å‡½æ•°çš„æ˜ å°„ï¼Œä½¿å¾—ç”¨æˆ·å¯ä»¥é€šè¿‡ `"ResNet"` 找到 `ResNet` 类; +而模å—æž„å»ºæ–¹æ³•åˆ™å®šä¹‰äº†å¦‚ä½•æ ¹æ®å—符串查找到对应的类或函数以åŠå¦‚何实例化这个类或者调用这个函数,例如,通过å—符串 `"bn"` 找到 `nn.BatchNorm2d` 并实例化 `BatchNorm2d` 模å—ï¼›åˆæˆ–者通过å—符串 `"build_batchnorm2d"` 找到 `build_batchnorm2d` 函数并返回该函数的调用结果。 +MMEngine ä¸çš„注册器默认使用 [build_from_cfg](mmengine.registry.build_from_cfg) 函数æ¥æŸ¥æ‰¾å¹¶å®žä¾‹åŒ–å—符串对应的类或者函数。 -一个注册器管ç†çš„类或函数通常有相似的接å£å’ŒåŠŸèƒ½ï¼Œå› æ¤è¯¥æ³¨å†Œå™¨å¯ä»¥è¢«è§†ä½œè¿™äº›ç±»æˆ–函数的抽象。例如注册器 `Classifier` å¯ä»¥è¢«è§†ä½œæ‰€æœ‰åˆ†ç±»ç½‘络的抽象,管ç†äº† `ResNet`, `SEResNet` å’Œ `RegNetX` ç‰åˆ†ç±»ç½‘ç»œçš„ç±»ä»¥åŠ `build_ResNet`, `build_SEResNet` å’Œ `build_RegNetX` ç‰åˆ†ç±»ç½‘络的构建函数。 -使用注册器管ç†åŠŸèƒ½ç›¸ä¼¼çš„模å—å¯ä»¥æ˜¾è‘—æ高代ç 的扩展性和çµæ´»æ€§ã€‚用户å¯ä»¥è·³è‡³`使用注册器æ高代ç 的扩展性`ç« èŠ‚äº†è§£æ³¨å†Œå™¨æ˜¯å¦‚ä½•æ高代ç 拓展性的。 +一个注册器管ç†çš„类或函数通常有相似的接å£å’ŒåŠŸèƒ½ï¼Œå› æ¤è¯¥æ³¨å†Œå™¨å¯ä»¥è¢«è§†ä½œè¿™äº›ç±»æˆ–函数的抽象。例如注册器 `MODELS` å¯ä»¥è¢«è§†ä½œæ‰€æœ‰æ¨¡åž‹çš„抽象,管ç†äº† `ResNet`, `SEResNet` å’Œ `RegNetX` ç‰åˆ†ç±»ç½‘ç»œçš„ç±»ä»¥åŠ `build_ResNet`, `build_SEResNet` å’Œ `build_RegNetX` ç‰åˆ†ç±»ç½‘络的构建函数。 ## 入门用法 使用注册器管ç†ä»£ç 库ä¸çš„模å—,需è¦ä»¥ä¸‹ä¸‰ä¸ªæ¥éª¤ã€‚ 1. 创建注册器 -2. 创建一个用于实例化类的构建方法(å¯é€‰ï¼Œåœ¨å¤§å¤šæ•°æƒ…况下您å¯ä»¥åªä½¿ç”¨é»˜è®¤æ–¹æ³•ï¼‰ +2. 创建一个用于实例化类的构建方法(å¯é€‰ï¼Œåœ¨å¤§å¤šæ•°æƒ…况下å¯ä»¥åªä½¿ç”¨é»˜è®¤æ–¹æ³•ï¼‰ 3. 将模å—åŠ å…¥æ³¨å†Œå™¨ä¸ -å‡è®¾æˆ‘们è¦å®žçŽ°ä¸€ç³»åˆ—æ•°æ®é›†è½¬æ¢å™¨ï¼ˆDataset Converter),将ä¸åŒæ ¼å¼çš„æ•°æ®è½¬æ¢ä¸ºæ ‡å‡†æ•°æ®æ ¼å¼ã€‚我们希望å¯ä»¥å®žçŽ°ä»…修改é…置就能够使用ä¸åŒçš„转æ¢å™¨è€Œæ— 需修改代ç 。 +å‡è®¾æˆ‘们è¦å®žçŽ°ä¸€ç³»åˆ—激活模å—并且希望仅修改é…置就能够使用ä¸åŒçš„激活模å—è€Œæ— éœ€ä¿®æ”¹ä»£ç 。 -我们先创建一个å为 `converters` 的目录作为包,在包ä¸æˆ‘们创建一个文件æ¥å®žçŽ°æž„建器(builder), +首先创建注册器, ```python -# model/builder.py from mmengine import Registry -# 创建转æ¢å™¨çš„注册器 -CONVERTERS = Registry('converter') +# scope 表示注册器的作用域,如果ä¸è®¾ç½®ï¼Œé»˜è®¤ä¸ºåŒ…å,例如在 mmdetection ä¸ï¼Œå®ƒçš„ scope 为 mmdet +ACTIVATION = Registry('activation', scope='mmengine') ``` -然åŽæˆ‘们å¯ä»¥å®žçŽ°ä¸åŒçš„转æ¢å™¨ã€‚例如,在 `converters/converter_cls.py` ä¸å®žçŽ° `Converter1` å’Œ `Converter2`,在 `converters/converter_func.py` ä¸å®žçŽ° `converter3`。 +然åŽæˆ‘们å¯ä»¥å®žçŽ°ä¸åŒçš„激活模å—,例如 `Sigmoid`,`ReLU` å’Œ `Softmax`。 ```python -# converters/converter_cls.py -from .builder import CONVERTERS +import torch.nn as nn # 使用注册器管ç†æ¨¡å— -@CONVERTERS.register_module() -class Converter1(object): - def __init__(self, a, b): - self.a = a - self.b = b - -@CONVERTERS.register_module() -class Converter2(object): - def __init__(self, a, b, c): - self.a = a - self.b = b - self.c = c -``` - -```python -# converters/converter_func.py -from .builder import CONVERTERS -from .converter_cls import Converter1 -@CONVERTERS.register_module() -def converter3(a, b) - return Converter1(a, b) -``` - -使用注册器管ç†æ¨¡å—的关键æ¥éª¤æ˜¯ï¼Œå°†å®žçŽ°çš„模å—注册到注册表 `CONVERTERS` ä¸ã€‚通过 `@CONVERTERS.register_module()` 装饰所实现的模å—,å—ç¬¦ä¸²å’Œç±»æˆ–å‡½æ•°ä¹‹é—´çš„æ˜ å°„å°±å¯ä»¥ç”± `CONVERTERS` 构建和维护,我们也å¯ä»¥é€šè¿‡ `CONVERTERS.register_module(module=Converter1)` 实现åŒæ ·çš„功能。 +@ACTIVATION.register_module() +class Sigmoid(nn.Module): + def __init__(self): + super().__init__() -通过注册,我们就å¯ä»¥é€šè¿‡ `CONVERTERS` 建立å—ç¬¦ä¸²ä¸Žç±»æˆ–å‡½æ•°ä¹‹é—´çš„æ˜ å°„ï¼Œ + def forward(self, x): + print('call Sigmoid.forward') + return x -```python -'Converter1' -> <class 'Converter1'> -'Converter2' -> <class 'Converter2'> -'Converter3' -> <function 'Converter3'> -``` +@ACTIVATION.register_module() +class ReLU(nn.Module): + def __init__(self, inplace=False): + super().__init__() -```{note} -åªæœ‰æ¨¡å—所在的文件被导入时,注册机制æ‰ä¼šè¢«è§¦å‘,所以我们需è¦åœ¨æŸå¤„导入该文件或者使用 `custom_imports` å—段动æ€å¯¼å…¥è¯¥æ¨¡å—进而触å‘æ³¨å†Œæœºåˆ¶ï¼Œè¯¦æƒ…è§ [导入自定义 Python 模å—](https://mmengine.readthedocs.io/zh_CN/latest/tutorials/config.html#python). -``` + def forward(self, x): + print('call ReLU.forward') + return x -模å—æˆåŠŸæ³¨å†ŒåŽï¼Œæˆ‘们å¯ä»¥é€šè¿‡é…置文件使用这个转æ¢å™¨ã€‚ +@ACTIVATION.register_module() +class Softmax(nn.Module): + def __init__(self): + super().__init__() -```python -# main.py -# 注æ„,converter_cfg å¯ä»¥é€šè¿‡è§£æžé…置文件得到 -converter_cfg = dict(type='Converter1', a=a_value, b=b_value) -converter = CONVERTERS.build(converter_cfg) -converter3_cfg = dict(type='converter3', a=a_value, b=b_value) -# returns the calling result -converter3 = CONVERTERS.build(converter3_cfg) + def forward(self, x): + print('call Softmax.forward') + return x ``` -如果我们想使用 `Converter2`,仅需修改é…置。 +使用注册器管ç†æ¨¡å—的关键æ¥éª¤æ˜¯ï¼Œå°†å®žçŽ°çš„模å—注册到注册表 `ACTIVATION` ä¸ã€‚通过 `@ACTIVATION.register_module()` 装饰所实现的模å—,å—ç¬¦ä¸²å’Œç±»æˆ–å‡½æ•°ä¹‹é—´çš„æ˜ å°„å°±å¯ä»¥ç”± `ACTIVATION` 构建和维护,我们也å¯ä»¥é€šè¿‡ `ACTIVATION.register_module(module=ReLU)` 实现åŒæ ·çš„功能。 -```python -converter_cfg = dict(type='Converter2', a=a_value, b=b_value, c=c_value) -converter = CONVERTERS.build(converter_cfg) -``` - -å‡å¦‚我们想在创建实例å‰æ£€æŸ¥è¾“å…¥å‚数的类型(或者任何其他æ“作),我们å¯ä»¥å®žçŽ°ä¸€ä¸ªæž„å»ºæ–¹æ³•å¹¶å°†å…¶ä¼ é€’ç»™æ³¨å†Œå™¨ä»Žè€Œå®žçŽ°è‡ªå®šä¹‰æž„å»ºæµç¨‹ã€‚ +通过注册,我们就å¯ä»¥é€šè¿‡ `ACTIVATION` 建立å—ç¬¦ä¸²ä¸Žç±»æˆ–å‡½æ•°ä¹‹é—´çš„æ˜ å°„ï¼Œ ```python -from mmengine import Registry - -# 创建一个构建方法 -def build_converter(cfg, registry, *args, **kwargs): - cfg_ = cfg.copy() - converter_type = cfg_.pop('type') - if converter_type not in registry: - raise KeyError(f'Unrecognized converter type {converter_type}') - else: - converter_cls = registry.get(converter_type) - - converter = converter_cls(*args, **kwargs, **cfg_) - return converter - -# 创建一个用于转æ¢å™¨çš„注册器,并将 `build_converter` ä¼ é€’ç»™ `build_func` å‚æ•° -CONVERTERS = Registry('converter', build_func=build_converter) +print(ACTIVATION.module_dict) +# { +# 'Sigmoid': __main__.Sigmoid, +# 'ReLU': __main__.ReLU, +# 'Softmax': __main__.Softmax +# } ``` ```{note} -在这个例åä¸ï¼Œæˆ‘们演示了如何使用å‚数:`build_func` 自定义构建类的实例的方法。 -该功能类似于默认的 `build_from_cfg` 方法。在大多数情况下,使用默认的方法就å¯ä»¥äº†ã€‚ +åªæœ‰æ¨¡å—所在的文件被导入时,注册机制æ‰ä¼šè¢«è§¦å‘,所以我们需è¦åœ¨æŸå¤„导入该文件或者使用 `custom_imports` å—段动æ€å¯¼å…¥è¯¥æ¨¡å—进而触å‘注册机制,详情è§[导入自定义 Python 模å—](config.md)。 ``` -## 使用注册器æ高代ç 的扩展性 - -使用注册器管ç†åŠŸèƒ½ç›¸ä¼¼çš„模å—å¯ä»¥ä¾¿åˆ©æ¨¡å—的自由组åˆä¸Žçµæ´»æ‹“展。下é¢é€šè¿‡ä¾‹å介ç»æ³¨å†Œå™¨çš„两个优点。 - -### 模å—çš„è‡ªç”±ç»„åˆ - -å‡è®¾ç”¨æˆ·å®žçŽ°äº†ä¸€ä¸ªæ¨¡å— `ConvBlock`,`ConvBlock` ä¸å®šä¹‰äº†ä¸€ä¸ªå·ç§¯å±‚和一个激活层。 +模å—æˆåŠŸæ³¨å†ŒåŽï¼Œæˆ‘们å¯ä»¥é€šè¿‡é…置文件使用这个激活模å—。 ```python -import torch.nn as nn - -class ConvBlock(nn.Module): - - def __init__(self): - self.conv = nn.Conv2d() - self.act = nn.ReLU() - - def forward(self, x): - x = self.conv(x) - x = self.act(x) - return x - -conv_blcok = ConvBlock() +import torch +input = torch.randn(2) + +act_cfg = dict(type='Sigmoid') +activation = ACTIVATION.build(act_cfg) +output = activation(input) +# call Sigmoid.forward +print(output) +# tensor([0.0159, 0.0815]) ``` -å¯ä»¥å‘现,æ¤æ—¶ ConvBlock åªæ”¯æŒ `nn.Conv2d` å’Œ `nn.ReLU` 的组åˆã€‚如果我们想è¦è®© `ConvBlock` æ›´åŠ é€šç”¨ï¼Œä¾‹å¦‚è®©å®ƒå¯ä»¥ä½¿ç”¨å…¶ä»–类型的激活层,在ä¸ä½¿ç”¨æ³¨å†Œå™¨çš„情况下,需è¦åšå¦‚下改动 +如果我们想使用 `ReLU`,仅需修改é…置。 ```python -import torch.nn as nn +act_cfg = dict(type='ReLU', inplace=True) +activation = ACTIVATION.build(act_cfg) +output = activation(input) +# call Sigmoid.forward +print(output) +# tensor([0.0159, 0.0815]) +``` -class ConvBlock(nn.Module): +如果我们希望在创建实例å‰æ£€æŸ¥è¾“å…¥å‚数的类型(或者任何其他æ“作),我们å¯ä»¥å®žçŽ°ä¸€ä¸ªæž„å»ºæ–¹æ³•å¹¶å°†å…¶ä¼ é€’ç»™æ³¨å†Œå™¨ä»Žè€Œå®žçŽ°è‡ªå®šä¹‰æž„å»ºæµç¨‹ã€‚ - def __init__(self, act_type): - self.conv = nn.Conv2d() - if act_type == 'relu': - self.act = nn.ReLU() - elif act_type == 'gelu': - self.act = nn.GELU() +创建一个构建方法, - def forward(self, x): - x = self.conv(x) - x = self.act(x) - return x +```python -conv_block = ConvBlock() +def build_activation(cfg, registry, *args, **kwargs): + cfg_ = cfg.copy() + act_type = cfg_.pop('type') + print(f'build activation: {act_type}') + act_cls = registry.get(act_type) + act = act_cls(*args, **kwargs, **cfg_) + return act ``` -å¯ä»¥å‘现,上述改动需è¦æžšä¸¾æ¨¡å—çš„å„ç§ç±»åž‹ï¼Œæ— 法çµæ´»åœ°ç»„åˆå„ç§æ¨¡å—。而如果使用注册器,该问题å¯ä»¥è½»æ¾è§£å†³ï¼Œç”¨æˆ·åªéœ€è¦åœ¨æž„建 ConvBlock 的时候设置ä¸åŒçš„ `conv_cfg` å’Œ `act_cfg` å³å¯è¾¾åˆ°ç›®çš„。 +并将 `build_activation` ä¼ é€’ç»™ `build_func` å‚æ•° ```python -import torch.nn as nn -from mmengine import MODELS +ACTIVATION = Registry('activation', build_func=build_activation, scope='mmengine') -# å°†å·ç§¯å’Œæ¿€æ´»æ¨¡å—注册到 MODELS -MODELS.register_module(module=nn.Conv2d) -MODELS.register_module(module=nn.ReLU) -MODELS.register_module(module=nn.GELU) - -class ConvBlock(nn.Module): - - def __init__(self, conv_cfg, act_cfg): - self.conv = MODELS.build(conv_cfg) - self.pool = MODELS.build(act_cfg) +@ACTIVATION.register_module() +class Tanh(nn.Module): + def __init__(self): + super().__init__() def forward(self, x): - x = self.conv(x) - x = self.act(x) + print('call Tanh.forward') return x -# 注æ„,conv_cfg å’Œ act_cfg å¯ä»¥é€šè¿‡è§£æžé…置文件得到 -conv_cfg = dict(type='Conv2d') -act_cfg = dict(type='GELU') -conv_block = ConvBlock(conv_cfg, act_cfg) +act_cfg = dict(type='Tanh') +activation = ACTIVATION.build(act_cfg) +output = activation(input) +# build activation: Tanh +# call Tanh.forward +print(output) +# tensor([0.0159, 0.0815]) ``` -### 模å—çš„çµæ´»æ‹“展 +```{note} +在这个例åä¸ï¼Œæˆ‘们演示了如何使用å‚æ•° `build_func` 自定义构建类的实例的方法。 +该功能类似于默认的 `build_from_cfg` 方法。在大多数情况下,使用默认的方法就å¯ä»¥äº†ã€‚ +``` -如果我们自定义了一个 `DeformConv2d` å·ç§¯æ¨¡å—,我们åªéœ€å°†è¯¥æ¨¡å—注册到 `MODELS`, +MMEngine 的注册器除了å¯ä»¥æ³¨å†Œç±»ï¼Œä¹Ÿå¯ä»¥æ³¨å†Œå‡½æ•°ã€‚ ```python -import torch.nn as nn -from mmengine import MODELS - -@MODELS.register_module() -class DeformConv2d(nn.Module): - pass -``` +FUNCTION = Registry('function', scope='mmengine') -å°±å¯ä»¥é€šè¿‡é…置使用该模å—。 +@FUNCTION.register_module() +def print_args(**kwargs): + print(kwargs) -```python -conv_cfg = dict(type='DeformConv2d') -act_cfg = dict(type='GELU') -conv_block = ConvBlock(conv_cfg, act_cfg) -conv = MODELS.build(cfg) +func_cfg = dict(type='print_args', a=1, b=2) +func_res = FUNCTION.build(func_cfg) ``` -å¯ä»¥çœ‹åˆ°ï¼Œæ·»åŠ 了 `DeformConv2d` 模å—并ä¸éœ€è¦å¯¹ `ConvBlock` åšä¿®æ”¹ã€‚ - -## 通过 Registry 实现模å—的跨库调用 +## 进阶用法 -MMEngine 的注册器支æŒè·¨é¡¹ç›®è°ƒç”¨ï¼Œå³å¯ä»¥åœ¨ä¸€ä¸ªé¡¹ç›®ä¸ä½¿ç”¨å¦ä¸€ä¸ªé¡¹ç›®çš„模å—。虽然跨项目调用也有其他方法的å¯ä»¥å®žçŽ°ï¼Œä½† MMEngine 注册器æ供了更为简便的方法。 +MMEngine 的注册器支æŒå±‚级注册,利用该功能å¯å®žçŽ°è·¨é¡¹ç›®è°ƒç”¨ï¼Œå³å¯ä»¥åœ¨ä¸€ä¸ªé¡¹ç›®ä¸ä½¿ç”¨å¦ä¸€ä¸ªé¡¹ç›®çš„模å—。虽然跨项目调用也有其他方法的å¯ä»¥å®žçŽ°ï¼Œä½† MMEngine 注册器æ供了更为简便的方法。 为了方便跨库调用,MMEngine æ供了 20 ä¸ªæ ¹æ³¨å†Œå™¨ï¼š - RUNNERS: Runner 的注册器 - RUNNER_CONSTRUCTORS: Runner çš„æž„é€ å™¨ -- LOOPS: 管ç†è®ç»ƒã€éªŒè¯ä»¥åŠæµ‹è¯•æµç¨‹ï¼Œå¦‚ `EpochBasedTrainRunner` -- HOOKS: é’©å,如 `CheckpointHook`, `ProfilerHook` +- LOOPS: 管ç†è®ç»ƒã€éªŒè¯ä»¥åŠæµ‹è¯•æµç¨‹ï¼Œå¦‚ `EpochBasedTrainLoop` +- HOOKS: é’©å,如 `CheckpointHook`, `ParamSchedulerHook` - DATASETS: æ•°æ®é›† -- DATA_SAMPLERS: `Dataloader` çš„ `sampler`ï¼Œç”¨äºŽé‡‡æ ·æ•°æ® +- DATA_SAMPLERS: `DataLoader` çš„ `Sampler`ï¼Œç”¨äºŽé‡‡æ ·æ•°æ® - TRANSFORMS: å„ç§æ•°æ®é¢„处ç†ï¼Œå¦‚ `Resize`, `Reshape` - MODELS: 模型的å„ç§æ¨¡å— - MODEL_WRAPPERS: 模型的包装器,如 `MMDistributedDataParallel`,用于对分布å¼æ•°æ®å¹¶è¡Œ - WEIGHT_INITIALIZERS: æƒé‡åˆå§‹åŒ–的工具 -- OPTIMIZERS: 注册了 PyTorch ä¸æ‰€æœ‰çš„ `optimizer` 以åŠè‡ªå®šä¹‰çš„ `optimizer` +- OPTIMIZERS: 注册了 PyTorch ä¸æ‰€æœ‰çš„ `Optimizer` 以åŠè‡ªå®šä¹‰çš„ `Optimizer` - OPTIM_WRAPPER: 对 Optimizer 相关æ“作的å°è£…,如 `OptimWrapper`,`AmpOptimWrapper` - OPTIM_WRAPPER_CONSTRUCTORS: optimizer wrapper çš„æž„é€ å™¨ - PARAM_SCHEDULERS: å„ç§å‚数调度器,如 `MultiStepLR` @@ -247,146 +189,119 @@ MMEngine 的注册器支æŒè·¨é¡¹ç›®è°ƒç”¨ï¼Œå³å¯ä»¥åœ¨ä¸€ä¸ªé¡¹ç›®ä¸ä½¿ç”¨ - VISBACKENDS: å˜å‚¨è®ç»ƒæ—¥å¿—çš„åŽç«¯ï¼Œå¦‚ `LocalVisBackend`, `TensorboardVisBackend` - LOG_PROCESSORS: 控制日志的统计窗å£å’Œç»Ÿè®¡æ–¹æ³•ï¼Œé»˜è®¤ä½¿ç”¨ `LogProcessor`,如有特殊需求å¯è‡ªå®šä¹‰ `LogProcessor` -下é¢æˆ‘们以 OpenMMLab å¼€æºé¡¹ç›®ä¸ºä¾‹ä»‹ç»å¦‚何跨项目调用模å—。 - ### è°ƒç”¨çˆ¶èŠ‚ç‚¹çš„æ¨¡å— -`MMEngine` ä¸å®šä¹‰äº†æ¨¡å— `Conv2d`, +`MMEngine` ä¸å®šä¹‰æ¨¡å— `RReLU`,并往 `MODELS` æ ¹æ³¨å†Œå™¨æ³¨å†Œã€‚ ```python +import torch.nn as nn from mmengine import Registry, MODELS -MODELS.register_module() -class Conv2d(nn.Module): - pass -``` - -`MMDetection` ä¸å®šä¹‰äº†æ¨¡å— `RetinaNet`, +@MODELS.register_module() +class RReLU(nn.Module): + def __init__(self, lower=0.125, upper=0.333, inplace=False): + super().__init__() -```python -from mmengine import Registry, MODELS as MMENGINE_MODELS -# parent å‚数表示当å‰èŠ‚点的父节点,通过 parent å‚数实现层级结构 -# scope å‚æ•°å¯ä»¥ç†è§£ä¸ºå½“å‰èŠ‚ç‚¹çš„æ ‡å¿—ã€‚å¦‚æžœä¸ä¼ 入该å‚数,则 scope 被推导为当å‰æ–‡ä»¶æ‰€åœ¨ -# 包的包å,这里为 mmdet -MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmdet') - -@MMDET_MODELS.register_module() -class RetinaNet(nn.Module): - pass + def forward(self, x): + print('call RReLU.forward') + return x ``` -下图是 `MMEngine`, `MMDetection` 两个项目的注册器层级结构。 - - - -我们å¯ä»¥åœ¨ `MMDetection` ä¸è°ƒç”¨ `MMEngine` ä¸çš„模å—。 +å‡è®¾æœ‰ä¸ªé¡¹ç›®å« `MMAlpha`,它也定义了 `MODELS`,并设置其父节点为 `MMEngine` çš„ `MODELS`ï¼Œè¿™æ ·å°±å»ºç«‹äº†å±‚çº§ç»“æž„ã€‚ ```python -from mmdet.models import MODELS -# 创建 RetinaNet 实例 -model = MODELS.build(cfg=dict(type='RetinaNet')) -# 也å¯ä»¥åŠ mmdet å‰ç¼€ -model = MODELS.build(cfg=dict(type='mmdet.RetinaNet')) -# 创建 Conv2d 实例 -model = MODELS.build(cfg=dict(type='mmengine.Conv2d')) -# 也å¯ä»¥ä¸åŠ mmengine å‰ç¼€ -model = MODELS.build(cfg=dict(type='Conv2d')) +from mmengine import Registry, MODELS as MMENGINE_MODELS +MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmalpha') ``` -如果ä¸åŠ å‰ç¼€ï¼Œ`build` 方法首先查找当å‰èŠ‚点是å¦å˜åœ¨è¯¥æ¨¡å—,如果å˜åœ¨åˆ™è¿”回该模å—,å¦åˆ™ä¼šç»§ç»å‘上查找父节点甚至祖先节点直到找到该模å—ï¼Œå› æ¤ï¼Œå¦‚果当å‰èŠ‚点和父节点å˜åœ¨åŒä¸€æ¨¡å—并且希望调用父节点的模å—,我们需è¦æŒ‡å®š `scope` å‰ç¼€ã€‚需è¦æ³¨æ„的是,å‘上查找父节点甚至祖先节点的**å‰æ是父节点或者祖先节点的模å—已通过æŸç§æ–¹å¼è¢«å¯¼å…¥è¿›è€Œå®Œæˆæ³¨å†Œ**。例如,在上é¢è¿™ä¸ªç¤ºä¾‹ä¸ï¼Œä¹‹æ‰€ä»¥æ²¡æœ‰æ˜¾ç¤ºå¯¼å…¥çˆ¶èŠ‚点 `mmengine` ä¸çš„ `MODELS`ï¼Œæ˜¯å› ä¸ºé€šè¿‡ `from mmdet.models import MODELS` é—´æŽ¥è§¦å‘ `mmengine.MODELS` 完æˆæ¨¡å—的注册。 +下图是 `MMEngine` å’Œ `MMAlpha` 的注册器层级结构。 -上é¢å±•ç¤ºäº†å¦‚何使用å节点注册器构建模å—,但有时候我们希望ä¸å¡«åŠ å‰ç¼€ä¹Ÿèƒ½åœ¨çˆ¶èŠ‚点注册器ä¸æž„建å节点的模å—,目的是æ供通用的代ç ,é¿å…下游算法库é‡å¤é€ è½®å,该如何实现呢? +<div align="center"> + <img src="https://user-images.githubusercontent.com/58739961/185307159-26dc5771-df77-4d03-9203-9c4c3197befa.png"/> +</div> -å‡è®¾ MMEngine ä¸æœ‰ä¸€ä¸ª `build_model` 函数,该方法用于构建模型。 +å¯ä»¥è°ƒç”¨ [count_registered_modules](mmengine.registry.count_registered_modules) 函数打å°å·²æ³¨å†Œåˆ° MMEngine 的模å—以åŠå±‚级结构。 ```python -from mmengine.registry import MODELS - -def build_model(cfg): - model = MODELS.build(cfg) +from mmengine.registry import count_registered_modules +count_registered_modules() ``` -如果我们希望在 MMDetection ä¸è°ƒç”¨è¯¥å‡½æ•°æž„建 MMDetection 注册的模å—,那么我们需è¦å…ˆèŽ·å–一个 scope_name 为 'mmdet' çš„ [DefaultScope](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.registry.DefaultScope) 实例,该实例全局唯一。 +在 `MMAlpha` ä¸å®šä¹‰æ¨¡å— `LogSoftmax`,并往 `MMAlpha` çš„ `MODELS` 注册。 ```python -from mmengine import build_model -import mmdet.models # 通过 import çš„æ–¹å¼å°† mmdet ä¸çš„模å—导入注册器进而完æˆæ³¨å†Œ +@MODELS.register_module() +class LogSoftmax(nn.Module): + def __init__(self, dim=None): + super().__init__() -default_scope = DefaultScope.get_instance('my_experiment', scope_name='mmdet') -model = build_model(cfg=dict(type='RetinaNet')) + def forward(self, x): + print('call LogSoftmax.forward') + return x ``` -èŽ·å– `DefaultScope` 实例的目的是使 Registry çš„ build 方法会将 DefaultScope å称(mmdet)注册器节点作为注册器的起点,æ‰èƒ½åœ¨é…ç½®ä¸ä¸å¡«åŠ mmdet å‰ç¼€çš„情况下在 MMDetection 的注册器节点ä¸æ‰¾åˆ° RetinaNet 模å—,如若ä¸ç„¶ï¼Œç¨‹åºä¼šæŠ¥æ‰¾ä¸åˆ° RetinaNet 错误。 - -### è°ƒç”¨å…„å¼ŸèŠ‚ç‚¹çš„æ¨¡å— - -除了å¯ä»¥è°ƒç”¨çˆ¶èŠ‚点的模å—,也å¯ä»¥è°ƒç”¨å…„弟节点的模å—。 - -`MMClassification` ä¸å®šä¹‰äº†æ¨¡å— `ResNet`, +在 `MMAlpha` ä¸ä½¿ç”¨é…置调用 `LogSoftmax` ```python -from mmengine.registry import Registry, MODELS -MODELS = Registry('model', parent=MMENGINE_MODELS) - -@MODELS.register_module() -class ResNet(nn.Module): - pass +model = MODELS.build(cfg=dict(type='LogSoftmax')) ``` -下图是 `MMEngine`, `MMDetection`, `MMClassification` 三个项目的注册器层级结构。 - - - -我们å¯ä»¥åœ¨ `MMDetection` ä¸è°ƒç”¨ `MMClassification` 定义的模å—, +也å¯ä»¥åœ¨ `MMAlpha` ä¸è°ƒç”¨çˆ¶èŠ‚点 `MMEngine` 的模å—。 ```python -from mmdet.models import MODELS -model = MODELS.build(cfg=dict(type='mmcls.ResNet')) +model = MODELS.build(cfg=dict(type='RReLU', lower=0.2)) +# 也å¯ä»¥åŠ scope +model = MODELS.build(cfg=dict(type='mmengine.RReLU')) ``` -也å¯ä»¥åœ¨ `MMClassification` ä¸è°ƒç”¨ `MMDetection` 定义的模å—。 +如果ä¸åŠ å‰ç¼€ï¼Œ`build` 方法首先查找当å‰èŠ‚点是å¦å˜åœ¨è¯¥æ¨¡å—,如果å˜åœ¨åˆ™è¿”回该模å—,å¦åˆ™ä¼šç»§ç»å‘上查找父节点甚至祖先节点直到找到该模å—ï¼Œå› æ¤ï¼Œå¦‚果当å‰èŠ‚点和父节点å˜åœ¨åŒä¸€æ¨¡å—并且希望调用父节点的模å—,我们需è¦æŒ‡å®š `scope` å‰ç¼€ã€‚ ```python -from mmcls.models import MODELS -model = MODELS.build(cfg=dict(type='mmdet.RetinaNet')) +import torch +input = torch.randn(2) +output = model(input) +# call RReLU.forward +print(output) +# tensor([-1.5774, -0.5850]) ``` -调用éžæœ¬èŠ‚点或父节点的模å—需è¦åœ¨ `type` ä¸æŒ‡å®š `scope` å‰ç¼€ã€‚ - -注册器除了支æŒä¸¤å±‚结构,三层甚至更多层结构也是支æŒçš„。 +### è°ƒç”¨å…„å¼ŸèŠ‚ç‚¹çš„æ¨¡å— -å‡è®¾æˆ‘们新建了一个项目 `DetPlus`,它的 `MODELS` 注册器继承自 `MMDetection` çš„ `MODELS`,并且它会用到 `MMClassification` ä¸çš„ `ResNet` 模å—。 +除了å¯ä»¥è°ƒç”¨çˆ¶èŠ‚点的模å—,也å¯ä»¥è°ƒç”¨å…„弟节点的模å—。 -`DetPlus` ä¸å®šä¹‰äº†æ¨¡å— `MetaNet`, +å‡è®¾æœ‰å¦ä¸€ä¸ªé¡¹ç›®å« `MMBeta`,它和 `MMAlpha` ä¸€æ ·ï¼Œå®šä¹‰äº† `MODELS` 以åŠè®¾ç½®å…¶çˆ¶èŠ‚点为 `MMEngine` çš„ `MODELS`。 ```python -from mmengine.registry import Registry -from mmdet.model import MODELS as MMDET_MODELS -MODELS = Registry('model', parent=MMDET_MODELS, scope='det_plus') - -@MODELS.register_module() -class MetaNet(nn.Module): - pass +from mmengine import Registry, MODELS as MMENGINE_MODELS +MODELS = Registry('model', parent=MMENGINE_MODELS, scope='mmbeta') ``` -下图是 `MMEngine`, `MMDetection`, `MMClassification` ä»¥åŠ `DetPlus` 四个项目的注册器层级结构。 +下图是 MMEngine,MMAlpha å’Œ MMBeta 的注册器层级结构。 - +<div align="center"> + <img src="https://user-images.githubusercontent.com/58739961/185307738-9ddbce2d-f8b5-40c4-bf8f-603830ccc0dc.png"/> +</div> -我们å¯ä»¥åœ¨ `DetPlus` ä¸è°ƒç”¨ `MMDetection` 或者 `MMClassification` ä¸çš„模å—, +在 `MMBeta` ä¸è°ƒç”¨å…„弟节点 `MMAlpha` 的模å—, ```python -from detplus.model import MODELS -# å¯ä»¥ä¸æä¾› mmdet å‰ç¼€ï¼Œå¦‚果在 detplus 找ä¸åˆ°åˆ™ä¼šå‘上在 mmdet ä¸æŸ¥æ‰¾ -model = MODELS.build(cfg=dict(type='mmdet.RetinaNet')) -# 调用兄弟节点的模å—需æä¾› mmcls å‰ç¼€ï¼Œä½†ä¹Ÿå¯ä»¥è®¾ç½® default_scope å‚æ•° -model = MODELS.build(cfg=dict(type='mmcls.ResNet')) +model = MODELS.build(cfg=dict(type='mmalpha.LogSoftmax')) +output = model(input) +# call LogSoftmax.forward +print(output) +# tensor([-1.5774, -0.5850]) ``` -也å¯ä»¥åœ¨ `MMClassification` ä¸è°ƒç”¨ `DetPlus` 的模å—。 +调用兄弟节点的模å—需è¦åœ¨ `type` ä¸æŒ‡å®š `scope` å‰ç¼€ï¼Œæ‰€ä»¥ä¸Šé¢çš„é…置需è¦åŠ å‰ç¼€ `mmalpha`。 + +如果需è¦è°ƒç”¨å…„弟节点的数个模å—,æ¯ä¸ªæ¨¡å—éƒ½åŠ å‰ç¼€ï¼Œè¿™éœ€è¦åšå¤§é‡çš„修改。于是 `MMEngine` 引入了 [DefaultScope](mmengine.registry.DefaultScope),`Registry` 借助它å¯ä»¥å¾ˆæ–¹ä¾¿åœ°æ”¯æŒä¸´æ—¶åˆ‡æ¢å½“å‰èŠ‚点为指定的节点。 + +如果需è¦ä¸´æ—¶åˆ‡æ¢å½“å‰èŠ‚点为指定的节点,åªéœ€åœ¨ `cfg` 设置 `_scope_` 为指定节点的作用域。 ```python -from mmcls.models import MODELS -# 需è¦æ³¨æ„å‰ç¼€çš„顺åºï¼Œ'detplus.mmdet.ResNet' 是ä¸æ£ç¡®çš„ -model = MODELS.build(cfg=dict(type='mmdet.detplus.MetaNet')) +model = MODELS.build(cfg=dict(type='LogSoftmax', _scope_='mmalpha')) +output = model(input) +# call LogSoftmax.forward +print(output) +# tensor([-1.5774, -0.5850]) ``` diff --git a/mmengine/registry/registry.py b/mmengine/registry/registry.py index 286e07d021dceb2d646e1859c30106ee8fef39a9..0e0af070fab7ed2a67275b92297abe30d4dca0fa 100644 --- a/mmengine/registry/registry.py +++ b/mmengine/registry/registry.py @@ -190,45 +190,45 @@ class Registry: scope (str): The target scope. Examples: - >>> from mmengine.registry import Registry, DefaultScope, MODELS - >>> import time - >>> # External Registry - >>> MMDET_MODELS = Registry('mmdet_model', scope='mmdet', - >>> parent=MODELS) - >>> MMCLS_MODELS = Registry('mmcls_model', scope='mmcls', - >>> parent=MODELS) - >>> # Local Registry - >>> CUSTOM_MODELS = Registry('custom_model', scope='custom', - >>> parent=MODELS) - >>> - >>> # Initiate DefaultScope - >>> DefaultScope.get_instance(f'scope_{time.time()}', - >>> scope_name='custom') - >>> # Check default scope - >>> DefaultScope.get_current_instance().scope_name - custom - >>> # Switch to mmcls scope and get `MMCLS_MODELS` registry. - >>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmcls') as registry: # noqa: E501 - >>> DefaultScope.get_current_instance().scope_name - mmcls - >>> registry.scope - mmcls - >>> # Nested switch scope - >>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmdet') as mmdet_registry: # noqa: E501 - >>> DefaultScope.get_current_instance().scope_name - mmdet - >>> mmdet_registry.scope - mmdet - >>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmcls') as mmcls_registry: # noqa: E501 - >>> DefaultScope.get_current_instance().scope_name - mmcls - >>> mmcls_registry.scope - mmcls - >>> - >>> # Check switch back to original scope. - >>> DefaultScope.get_current_instance().scope_name - custom - """ + >>> from mmengine.registry import Registry, DefaultScope, MODELS + >>> import time + >>> # External Registry + >>> MMDET_MODELS = Registry('mmdet_model', scope='mmdet', + >>> parent=MODELS) + >>> MMCLS_MODELS = Registry('mmcls_model', scope='mmcls', + >>> parent=MODELS) + >>> # Local Registry + >>> CUSTOM_MODELS = Registry('custom_model', scope='custom', + >>> parent=MODELS) + >>> + >>> # Initiate DefaultScope + >>> DefaultScope.get_instance(f'scope_{time.time()}', + >>> scope_name='custom') + >>> # Check default scope + >>> DefaultScope.get_current_instance().scope_name + custom + >>> # Switch to mmcls scope and get `MMCLS_MODELS` registry. + >>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmcls') as registry: + >>> DefaultScope.get_current_instance().scope_name + mmcls + >>> registry.scope + mmcls + >>> # Nested switch scope + >>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmdet') as mmdet_registry: + >>> DefaultScope.get_current_instance().scope_name + mmdet + >>> mmdet_registry.scope + mmdet + >>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmcls') as mmcls_registry: + >>> DefaultScope.get_current_instance().scope_name + mmcls + >>> mmcls_registry.scope + mmcls + >>> + >>> # Check switch back to original scope. + >>> DefaultScope.get_current_instance().scope_name + custom + """ # noqa: E501 from ..logging import print_log # Switch to the given scope temporarily. If the corresponding registry diff --git a/mmengine/registry/utils.py b/mmengine/registry/utils.py index 3ace5f5f7d396a4a1a72a4eb8d701f179cd3c58b..5184e2ebedbba5dd15060da61cffea09709e52b7 100644 --- a/mmengine/registry/utils.py +++ b/mmengine/registry/utils.py @@ -54,10 +54,20 @@ def count_registered_modules(save_path: Optional[str] = None, Args: save_path (str, optional): Path to save the json file. - verbose (bool): Whether to print log. Default: True + verbose (bool): Whether to print log. Defaults to True. + Returns: dict: Statistic results of all registered modules. """ + # import modules to trigger registering + import mmengine.dataset + import mmengine.evaluator + import mmengine.hooks + import mmengine.model + import mmengine.optim + import mmengine.runner + import mmengine.visualization # noqa: F401 + registries_info = {} # traverse all registries in MMEngine for item in dir(root): diff --git a/tests/test_registry/test_registry_utils.py b/tests/test_registry/test_registry_utils.py index ce903f0b9045b9ad9d65aa6b882cf887a7cd2a33..35670435d7574982b0101df78e8f2ee7a7e511f9 100644 --- a/tests/test_registry/test_registry_utils.py +++ b/tests/test_registry/test_registry_utils.py @@ -1,10 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp from tempfile import TemporaryDirectory -from unittest import TestCase +from unittest import TestCase, skipIf from mmengine.registry import (Registry, count_registered_modules, root, traverse_registry_tree) +from mmengine.utils import is_installed class TestUtils(TestCase): @@ -42,6 +43,7 @@ class TestUtils(TestCase): # result from any node should be the same self.assertEqual(result, result_leaf) + @skipIf(not is_installed('torch'), 'tests requires torch') def test_count_all_registered_modules(self): temp_dir = TemporaryDirectory() results = count_registered_modules(temp_dir.name, verbose=True)