diff --git a/docs/zh_cn/tutorials/registry.md b/docs/zh_cn/tutorials/registry.md index ced73f25162fa6984e4dd14e39f6b29686c65511..1e0ff3a2f5419f0aac1672167fa26fee073e22ec 100644 --- a/docs/zh_cn/tutorials/registry.md +++ b/docs/zh_cn/tutorials/registry.md @@ -262,7 +262,7 @@ class RetinaNet(nn.Module):  -我们å¯ä»¥åœ¨ `MMDetection` ä¸è°ƒç”¨ `MMEngine` ä¸æ¨¡å—。 +我们å¯ä»¥åœ¨ `MMDetection` ä¸è°ƒç”¨ `MMEngine` ä¸çš„模å—。 ```python from mmdet.models import MODELS @@ -278,6 +278,29 @@ model = MODELS.build(cfg=dict(type='Conv2d')) 如果ä¸åŠ å‰ç¼€ï¼Œ`build` 方法首先查找当å‰èŠ‚点是å¦å˜åœ¨è¯¥æ¨¡å—,如果å˜åœ¨åˆ™è¿”回该模å—,å¦åˆ™ä¼šç»§ç»å‘上查找父节点甚至祖先节点直到找到该模å—ï¼Œå› æ¤ï¼Œå¦‚果当å‰èŠ‚点和父节点å˜åœ¨åŒä¸€æ¨¡å—并且希望调用父节点的模å—,我们需è¦æŒ‡å®š `scope` å‰ç¼€ã€‚需è¦æ³¨æ„的是,å‘上查找父节点甚至祖先节点的**å‰æ是父节点或者祖先节点的模å—已通过æŸç§æ–¹å¼è¢«å¯¼å…¥è¿›è€Œå®Œæˆæ³¨å†Œ**。例如,在上é¢è¿™ä¸ªç¤ºä¾‹ä¸ï¼Œä¹‹æ‰€ä»¥æ²¡æœ‰æ˜¾ç¤ºå¯¼å…¥çˆ¶èŠ‚点 `mmengine` ä¸çš„ `MODELS`ï¼Œæ˜¯å› ä¸ºé€šè¿‡ `from mmdet.models import MODELS` é—´æŽ¥è§¦å‘ `mmengine.MODELS` 完æˆæ¨¡å—的注册。 +上é¢å±•ç¤ºäº†å¦‚何使用å节点注册器构建模å—,但有时候我们希望ä¸å¡«åŠ å‰ç¼€ä¹Ÿèƒ½åœ¨çˆ¶èŠ‚点注册器ä¸æž„建å节点的模å—,目的是æ供通用的代ç ,é¿å…下游算法库é‡å¤é€ è½®å,该如何实现呢? + +å‡è®¾ MMEngine ä¸æœ‰ä¸€ä¸ª `build_model` 函数,该方法用于构建模型。 + +```python +from mmengine.registry import MODELS + +def build_model(cfg): + model = MODELS.build(cfg) +``` + +如果我们希望在 MMDetection ä¸è°ƒç”¨è¯¥å‡½æ•°æž„建 MMDetection 注册的模å—,那么我们需è¦å…ˆèŽ·å–一个 scope_name 为 'mmdet' çš„ [DefaultScope](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.registry.DefaultScope) 实例,该实例全局唯一。 + +```python +from mmengine import build_model +import mmdet.models # 通过 import çš„æ–¹å¼å°† mmdet ä¸çš„模å—导入注册器进而完æˆæ³¨å†Œ + +default_scope = DefaultScope.get_instance('my_experiment', scope_name='mmdet') +model = build_model(cfg=dict(type='RetinaNet')) +``` + +èŽ·å– `DefaultScope` 实例的目的是使 Registry çš„ build 方法会将 DefaultScope å称(mmdet)注册器节点作为注册器的起点,æ‰èƒ½åœ¨é…ç½®ä¸ä¸å¡«åŠ mmdet å‰ç¼€çš„情况下在 MMDetection 的注册器节点ä¸æ‰¾åˆ° RetinaNet 模å—,如若ä¸ç„¶ï¼Œç¨‹åºä¼šæŠ¥æ‰¾ä¸åˆ° RetinaNet 错误。 + ### è°ƒç”¨å…„å¼ŸèŠ‚ç‚¹çš„æ¨¡å— é™¤äº†å¯ä»¥è°ƒç”¨çˆ¶èŠ‚点的模å—,也å¯ä»¥è°ƒç”¨å…„弟节点的模å—。 @@ -311,16 +334,7 @@ from mmcls.models import MODELS model = MODELS.build(cfg=dict(type='mmdet.RetinaNet')) ``` -调用éžæœ¬èŠ‚点的模å—需è¦æŒ‡å®šåœ¨ `type` ä¸æŒ‡å®š `scope` å‰ç¼€ï¼Œå¦‚æžœä¸æƒ³æŒ‡å®šï¼Œæˆ‘们å¯ä»¥åˆ›å»ºä¸€ä¸ªå…¨å±€å˜é‡ `default_scope` 并将 `scope_name` 设置为 'mmdet',`Registry` 会将 `scope_name` 对应的 `registry` ä½œä¸ºå½“å‰ `Registry` 并调用 `build` 方法。 - -```python -from mmengine.registry import DefaultScope, MODELS - -# 调用注册在 mmdet ä¸çš„ RetinaNet -default_scope = DefaultScope.get_instance( - 'my_experiment', scope_name='mmdet') -model = MODELS.build(cfg=dict(type='RetinaNet')) -``` +调用éžæœ¬èŠ‚点或父节点的模å—需è¦åœ¨ `type` ä¸æŒ‡å®š `scope` å‰ç¼€ã€‚ 注册器除了支æŒä¸¤å±‚结构,三层甚至更多层结构也是支æŒçš„。 @@ -358,10 +372,4 @@ model = MODELS.build(cfg=dict(type='mmcls.ResNet')) from mmcls.models import MODELS # 需è¦æ³¨æ„å‰ç¼€çš„顺åºï¼Œ'detplus.mmdet.ResNet' 是ä¸æ£ç¡®çš„ model = MODELS.build(cfg=dict(type='mmdet.detplus.MetaNet')) - -# 如果希望默认从 detplus 构建模型,设置å¯ä»¥ default_scope -from mmengine.registry import DefaultScope -default_scope = DefaultScope.get_instance( - 'my_experiment', scope_name='detplus') -model = MODELS.build(cfg=dict(type='MetaNet', default_scope='detplus')) ``` diff --git a/docs/zh_cn/tutorials/visualization.md b/docs/zh_cn/tutorials/visualization.md new file mode 100644 index 0000000000000000000000000000000000000000..4aa7d6ecba24628995f448c36f6149a77a6107c8 --- /dev/null +++ b/docs/zh_cn/tutorials/visualization.md @@ -0,0 +1,300 @@ +# å¯è§†åŒ– (Visualization) + +## 概述 + +å¯è§†åŒ–å¯ä»¥ç»™æ·±åº¦å¦ä¹ 的模型è®ç»ƒå’Œæµ‹è¯•è¿‡ç¨‹æ供直观解释。在 OpenMMLab 算法库ä¸ï¼Œæˆ‘们期望å¯è§†åŒ–功能的设计能满足以下需求: + +- æ供丰富的开箱å³ç”¨å¯è§†åŒ–功能,能够满足大部分计算机视觉å¯è§†åŒ–任务 +- 高扩展性,å¯è§†åŒ–åŠŸèƒ½é€šå¸¸å¤šæ ·åŒ–ï¼Œåº”è¯¥èƒ½å¤Ÿé€šè¿‡ç®€å•æ‰©å±•å®žçŽ°å®šåˆ¶éœ€æ±‚ +- 能够在è®ç»ƒå’Œæµ‹è¯•æµç¨‹çš„ä»»æ„点ä½è¿›è¡Œå¯è§†åŒ– +- OpenMMLab å„个算法库具有统一å¯è§†åŒ–接å£ï¼Œåˆ©äºŽç”¨æˆ·ç†è§£å’Œç»´æŠ¤ + +基于上述需求,OpenMMLab 2.0 引入了å¯è§†åŒ–对象 Visualizer å’Œå„个å¯è§†åŒ–å˜å‚¨åŽç«¯ VisBackend 如 `LocalVisBackend`ã€`WandbVisBackend` å’Œ `TensorboardVisBackend` ç‰ã€‚æ¤å¤„çš„å¯è§†åŒ–ä¸ä»…仅包括图片数æ®æ ¼å¼ï¼Œè¿˜åŒ…括é…置内容ã€æ ‡é‡å’Œæ¨¡åž‹å›¾ç‰æ•°æ®çš„å¯è§†åŒ–。 + +- 为了方便调用,Visualizer æ供的接å£å®žçŽ°äº†ç»˜åˆ¶å’Œå˜å‚¨çš„功能。å¯è§†åŒ–å˜å‚¨åŽç«¯ VisBackend 作为 Visualizer 的内部属性,会在需è¦çš„时候被 Visualizer 调用,将数æ®å˜åˆ°ä¸åŒçš„åŽç«¯ +- 考虑到绘制åŽä¼šå¸Œæœ›å˜å‚¨åˆ°å¤šä¸ªåŽç«¯ï¼ŒVisualizer å¯ä»¥é…置多个 VisBackend,当用户调用 Visualizer çš„å˜å‚¨æŽ¥å£æ—¶å€™ï¼ŒVisualizer 内部会é历的调用 VisBackend å˜å‚¨æŽ¥å£ + +两者的 UML 关系图如下 + +<div align="center"> + <img src="https://user-images.githubusercontent.com/17425982/163327736-f7cb3b16-ef07-46bc-982a-3cc7495e6c82.png" > +</div> + +## å¯è§†åŒ–对象 Visualizer + +### 接å£è¯´æ˜Ž + +å¯è§†åŒ–对象 Visualizer 对外æ供了所有接å£ã€‚å¯ä»¥å°†å…¶æŽ¥å£åˆ†æˆ 3 大类,如下所示 + +**(1) 绘制相关接å£** + +- [draw_bboxes](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_bboxes) 绘制å•ä¸ªæˆ–多个边界框 +- [draw_points](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_points) 绘制å•ä¸ªæˆ–多个点 +- [draw_texts](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_texts) 绘制å•ä¸ªæˆ–多个文本框 +- [draw_lines](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.lines) 绘制å•ä¸ªæˆ–多个线段 +- [draw_circles](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_circles) 绘制å•ä¸ªæˆ–多个圆 +- [draw_polygons](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_polygons) 绘制å•ä¸ªæˆ–多个多边形 +- [draw_binary_masks](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_binary_mask) 绘制å•ä¸ªæˆ–多个二值掩ç +- [draw_featmap](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_featmap) 绘制特å¾å›¾ï¼Œé™æ€æ–¹æ³• + +上述接å£é™¤äº† `draw_featmap` 外都å¯ä»¥é“¾å¼è°ƒç”¨ï¼Œå› 为该方法调用åŽå¯èƒ½ä¼šå¯¼è‡´å›¾ç‰‡å°ºå¯¸å‘生改å˜ã€‚为了é¿å…给用户带æ¥å›°æ‰°ï¼Œ `draw_featmap` 被设置为é™æ€æ–¹æ³•ã€‚ + +当用户想先绘制边界框,在æ¤åŸºç¡€ä¸Šç»˜åˆ¶æ–‡æœ¬ï¼Œç»˜åˆ¶çº¿æ®µçš„时候,å¯ä»¥é€šè¿‡é“¾å¼è°ƒç”¨å®žçŽ°ï¼š + +```python +visualizer.set_image(image) +visualizer.draw_bboxes(...).draw_texts(...).draw_lines(...) +visualizer.show() # å¯è§†åŒ–绘制结果 +``` + +特å¾å›¾å¯è§†åŒ–是一个常è§çš„功能,用户通过调用 `draw_featmap` å¯è§†åŒ–特å¾å›¾ï¼Œå…¶å‚数定义为: + +```python +@staticmethod +def draw_featmap(featmap: torch.Tensor, # è¾“å…¥æ ¼å¼è¦æ±‚为 CHW + overlaid_image: Optional[np.ndarray] = None, # 如果åŒæ—¶è¾“入了 image æ•°æ®ï¼Œåˆ™ç‰¹å¾å›¾ä¼šå åŠ åˆ° image 上绘制 + channel_reduction: Optional[str] = 'squeeze_mean', # 多个通é“压缩为å•é€šé“çš„ç–ç•¥ + topk: int = 10, # å¯é€‰æ‹©æ¿€æ´»åº¦æœ€é«˜çš„ topk 个特å¾å›¾æ˜¾ç¤º + arrangement: Tuple[int, int] = (5, 2), # 多通é“å±•å¼€ä¸ºå¤šå¼ å›¾æ—¶å€™å¸ƒå±€ + resize_shape:Optional[tuple] = None, # å¯ä»¥æŒ‡å®š resize_shape å‚æ•°æ¥ç¼©æ”¾ç‰¹å¾å›¾ + alpha: float = 0.5) -> np.ndarray: # 图片和特å¾å›¾ç»˜åˆ¶çš„å åŠ æ¯”ä¾‹ +``` + +特å¾å›¾å¯è§†åŒ–功能较多,目å‰ä¸æ”¯æŒ Batch 输入,其功能å¯ä»¥å½’纳如下 + +- 输入的 Tensor 一般是包括多个通é“的,channel_reduction å‚æ•°å¯ä»¥å°†å¤šä¸ªé€šé“压缩为å•é€šé“,然åŽå’Œå›¾ç‰‡è¿›è¡Œå åŠ æ˜¾ç¤º + - `squeeze_mean` 将输入的 C 维度采用 mean 函数压缩为一个通é“,输出维度å˜æˆ (1, H, W) + - `select_max` 从输入的 C 维度ä¸å…ˆåœ¨ç©ºé—´ç»´åº¦ sum,维度å˜æˆ (C, ),然åŽé€‰æ‹©å€¼æœ€å¤§çš„é€šé“ + - `None` 表示ä¸éœ€è¦åŽ‹ç¼©ï¼Œæ¤æ—¶å¯ä»¥é€šè¿‡ topk å‚æ•°å¯é€‰æ‹©æ¿€æ´»åº¦æœ€é«˜çš„ topk 个特å¾å›¾æ˜¾ç¤º + +- 在 channel_reduction å‚数为 None 的情况下,topk å‚数生效,其会按照激活度排åºé€‰æ‹© topk 个通é“,然åŽå’Œå›¾ç‰‡è¿›è¡Œå åŠ æ˜¾ç¤ºï¼Œå¹¶ä¸”æ¤æ—¶ä¼šé€šè¿‡ arrangement å‚数指定显示的布局 + - 如果 topk ä¸æ˜¯ -1,则会按照激活度排åºé€‰æ‹© topk 个通é“显示 + - 如果 topk = -1,æ¤æ—¶é€šé“ C 必须是 1 或者 3 表示输入数æ®æ˜¯å›¾ç‰‡ï¼Œå¦åˆ™æŠ¥é”™æ示用户应该设置 `channel_reduction`æ¥åŽ‹ç¼©é€šé“。 + +- 考虑到输入的特å¾å›¾é€šå¸¸éžå¸¸å°ï¼Œå‡½æ•°æ”¯æŒè¾“å…¥ `resize_shape` å‚数,方便将特å¾å›¾è¿›è¡Œä¸Šé‡‡æ ·åŽè¿›è¡Œå¯è§†åŒ–。 + +**(2) å˜å‚¨ç›¸å…³æŽ¥å£** + +- [add_config](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_config) 写é…置到特定å˜å‚¨åŽç«¯ +- [add_graph](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_graph) 写模型图到特定å˜å‚¨åŽç«¯ +- [add_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_image) 写图片到特定å˜å‚¨åŽç«¯ +- [add_scalar](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_scalar) å†™æ ‡é‡åˆ°ç‰¹å®šå˜å‚¨åŽç«¯ +- [add_scalars](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_scalars) ä¸€æ¬¡æ€§å†™å¤šä¸ªæ ‡é‡åˆ°ç‰¹å®šå˜å‚¨åŽç«¯ +- [add_datasample](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_datasample) å„个下游库绘制 datasample æ•°æ®çš„æŠ½è±¡æŽ¥å£ + +以 add å‰ç¼€å¼€å¤´çš„接å£è¡¨ç¤ºå˜å‚¨æŽ¥å£ã€‚datasample 是 OpenMMLab 2.0 架构ä¸è®¾è®¡çš„å„个下游库统一的抽象数æ®æŽ¥å£ï¼Œè€Œ `add_datasample` 接å£å¯ä»¥ç›´æŽ¥å¤„ç†è¯¥æ•°æ®æ ¼å¼ï¼Œä¾‹å¦‚å¯è§†åŒ–预测结果ã€å¯è§†åŒ– Dataset 或者 DataLoader 输出ã€å¯è§†åŒ–ä¸é—´é¢„测结果ç‰ç‰éƒ½å¯ä»¥ç›´æŽ¥è°ƒç”¨ä¸‹æ¸¸åº“é‡å†™çš„ `add_datasample` 接å£ã€‚ + +所有下游库都必须è¦ç»§æ‰¿ Visualizer 并实现 `add_datasample` 接å£ã€‚以 MMDetection 为例,应该继承并通过该接å£å®žçŽ°ç›®æ ‡æ£€æµ‹ä¸æ‰€æœ‰é¢„置任务的å¯è§†åŒ–åŠŸèƒ½ï¼Œä¾‹å¦‚ç›®æ ‡æ£€æµ‹ã€å®žä¾‹åˆ†å‰²ã€å…¨æ™¯åˆ†å‰²ä»»åŠ¡ç»“果的绘制和å˜å‚¨ã€‚ + +**(3) 其余功能性接å£** + +- [set_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.set_image) 设置原始图片数æ®ï¼Œé»˜è®¤è¾“å…¥å›¾ç‰‡æ ¼å¼ä¸º RGB +- [get_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.get_image) 获å–绘制åŽçš„ Numpy æ ¼å¼å›¾ç‰‡æ•°æ®ï¼Œé»˜è®¤è¾“å‡ºæ ¼å¼ä¸º RGB +- [show](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.show) å¯è§†åŒ– +- [get_backend](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.get_backend) 通过 name 获å–特定å˜å‚¨åŽç«¯ +- [close](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.close) å…³é—所有已ç»æ‰“开的资æºï¼ŒåŒ…括 VisBackend + +### ä½¿ç”¨æ ·ä¾‹ + +**(1) 在任æ„ä½ç½®èŽ·å– visualizer** + +为了确ä¿å¯è§†åŒ–对象 Visualizer 能够在任何地方被调用,设计上将其继承自 `ManagerMixin` 类,转å˜ä¸ºå…¨å±€å”¯ä¸€å¯¹è±¡ï¼Œç”¨æˆ·åˆå§‹åŒ– `Visualizer` 时必须è¦è°ƒç”¨ `visualizer.get_instance()` 方法æ‰èƒ½ä½¿å®žä¾‹å¯¹è±¡å…·å¤‡å…¨å±€å”¯ä¸€æ€§ã€‚一旦实例化完æˆï¼ŒåŽç»å¯ä»¥åœ¨ä»»æ„代ç ä½ç½®é€šè¿‡ `Visualizer.get_current_instance()` æ¥èŽ·å–å¯è§†åŒ–对象。 + +以 MMDetection 为例,å‡è®¾ `DetLocalVisualizer` 类继承自 `Visualizer`,并实现了 `add_datasample` 接å£ã€‚é…置文件写法为: + +```python +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer') +``` +```python +# 内部会调用 get_instance() 进行全局唯一实例化 +VISUALIZERS.build(cfg.visualizer) +``` + +通过上述代ç 实例化åŽï¼Œå¯ä»¥åœ¨ä»»æ„ä½ç½®è°ƒç”¨ `get_current_instance` 方法æ¥èŽ·å– visualizer + +```python +# ä»»æ„代ç ä½ç½®èŽ·å– visualizer +visualizer = Visualizer.get_current_instance() +``` + +如果用户直接使用了 MMEngine 或者下游库ä¸çš„ Runnerï¼Œåˆ™æ— éœ€è¿›è¡Œé¢å¤–çš„å®žä¾‹åŒ–ï¼Œå› ä¸ºåœ¨ Runner çš„åˆå§‹åŒ–函数ä¸ä¼šè‡ªåŠ¨åˆ›å»ºå…¨å±€å”¯ä¸€çš„ visualizer。 + +**(2) 将数æ®å†™å…¥è‡³ç‰¹å®šåŽç«¯** + +在获å–到 visualizer åŽï¼Œå¯ä»¥è°ƒç”¨ `add_xxx` 接å£å°†å„类数æ®å†™å…¥åˆ°ç‰¹å®šåŽç«¯ + +```python +# 绘制 datasample,并ä¿å˜åˆ°æœ¬åœ°å˜å‚¨åŽç«¯ +visualizer.add_datasample('demo_image', image, gt_sample, pred_sample, step=1) +# 直接本地窗å£æ˜¾ç¤ºï¼Œè€Œæ— 需å˜å‚¨ +visualizer.add_datasample('demo_image', image, gt_sample, pred_sample, show=True) + +# 写图片 +visualizer.add_image('demo_image', image, step=1) + +# 写模型精度值 +visualizer.add_scalar('mAP', 0.9, step=1) +visualizer.add_scalars({'loss': 1.2, 'acc': 0.8}, step=1) + +# 写é…置文件 +visualizer.add_config(cfg) + +# 写模型图 +visualizer.add_graph(model, data_batch) +``` + +**(3) 特å¾å›¾å¯è§†åŒ–** + +通过 `channel_reduction` å‚数压缩或者选择特å¾å›¾ï¼Œå¹¶æ˜¾ç¤ºåˆ°æœ¬åœ°çª—å£ + +```python +featmap = ... # CHW shape çš„ tensor + +# 压缩 +feat_img = visualizer.draw_featmap(featmap, channel_reduction='squeeze_mean') +visualizer.show(feat_img) + +# 选择激活度最高的通é“显示 +feat_img = visualizer.draw_featmap(featmap, channel_reduction='select_max') +visualizer.show(feat_img) +``` + +å åŠ å›¾ç‰‡æ˜¾ç¤º + +```python +featmap = ... # CHW shape çš„ tensor +img = ... # 如果 featmap å’Œ img 空间尺寸ä¸ä¸€è‡´ï¼Œå†…部会对 featmap 进行æ’值 + +# 压缩 +feat_img = visualizer.draw_featmap(featmap, img, channel_reduction='squeeze_mean') +visualizer.show(feat_img) + +# 选择激活度最高的通é“显示 +feat_img = visualizer.draw_featmap(featmap, img, channel_reduction='select_max') +visualizer.show(feat_img) +``` + +通过 `topk` å‚数选择指定个数的通é“æ˜¾ç¤ºï¼Œå¹¶æ˜¾ç¤ºåˆ°æœ¬åœ°çª—å£ + +```python +featmap= ... # CHW shape çš„ tensor + +# topk,并以 2 è¡Œ 5 列模å¼æ˜¾ç¤º +feat_img = visualizer.draw_featmap(featmap, channel_reduction=None, topk=10, arrangement=(2, 5)) +visualizer.show(feat_img) + +# topk,并以 5 è¡Œ 2 列模å¼æ˜¾ç¤º +feat_img = visualizer.draw_featmap(featmap, channel_reduction=None, topk=10, arrangement=(5, 2)) +visualizer.show(feat_img) +``` + +通过 `resize_shape` 缩放显示的特å¾å›¾ + +```python +featmap = ... # CHW shape çš„ tensor + +# 压缩 +feat_img = visualizer.draw_featmap(featmap, channel_reduction='squeeze_mean', resize_shape=(224, 224)) +visualizer.show(feat_img) +``` + +å˜å‚¨ç‰¹å¾å›¾åˆ°å¯è§†åŒ–åŽç«¯ + +```python +featmap = ... # CHW shape çš„ tensor + +# 压缩 +feat_img = visualizer.draw_featmap(featmap, channel_reduction='squeeze_mean', resize_shape=(224, 224)) +# å˜å‚¨ +visualizer.add_image('feat_image', feat_img) +``` + +**(4) 远程窗å£æ˜¾ç¤º** + +用户å¯ä»¥æŒ‡å®š Wandb ã€Tensorboard 或者自定义具备远程窗å£æ˜¾ç¤ºçš„åŽç«¯æ¥ä¿å˜æ•°æ®ï¼Œç„¶åŽåœ¨æµè§ˆå™¨ä¸Šæ˜¾ç¤ºã€‚以 Wandb 为例,典型é…置为: + +```python +vis_backends = [dict(type='WandbVisBackend')] +visualizer = dict( + type='DetWandbVisualizer', vis_backends=vis_backends, name='visualizer') +``` + +使用方法和上é¢å®Œå…¨ä¸€è‡´ã€‚需è¦ç‰¹åˆ«æ³¨æ„的是由于 Wandb 绘制的数æ®æ— 法和 `LocalVisBackend` åŽç«¯å…¼å®¹ï¼Œæ‰€ä»¥å½“ `vis_backends` å˜åœ¨å¤šä¸ªå¯è§†åŒ–å˜å‚¨åŽç«¯æ—¶å€™åªæœ‰ `WandbVisBackend` æ‰æ˜¯æœ‰æ•ˆçš„。 + +## å¯è§†åŒ–å˜å‚¨åŽç«¯ VisBackend + +在绘制åŽå¯ä»¥å°†ç»˜åˆ¶åŽçš„æ•°æ®å˜å‚¨åˆ°å¤šä¸ªå¯è§†åŒ–å˜å‚¨åŽç«¯ä¸ã€‚为了统一接å£è°ƒç”¨ï¼ŒMMEngine æ供了统一的抽象类 `BaseVisBackend`,和一些常用的 VisBackend 如 `LocalVisBackend`ã€`WandbVisBackend` å’Œ `TensorboardVisBackend`。 + +### 接å£è¯´æ˜Ž + +BaseVisBackend 定义了对外调用的接å£è§„范,主è¦æŽ¥å£å’Œå±žæ€§å¦‚下: + +- [add_config](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.vis_backend.BaseVisBackend.add_config) 写é…置到特定å˜å‚¨åŽç«¯ +- [add_graph](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.vis_backend.BaseVisBackend.add_graph) 写模型图到特定åŽç«¯ +- [add_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.vis_backend.BaseVisBackend.add_image) 写图片到特定åŽç«¯ +- [add_scalar](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.vis_backend.BaseVisBackend.add_scalar) å†™æ ‡é‡åˆ°ç‰¹å®šåŽç«¯ +- [add_scalars](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.vis_backend.BaseVisBackend.add_scalars) ä¸€æ¬¡æ€§å†™å¤šä¸ªæ ‡é‡åˆ°ç‰¹å®šåŽç«¯ +- [close](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.vis_backend.BaseVisBackend.close) å…³é—å·²ç»æ‰“å¼€çš„èµ„æº +- [experiment](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.vis_backend.BaseVisBackend.experiment) 写åŽç«¯å¯¹è±¡ï¼Œä¾‹å¦‚ Wandb 对象和 Tensorboard 对象 + +`BaseVisBackend` 定义了 5 个常è§çš„写数æ®æŽ¥å£ï¼Œè€ƒè™‘到æŸäº›å†™åŽç«¯åŠŸèƒ½éžå¸¸å¼ºå¤§ï¼Œä¾‹å¦‚ Wandbï¼Œå…¶å…·å¤‡å†™è¡¨æ ¼ï¼Œå†™è§†é¢‘ç‰ç‰åŠŸèƒ½ï¼Œé’ˆå¯¹è¿™ç±»éœ€æ±‚用户å¯ä»¥ç›´æŽ¥èŽ·å– experiment 对象,然åŽè°ƒç”¨å†™åŽç«¯å¯¹è±¡æœ¬èº«çš„ API å³å¯ã€‚而 `LocalVisBackend`ã€`WandbVisBackend` å’Œ `TensorboardVisBackend` ç‰éƒ½æ˜¯ç»§æ‰¿è‡ª `BaseVisBackend`ï¼Œå¹¶æ ¹æ®è‡ªèº«ç‰¹æ€§å®žçŽ°äº†å¯¹åº”çš„å˜å‚¨åŠŸèƒ½ã€‚ + +### 使用案例 + +ä¸€èˆ¬æƒ…å†µä¸‹ç”¨æˆ·æ— éœ€æ“作 VisBackend 对象,åªæœ‰åœ¨å½“å‰å¯è§†åŒ–å˜å‚¨æ— 法满足需求时候,用户会希望直接æ“作å˜å‚¨åŽç«¯ã€‚以 Wandb 为例,其æ供了éžå¸¸ä¸°å¯Œçš„å˜å‚¨æ ¼å¼ï¼Œä¾‹å¦‚å˜å‚¨è¡¨æ ¼ã€å˜å‚¨æƒé‡ç‰ç‰æŽ¥å£ã€‚为了所有åŽç«¯èƒ½å¤Ÿç»Ÿä¸€æŽ¥å£ï¼Œæˆ‘们并没有æ供这类常用接å£ï¼Œæ¤æ—¶ç”¨æˆ·å¯ä»¥ç›´æŽ¥èŽ·å– Wandb 对象进行自定义å˜å‚¨ã€‚ + +```python +vis_backends = [dict(type='WandbVisBackend')] +visualizer = dict( + type='DetWandbVisualizer', vis_backends=vis_backends, name='visualizer') +``` + +```python +# 内部会调用 get_instance() 进行全局唯一实例化 +VISUALIZERS.build(cfg.visualizer) +# ä»»æ„代ç ä½ç½®èŽ·å– visualizer +visualizer = Visualizer.get_current_instance() + +# 扩展 add 功能,例如利用 Wandb å¯¹è±¡ç»˜åˆ¶è¡¨æ ¼ +wandb = visualizer.get_backend('WandbVisBackend').experiment +val_table = wandb.Table(data=my_data, columns=column_names) +wandb.log({'my_val_table': val_table}) +``` + +一个 visualizer 对象å¯ä»¥æŽ¥å…¥ä»»æ„多个 VisBackend。为了方便用户获å–ä»»æ„çš„ VisBackend,在ä¸æŒ‡å®š name å‚数情况下,å¯ä»¥é€šè¿‡ç±»åèŽ·å– + +```python +vis_backends = [dict(type='LocalVisBackend'), dict(type='WandbVisBackend')] +visualizer = dict( + type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer') +``` + +```python +# 内部会调用 get_instance() 进行全局唯一实例化 +VISUALIZERS.build(cfg.visualizer) +# ä»»æ„代ç ä½ç½®èŽ·å– visualizer +visualizer = Visualizer.get_current_instance() + +local_vis_backend = visualizer.get_backend('LocalVisBackend') +wandb_vis_backend = visualizer.get_backend('WandbVisBackend') +``` + +当å˜åœ¨å¤šä¸ªåŒåçš„ VisBackend 时候,用户必须指定唯一的 name å‚数,åŽç»å¯ä»¥é€šè¿‡ name å—符串æ¥èŽ·å– + +```python +vis_backends = [dict(type='LocalVisBackend', name='local_vis_backend_1'), dict(type='LocalVisBackend', name='local_vis_backend_2')] +visualizer = dict( + type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer') +``` + +```python +# 内部会调用 get_instance() 进行全局唯一实例化 +VISUALIZERS.build(cfg.visualizer) +# ä»»æ„代ç ä½ç½®èŽ·å– visualizer +visualizer = Visualizer.get_current_instance() + +local_vis_backend_1 = visualizer.get_backend('local_vis_backend_1') +local_vis_backend_2 = visualizer.get_backend('local_vis_backend_2') +``` diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py index 49995334c6ea6fa78464bc543482a56e5db0e47c..5e96bcd871c00c441cf835fb984565515d8cf3d8 100644 --- a/mmengine/hooks/hook.py +++ b/mmengine/hooks/hook.py @@ -358,11 +358,11 @@ class Hook: """ return (runner.epoch + 1) % n == 0 if n > 0 else False - def every_n_inner_iters(self, inner_iter: int, n: int) -> bool: + def every_n_inner_iters(self, batch_idx: int, n: int) -> bool: """Test whether current inner iteration can be evenly divided by n. Args: - inner_iter (int): Current inner_iter of the training, validation + batch_idx (int): Current batch index of the training, validation or testing loop. n (int): Whether current inner iteration can be evenly divided by n. @@ -371,7 +371,7 @@ class Hook: bool: Whether current inner iteration can be evenly divided by n. """ - return (inner_iter + 1) % n == 0 if n > 0 else False + return (batch_idx + 1) % n == 0 if n > 0 else False def every_n_iters(self, runner, n: int) -> bool: """Test whether current iteration can be evenly divided by n. @@ -395,7 +395,6 @@ class Hook: dataloader (Dataloader): The dataloader of the training, validation or testing process. batch_idx (int): The index of the current batch in the loop. - Returns: bool: Whether reaches the end of current epoch or not. """ @@ -418,10 +417,10 @@ class Hook: Args: runner (Runner): The runner of the training, validation or testing process. + mode (str): Current mode of runner. Defaults to 'train'. Returns: bool: Whether current iteration is the last iteration. - mode (str): Current mode of runner. Defaults to 'train'. """ if mode == 'train': return runner.iter + 1 == runner.train_loop.max_iters diff --git a/mmengine/hooks/iter_timer_hook.py b/mmengine/hooks/iter_timer_hook.py index 8791dc96d86696a42eb0ec4773ca166dcf33f2a9..56c9cee927842f67e63f0b276d3409a124f9fa32 100644 --- a/mmengine/hooks/iter_timer_hook.py +++ b/mmengine/hooks/iter_timer_hook.py @@ -18,11 +18,25 @@ class IterTimerHook(Hook): priority = 'NORMAL' + def __init__(self): + self.time_sec_tot = 0 + self.start_iter = 0 + + def before_run(self, runner) -> None: + """Synchronize the number of iterations with the runner. + + Args: + runner: The runner of the training, validation or testing + process. + """ + self.start_iter = runner.iter + def _before_epoch(self, runner, mode: str = 'train') -> None: - """Record time flag before start a epoch. + """Record timestamp before start an epoch. Args: - runner (Runner): The runner of the training process. + runner (Runner): The runner of the training validation and + testing process. mode (str): Current mode of runner. Defaults to 'train'. """ self.t = time.time() @@ -32,16 +46,18 @@ class IterTimerHook(Hook): batch_idx: int, data_batch: DATA_BATCH = None, mode: str = 'train') -> None: - """Logging time for loading data and update the time flag. + """Calculating time for loading data and updating "data_time" + ``HistoryBuffer`` of ``runner.message_hub``. Args: - runner (Runner): The runner of the training process. + runner (Runner): The runner of the training, validation and + testing process. batch_idx (int): The index of the current batch in the loop. data_batch (Sequence[dict], optional): Data from dataloader. Defaults to None. mode (str): Current mode of runner. Defaults to 'train'. """ - # TODO: update for new logging system + # Update data loading time in `runner.message_hub`. runner.message_hub.update_scalar(f'{mode}/data_time', time.time() - self.t) @@ -52,10 +68,12 @@ class IterTimerHook(Hook): outputs: Optional[Union[dict, Sequence[BaseDataElement]]] = None, mode: str = 'train') -> None: - """Logging time for a iteration and update the time flag. + """Calculating time for an iteration and updating "time" + ``HistoryBuffer`` of ``runner.message_hub``. Args: - runner (Runner): The runner of the training process. + runner (Runner): The runner of the training validation and + testing process. batch_idx (int): The index of the current batch in the loop. data_batch (Sequence[dict], optional): Data from dataloader. Defaults to None. @@ -63,7 +81,31 @@ class IterTimerHook(Hook): to None. mode (str): Current mode of runner. Defaults to 'train'. """ - # TODO: update for new logging system - - runner.message_hub.update_scalar(f'{mode}/time', time.time() - self.t) + # Update iteration time in `runner.message_hub`. + message_hub = runner.message_hub + message_hub.update_scalar(f'{mode}/time', time.time() - self.t) self.t = time.time() + window_size = runner.log_processor.window_size + # Calculate eta every `window_size` iterations. Since test and val + # loop will not update runner.iter, use `every_n_innter_iters`to check + # the interval. + if self.every_n_inner_iters(batch_idx, window_size): + iter_time = message_hub.get_scalar(f'{mode}/time').mean( + window_size) + if mode == 'train': + self.time_sec_tot += iter_time * window_size + # Calculate average iterative time. + time_sec_avg = self.time_sec_tot / ( + runner.iter - self.start_iter + 1) + # Calculate eta. + eta_sec = time_sec_avg * ( + runner.train_loop.max_iters - runner.iter - 1) + runner.message_hub.update_info('eta', eta_sec) + else: + if mode == 'val': + cur_dataloader = runner.val_loop.dataloader + else: + cur_dataloader = runner.test_loop.dataloader + + eta_sec = iter_time * (len(cur_dataloader) - batch_idx - 1) + runner.message_hub.update_info('eta', eta_sec) diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py index cd56624429d43eeaed52d4c78196ee3c52a6c9fa..65d5d8a4d523f206020733bb25bfad7c1e68e26a 100644 --- a/mmengine/hooks/logger_hook.py +++ b/mmengine/hooks/logger_hook.py @@ -1,14 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -import copy -import datetime import os import os.path as osp -from collections import OrderedDict from pathlib import Path from typing import Optional, Sequence, Union -import torch - +from mmengine.data import BaseDataElement from mmengine.fileio import FileClient from mmengine.hooks import Hook from mmengine.registry import HOOKS @@ -19,33 +15,20 @@ DATA_BATCH = Optional[Sequence[dict]] @HOOKS.register_module() class LoggerHook(Hook): - """In this logger hook, the information will be printed on the terminal and - saved in JSON file, tensorboard, wandb .etc. + """Collect logs from different components of ``Runner`` and write them to + terminal, JSON file, tensorboard and wandb .etc. + + ``LoggerHook`` is used to record logs formatted by ``LogProcessor`` during + training/validation/testing phase. It is used to control following + behaviers: + + - The frequency of logs update in terminal, local, tensorboad wandb.etc. + - The frequency of show experiment information in terminal. + - The work directory to save logs. Args: - by_epoch (bool): Whether ``EpochBasedLoop`` is used. - Defaults to True. interval (int): Logging interval (every k iterations). Defaults to 10. - custom_keys (dict, optional): Defines the keys in the log and which - kinds of statistic methods should be used to log them. - - - ``custom_keys`` contains multiple string-dict pairs. In each - string-dict pair, the string defines a key name in the log and the - dict is a config defines the statistic methods and corresponding - arguments used to log the value. For example, - ``dict(loss=dict(method_name='mean', log_name='global_loss', - window_size='global'))`` which means the log key ``loss`` will be - counted as global mean and additionally logged as ``global_loss``. - If ``log_name`` is not defined in config dict, the original logged - key will be overwritten. - - The key in ``LoggerHook.fixed_smooth_keys`` cannot be overwritten - because ``time`` and ``iter_time`` will be used to calculate - estimated time of arrival. If you want to recount the time, you - should set ``log_name`` in corresponding values. - - For those statistic methods with the ``window_size`` argument, - if ``by_epoch`` is set to False, ``windows_size`` should not be - `epoch` to statistics log value by epoch. ignore_last (bool): Ignore the log of last iterations in each epoch if the number of remaining iterations is less than :attr:`interval`. Defaults to True. @@ -70,64 +53,24 @@ class LoggerHook(Hook): Defaults to None. Examples: - >>> # `log_name` is defined, `loss_mean_window` will be an additional - >>> # record. - >>> logger_hook_cfg = dict(by_epoch=True, - >>> custom_keys=dict( - >>> loss=dict( - >>> log_name='loss_mean_window', - >>> method_name='mean', - >>> window_size=10))) - >>> # `log_name` is not defined. `loss` will be overwritten by - >>> # `global_mean` statistics. - >>> logger_hook_cfg = dict(by_epoch=True, - >>> custom_keys=dict( - >>> loss=dict( - >>> method_name='mean', - >>> window_size='global'))) - >>> # `time` cannot be overwritten, `global_time` will be an additional - >>> # record. - >>> logger_hook_cfg = dict(by_epoch=True, - >>> custom_keys=dict( - >>> time=dict( - >>> log_name='global_time', - >>> method='mean', - >>> window_size='global'))) - >>> # Record loss with different statistics methods. - >>> logger_hook_cfg = dict(by_epoch=True, - >>> custom_keys=dict(loss=[ - >>> dict(log_name='loss_mean_window', - >>> method_name='mean', - >>> window_size=10), - >>> dict(method_name='mean', - >>> window_size='global')])) + >>> # A simplest LoggerHook config. + >>> logger_hook_cfg = dict(interval=20) """ - # eta will be calculated by time. `time` and `data_time` should not be - # overwritten. - fixed_smooth_keys = ('time', 'data_time') priority = 'BELOW_NORMAL' def __init__( self, - by_epoch: bool = True, interval: int = 10, - custom_keys: Optional[dict] = None, ignore_last: bool = True, interval_exp_name: int = 1000, out_dir: Optional[Union[str, Path]] = None, out_suffix: Union[Sequence[str], str] = ('.log.json', '.log', '.py'), - keep_local=True, - file_client_args=None, + keep_local: bool = True, + file_client_args: Optional[dict] = None, ): - self._inner_iter = 0 - self.by_epoch = by_epoch self.interval = interval - self.custom_keys = custom_keys if custom_keys is not None else dict() self.ignore_last = ignore_last - - self.time_sec_tot = 0 self.interval_exp_name = interval_exp_name - self._check_custom_keys() if out_dir is None and file_client_args is not None: raise ValueError( @@ -165,14 +108,15 @@ class LoggerHook(Hook): self.json_log_path = osp.join(runner.work_dir, f'{runner.timestamp}.log.json') - self.start_iter = runner.iter + self.yaml_log_path = osp.join(runner.work_dir, + f'{runner.timestamp}.log.json') def after_train_iter(self, runner, batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[dict] = None) -> None: - """Record training logs. + """Record training logs after training iteration. Args: runner (Runner): The runner of the training process. @@ -182,33 +126,90 @@ class LoggerHook(Hook): outputs (dict, optional): Outputs from model. Defaults to None. """ - self._inner_iter = batch_idx - if runner.meta is not None and 'exp_name' in runner.meta: - if (self.every_n_iters(runner, self.interval_exp_name)) or ( - self.by_epoch and self.end_of_epoch( - runner.train_loop.dataloader, batch_idx)): - exp_info = f'Exp name: {runner.meta["exp_name"]}' - runner.logger.info(exp_info) - if self.by_epoch and self.every_n_inner_iters(batch_idx, - self.interval): - self._log_train(runner) - elif not self.by_epoch and self.every_n_iters(runner, self.interval): - self._log_train(runner) - elif self.end_of_epoch(runner.train_loop.dataloader, - batch_idx) and not self.ignore_last: + # Print experiment name every n iterations. + if self.every_n_iters(runner, + self.interval_exp_name) or (self.end_of_epoch( + runner.train_dataloader, batch_idx)): + exp_info = f'Exp name: {runner.experiment_name}' + runner.logger.info(exp_info) + if self.every_n_inner_iters(batch_idx, self.interval): + tag, log_str = runner.log_processor.get_log_after_iter( + runner, batch_idx, 'train') + elif (self.end_of_epoch(runner.train_dataloader, batch_idx) + and not self.ignore_last): # `runner.max_iters` may not be divisible by `self.interval`. if # `self.ignore_last==True`, the log of remaining iterations will # be recorded (Epoch [4][1000/1007], the logs of 998-1007 # iterations will be recorded). - self._log_train(runner) + tag, log_str = runner.log_processor.get_log_after_iter( + runner, batch_idx, 'train') + else: + return + runner.logger.info(log_str) + # TODO compatible with visualizer. + runner.visualizer.add_scalars(tag, step=runner.iter + 1) + + def after_val_iter( + self, + runner, + batch_idx: int, + data_batch: DATA_BATCH = None, + outputs: Optional[Sequence[BaseDataElement]] = None) -> None: + """Record validation logs after validation iteration. + + Args: + runner (Runner): The runner of the training process. + batch_idx (int): The index of the current batch in the train loop. + data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): + Data from dataloader. Defaults to None. + outputs (sequence, optional): Outputs from model. Defaults to None. + """ + if self.every_n_inner_iters(batch_idx, self.interval): + tag, log_str = runner.log_processor.get_log_after_iter( + runner, batch_idx, 'val') + runner.logger.info(log_str) + + def after_test_iter( + self, + runner, + batch_idx: int, + data_batch: DATA_BATCH = None, + outputs: Optional[Sequence[BaseDataElement]] = None) -> None: + """Record testing logs after iteration. + + Args: + runner (Runner): The runner of the training process. + batch_idx (int): The index of the current batch in the train loop. + data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): + Data from dataloader. Defaults to None. + outputs (sequence, optional): Outputs from model. Defaults to None. + """ + if self.every_n_inner_iters(batch_idx, self.interval): + tag, log_str = runner.log_processor.get_log_after_iter( + runner, batch_idx, 'test') + runner.logger.info(log_str) def after_val_epoch(self, runner) -> None: - """Record validation logs. + """Record validation logs after validation epoch. Args: runner (Runner): The runner of the training process. """ - self._log_val(runner) + tag, log_str = runner.log_processor.get_log_after_epoch( + runner, len(runner.val_dataloader), 'val') + runner.logger.info(log_str) + # TODO compatible with visualizer. + runner.visualizer.add_scalars(tag, step=runner.iter + 1) + + def after_test_epoch(self, runner) -> None: + """Record testing logs after test epoch. + + Args: + runner (Runner): The runner of the training process. + """ + tag, log_str = runner.log_processor.get_log_after_epoch( + runner, len(runner.val_dataloader), 'test') + runner.logger.info(log_str) def after_run(self, runner) -> None: """Copy logs to ``self.out_dir`` if ``self.out_dir is not None`` @@ -233,278 +234,3 @@ class LoggerHook(Hook): os.remove(local_filepath) runner.logger.info((f'{local_filepath} was removed due to the ' '`self.keep_local=False`')) - - def _log_train(self, runner) -> None: - """Collect and record training logs which start named with "train/*". - - Args: - runner (Runner): The runner of the training process. - """ - tag = self._collect_info(runner, 'train') - # The training log default defines `lr`, `momentum`, `time` and - # `data_time`. `log_tag` will pop these keys and loop other keys to - # `log_str`. - log_tag = copy.deepcopy(tag) - cur_iter = self._get_iter(runner, inner_iter=True) - cur_epoch = self._get_epoch(runner, 'train') - - # Record learning rate and momentum. - lr_str_list = [] - momentum_str_list = [] - for key, value in tag.items(): - if key.startswith('lr'): - log_tag.pop(key) - lr_str_list.append(f'{key}: {value:.3e}') - lr_str = ' '.join(lr_str_list) - for key, value in tag.items(): - if key.startswith('momentum'): - log_tag.pop(key) - momentum_str_list.append(f'{key}: {value:.3e}') - momentum_str = ' '.join(momentum_str_list) - lr_momentum_str = f'{lr_str} {momentum_str}' - # by epoch: Epoch [4][100/1000] - # by iter: Iter [100/100000] - if self.by_epoch: - log_str = f'Epoch [{cur_epoch}]' \ - f'[{cur_iter}/{len(runner.train_loop.dataloader)}] ' - else: - log_str = f'Iter [{cur_iter}/{runner.train_loop.max_iters}] ' - log_str += f'{lr_momentum_str}, ' - # Calculate eta time. - self.time_sec_tot += (tag['time'] * self.interval) - time_sec_avg = self.time_sec_tot / (runner.iter - self.start_iter + 1) - eta_sec = time_sec_avg * ( - runner.train_loop.max_iters - runner.iter - 1) - eta_str = str(datetime.timedelta(seconds=int(eta_sec))) - log_str += f'eta: {eta_str}, ' - log_str += f'time: {tag["time"]:.3f}, ' \ - f'data_time: {tag["data_time"]:.3f}, ' - # Pop recorded keys - log_tag.pop('time') - log_tag.pop('data_time') - # statistic memory - if torch.cuda.is_available(): - log_str += f'memory: {self._get_max_memory(runner)}, ' - # Loop left keys to fill `log_str`. - log_items = [] - for name, val in log_tag.items(): - if isinstance(val, float): - val = f'{val:.4f}' - log_items.append(f'{name}: {val}') - log_str += ', '.join(log_items) - runner.logger.info(log_str) - # Write logs to local, tensorboad, and wandb. - runner.visualizer.add_scalars( - tag, step=runner.iter + 1, file_path=self.json_log_path) - - def _log_val(self, runner) -> None: - """Collect and record training logs which start named with "val/*". - - Args: - runner (Runner): The runner of the training process. - """ - tag = self._collect_info(runner, 'val') - # Compatible with function `log` https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/logger/text.py # noqa E501 - eval_iter = len(runner.val_loop.dataloader) - cur_iter = self._get_iter(runner) - cur_epoch = self._get_epoch(runner, 'val') - # val/test time - # here 1000 is the length of the val dataloader - # by epoch: Epoch[val] [4][1000] - # by iter: Iter[val] [1000] - if self.by_epoch: - # runner.epoch += 1 has been done before val workflow - log_str = f'Epoch(val) [{cur_epoch}][{eval_iter}] ' - else: - log_str = f'Iter(val) [{eval_iter}] ' - - log_items = [] - for name, val in tag.items(): - if isinstance(val, float): - val = f'{val:.4f}' - log_items.append(f'{name}: {val}') - log_str += ', '.join(log_items) - runner.logger.info(log_str) - # Write tag. - runner.visualizer.add_scalars( - tag, step=cur_iter, file_path=self.json_log_path) - - def _get_window_size(self, runner, window_size: Union[int, str]) \ - -> int: - """Parse window_size specified in ``self.custom_keys`` to int value. - - Args: - runner (Runner): The runner of the training process. - window_size (int or str): Smoothing scale of logs. - - Returns: - int: Smoothing window for statistical methods. - """ - if isinstance(window_size, int): - assert window_size == self.interval, \ - 'The value of windows size must equal to LoggerHook.interval' - return window_size - elif window_size == 'epoch': - return self._inner_iter + 1 - elif window_size == 'global': - return runner.iter + 1 - else: - raise ValueError('window_size should be int, epoch or global, but ' - f'got invalid {window_size}') - - def _collect_info(self, runner, mode: str) -> dict: - """Collect log information to a dict according to mode. - - Args: - runner (Runner): The runner of the training process. - mode (str): 'train' or 'val', which means the prefix attached by - runner. - - Returns: - dict: Statistical values of logs. - """ - tag = OrderedDict() - log_buffers = runner.message_hub.log_scalars - mode_log_buffers = OrderedDict() - # Filter log_buffers which starts with `mode`. - for prefix_key, log_buffer in log_buffers.items(): - if prefix_key.startswith(mode): - key = prefix_key.split('/')[-1] - mode_log_buffers[key] = log_buffer - # Ensure all metric and lr values are latest. - for key in mode_log_buffers: - # Update the latest learning rate and smoothed time logs. - if key in self.fixed_smooth_keys or key.startswith('loss'): - tag[key] = mode_log_buffers[key].mean(self.interval) - else: - tag[key] = mode_log_buffers[key].current() - # Update custom keys. - if mode == 'train': - for log_key, log_cfg in self.custom_keys.items(): - self._parse_custom_keys(runner, log_key, - copy.deepcopy(log_cfg), - mode_log_buffers, tag) - return tag - - def _parse_custom_keys(self, runner, log_key: str, log_cfg: dict, - log_buffers: OrderedDict, tag: OrderedDict) -> None: - """Statistics logs in log_buffers according to custom_keys. - - Args: - runner (Runner): The runner of the training process. - log_key (str): log key specified in ``self.custom_keys`` - log_cfg (dict): A config dict for describing the logging - statistics method. - log_buffers (OrderedDict): All logs for the corresponding phase. - tag (OrderedDict): A dict which defines all statistic values of - logs. - """ - if isinstance(log_cfg, list): - log_names = set() - for cfg in log_cfg: - log_name = cfg.get('log_name', None) - if log_name in log_names: - raise KeyError(f'{cfg["log_name"]} cannot be redefined in ' - 'log_key') - if log_name is not None: - log_names.add(log_name) - self._parse_custom_keys(runner, log_key, cfg, log_buffers, tag) - assert len(log_names) == len(log_cfg) - 1, \ - f'{log_key} cannot be overwritten multiple times, please ' \ - f'check only one key does not contain `log_name` in {log_cfg}.' - elif isinstance(log_cfg, dict): - if 'window_size' in log_cfg: - log_cfg['window_size'] = \ - self._get_window_size(runner, log_cfg['window_size']) - if 'log_name' in log_cfg: - name = log_cfg.pop('log_name') - else: - name = log_key - tag[name] = log_buffers[log_key].statistics(**log_cfg).item() - else: - raise ValueError('The structure of `LoggerHook.custom key` is ' - 'wrong, please make sure the type of each key is ' - 'dict or list.') - - def _get_max_memory(self, runner) -> int: - """Returns the maximum GPU memory occupied by tensors in megabytes (MB) - for a given device. - - Args: - runner (Runner): The runner of the training process. - - Returns: - The maximum GPU memory occupied by tensors in megabytes for a given - device. - """ - device = getattr(runner.model, 'output_device', None) - mem = torch.cuda.max_memory_allocated(device=device) - mem_mb = torch.tensor([int(mem) // (1024 * 1024)], - dtype=torch.int, - device=device) - torch.cuda.reset_peak_memory_stats() - return int(mem_mb.item()) - - def _check_custom_keys(self) -> None: - """Check the legality of ``self.custom_keys``. - - If ``self.by_epoch==False``, ``window_size`` should not be "epoch". The - key of ``self.fixed_smooth_keys`` cannot be overwritten. - """ - - def _check_window_size(item): - if not self.by_epoch: - assert item['window_size'] != 'epoch', \ - 'window_size cannot be epoch if LoggerHook.by_epoch is ' \ - 'False.' - - def _check_fixed_keys(key, item): - if key in self.fixed_smooth_keys: - assert 'log_name' in item, f'{key} cannot be overwritten by ' \ - 'custom keys!' - - for key, value in self.custom_keys.items(): - if isinstance(value, Sequence): - [(_check_window_size(item), _check_fixed_keys(key, item)) - for item in value] - - else: - _check_window_size(value) - _check_fixed_keys(key, value) - - def _get_epoch(self, runner, mode: str) -> int: - """Get epoch according to mode. - - Args: - runner (Runner): The runner of the training process. - mode (str): Train or val. - - Returns: - int: The current epoch. - """ - if mode == 'train': - epoch = runner.epoch + 1 - elif mode == 'val': - # normal val mode - # runner.epoch += 1 has been done before val workflow - epoch = runner.epoch - else: - raise ValueError(f"runner mode should be 'train' or 'val', " - f'but got {runner.mode}') - return epoch - - def _get_iter(self, runner, inner_iter=False) -> int: - """Get the current training iteration step. - Args: - runner (Runner): The runner of the training process. - inner_iter (bool): Whether to return the inner iter of an epoch. - Defaults to False. - - Returns: - int: The current global iter or inner iter. - """ - if self.by_epoch and inner_iter: - current_iter = self._inner_iter + 1 - else: - current_iter = runner.iter + 1 - return current_iter diff --git a/mmengine/hooks/optimizer_hook.py b/mmengine/hooks/optimizer_hook.py index 9107dbf02500e24d471271fd99a7fc1b29ad12fe..510e31c622ab1bd530bc40581042428afe17a781 100644 --- a/mmengine/hooks/optimizer_hook.py +++ b/mmengine/hooks/optimizer_hook.py @@ -84,6 +84,9 @@ class OptimizerHook(Hook): we keep ``outputs`` here. Defaults to None. """ runner.optimizer.zero_grad() + runner.message_hub.update_scalar( + 'train/lr', runner.optimizer.param_groups[0]['lr']) + if self.detect_anomalous_params: self.detect_anomalous_parameters(runner.outputs['loss'], runner) runner.outputs['loss'].backward() diff --git a/mmengine/logging/__init__.py b/mmengine/logging/__init__.py index ba5533c2363f49c749d0dba87f49496e3861ed80..eeac7ff1bdb55004c572d39273db3b2a9d2e645d 100644 --- a/mmengine/logging/__init__.py +++ b/mmengine/logging/__init__.py @@ -1,6 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from .history_buffer import HistoryBuffer +from .log_processor import LogProcessor from .logger import MMLogger, print_log from .message_hub import MessageHub -__all__ = ['HistoryBuffer', 'MessageHub', 'MMLogger', 'print_log'] +__all__ = [ + 'HistoryBuffer', 'MessageHub', 'MMLogger', 'print_log', 'LogProcessor' +] diff --git a/mmengine/logging/log_processor.py b/mmengine/logging/log_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..3619bf4b7983d0cc1d15dec25f30362c1a838a30 --- /dev/null +++ b/mmengine/logging/log_processor.py @@ -0,0 +1,409 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import datetime +from collections import OrderedDict +from typing import List, Optional, Tuple + +import torch + + +class LogProcessor: + """A log processor used to format log information collected from + ``runner.message_hub.log_scalars``. + + ``LogProcessor`` instance is built by runner and will format + ``runner.message_hub.log_scalars`` to ``tag`` and ``log_str``, which can + directly used by ``LoggerHook`` and ``MMLogger``. Besides, the argument + ``custom_cfg`` of constructor can control the statistics method of logs. + + Args: + window_size (int): default smooth interval Defaults to 10. + by_epoch (bool): Whether to format logs with epoch stype. Defaults to + True. + custom_cfg (list[dict], optional): Contains multiple log config dict, + in which key means the data source name of log and value means the + statistic method and corresponding arguments used to count the + data source. Defaults to None + - If custom_cfg is None, all logs will be formatted via default + methods, such as smoothing loss by default window_size. If + custom_cfg is defined as a list of config dict, for example: + [dict(data_src=loss, method='mean', log_name='global_loss', + window_size='global')]. It means the log item ``loss`` will be + counted as global mean and additionally logged as ``global_loss`` + (defined by ``log_name``). If ``log_name`` is not defined in + config dict, the original logged key will be overwritten. + + - The original log item cannot be overwritten twice. Here is + an error example: + [dict(data_src=loss, method='mean', window_size='global'), + dict(data_src=loss, method='mean', window_size='epoch')]. + Both log config dict in custom_cfg do not have ``log_name`` key, + which means the loss item will be overwritten twice. + + - For those statistic methods with the ``window_size`` argument, + if ``by_epoch`` is set to False, ``windows_size`` should not be + `epoch` to statistics log value by epoch. + + Examples: + >>> # `log_name` is defined, `loss_large_window` will be an additional + >>> # record. + >>> log_processor = dict( + >>> window_size=10, + >>> by_epoch=True, + >>> custom_cfg=[dict(data_src='loss', + >>> log_name='loss_large_window', + >>> method_name='mean', + >>> window_size=100)]) + >>> # `log_name` is not defined. `loss` will be overwritten. + >>> log_processor = dict( + >>> window_size=10, + >>> by_epoch=True, + >>> custom_cfg=[dict(data_src='loss', + >>> method_name='mean', + >>> window_size=100)]) + >>> # Record loss with different statistics methods. + >>> log_processor = dict( + >>> window_size=10, + >>> by_epoch=True, + >>> custom_cfg=[dict(data_src='loss', + >>> log_name='loss_large_window', + >>> method_name='mean', + >>> window_size=100), + >>> dict(data_src='loss', + >>> method_name='mean', + >>> window_size=100)]) + >>> # Overwrite loss item twice will raise an error. + >>> log_processor = dict( + >>> window_size=10, + >>> by_epoch=True, + >>> custom_cfg=[dict(data_src='loss', + >>> method_name='mean', + >>> window_size=100), + >>> dict(data_src='loss', + >>> method_name='max', + >>> window_size=100)]) + AssertionError + """ + + def __init__(self, + window_size=10, + by_epoch=True, + custom_cfg: Optional[List[dict]] = None): + self.window_size = window_size + self.by_epoch = by_epoch + self.custom_cfg = custom_cfg if custom_cfg else [] + self._check_custom_cfg() + + def get_log_after_iter(self, runner, batch_idx: int, + mode: str) -> Tuple[dict, str]: + """Format log string after training, validation or testing epoch. + + Args: + runner (Runner): The runner of training phase. + batch_idx (int): The index of the current batch in the current + loop. + mode (str): Current mode of runner, train, test or val. + + Return: + Tuple(dict, str): Formatted log dict/string which will be + recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`. + """ + assert mode in ['train', 'test', 'val'] + current_loop = self._get_cur_loop(runner, mode) + cur_iter = self._get_iter(runner, batch_idx=batch_idx) + # Overwrite ``window_size`` defined in ``custom_cfg`` to int value. + custom_cfg_copy = self._parse_windows_size(runner, batch_idx) + # tag is used to write log information to different backends. + tag = self._collect_scalars(custom_cfg_copy, runner, mode) + # `log_tag` will pop 'lr' and loop other keys to `log_str`. + log_tag = copy.deepcopy(tag) + # Record learning rate. + lr_str_list = [] + for key, value in tag.items(): + if key.startswith('lr'): + log_tag.pop(key) + lr_str_list.append(f'{key}: {value:.3e}') + lr_str = ' '.join(lr_str_list) + # Format log header. + # by_epoch == True + # train/val: Epoch [5][5/10] ... + # test: Epoch [5/10] + # by_epoch == False + # train: Epoch [5/10000] ... (divided by `max_iter`) + # val/test: Epoch [5/2000] ... (divided by length of dataloader) + if self.by_epoch: + if mode in ['train', 'val']: + cur_epoch = self._get_epoch(runner, mode) + log_str = (f'Epoch({mode}) [{cur_epoch}]' + f'[{cur_iter}/{len(current_loop.dataloader)}] ') + else: + log_str = (f'Epoch({mode}) ' + f'[{cur_iter}/{len(current_loop.dataloader)}] ') + else: + if mode == 'train': + log_str = (f'Iter({mode}) ' + f'[{cur_iter}/{runner.train_loop.max_iters}] ') + else: + log_str = (f'Iter({mode}) [{batch_idx+1}' + f'/{len(current_loop.dataloader)}] ') + # Concatenate lr, momentum string with log header. + log_str += f'{lr_str} ' + # If IterTimerHook used in runner, eta, time, and data_time should be + # recorded. + if (all(item in tag for item in ['time', 'data_time']) + and 'eta' in runner.message_hub.runtime_info): + eta = runner.message_hub.get_info('eta') + eta_str = str(datetime.timedelta(seconds=int(eta))) + log_str += f'eta: {eta_str} ' + log_str += (f'time: {tag["time"]:.3f} ' + f'data_time: {tag["data_time"]:.3f} ') + # Pop recorded keys + log_tag.pop('time') + log_tag.pop('data_time') + + # If cuda is available, the max memory occupied should be calculated. + if torch.cuda.is_available(): + log_str += f'memory: {self._get_max_memory(runner)} ' + # Loop left keys to fill `log_str`. + if mode in ('train', 'val'): + log_items = [] + for name, val in log_tag.items(): + if mode == 'val' and not name.startswith('val/loss'): + continue + if isinstance(val, float): + val = f'{val:.4f}' + log_items.append(f'{name}: {val}') + log_str += ' '.join(log_items) + return tag, log_str + + def get_log_after_epoch(self, runner, batch_idx: int, + mode: str) -> Tuple[dict, str]: + """Format log string after validation or testing epoch. + + Args: + runner (Runner): The runner of training phase. + batch_idx (int): The index of the current batch in the current + loop. + mode (str): Current mode of runner. + + Return: + Tuple(dict, str): Formatted log dict/string which will be + recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`. + """ + assert mode in [ + 'test', 'val' + ], ('`_get_metric_log_str` only accept val or test mode, but got ' + f'{mode}') + cur_loop = self._get_cur_loop(runner, mode) + dataloader_len = len(cur_loop.dataloader) + + custom_cfg_copy = self._parse_windows_size(runner, batch_idx) + # tag is used to write log information to different backends. + tag = self._collect_scalars(custom_cfg_copy, runner, mode) + # validation log string needs cur epoch/iteration and max + # epochs/iterations. test log string only needs length of test + # dataloader. + cur_iter = self._get_iter(runner, batch_idx) + if self.by_epoch: + if mode == 'val': + cur_epoch = self._get_epoch(runner, mode) + log_str = (f'Epoch({mode}) [{cur_epoch}][{dataloader_len}/' + f'{dataloader_len}] ') + else: + log_str = ( + f'Epoch({mode}) [{dataloader_len}/{dataloader_len}] ') + + else: + if mode == 'train': + log_str = (f'Iter({mode}) [{cur_iter}/' + f'{runner.train_loop.max_iters}] ') + else: + log_str = ( + f'Iter({mode}) [{dataloader_len}/{dataloader_len}] ') + log_items = [] + for name, val in tag.items(): + if name in ('time', 'data_time'): + continue + if isinstance(val, float): + val = f'{val:.4f}' + log_items.append(f'{name}: {val}') + log_str += ' '.join(log_items) + return tag, log_str + + def _collect_scalars(self, custom_cfg: List[dict], runner, + mode: str) -> dict: + """Collect log information to compose a dict according to mode. + + Args: + custom_cfg (List[dict]): A copy of ``self.custom_cfg`` with int + ``window_size``. + runner (Runner): The runner of the training process. + mode (str): 'train' or 'val', which means the prefix attached by + runner. + + Returns: + dict: Statistical values of logs. + """ + tag = OrderedDict() + # history_scalars of train/val/test phase. + history_scalars = runner.message_hub.log_scalars + # corresponding mode history_scalars + mode_history_scalars = OrderedDict() + # extract log scalars and remove prefix to `mode_history_scalars` + # according to mode. + for prefix_key, log_buffer in history_scalars.items(): + if prefix_key.startswith(mode): + key = prefix_key.split('/')[-1] + mode_history_scalars[key] = log_buffer + for key in mode_history_scalars: + # Update the latest learning rate and smoothed time logs. + if key.startswith('loss'): + tag[key] = mode_history_scalars[key].mean(self.window_size) + else: + # Default statistic method is current. + tag[key] = mode_history_scalars[key].current() + # Update custom keys. + for log_cfg in custom_cfg: + data_src = log_cfg.pop('data_src') + if 'log_name' in log_cfg: + log_name = log_cfg.pop('log_name') + else: + log_name = data_src + # log item in custom_cfg could only exist in train or val + # mode. + if data_src in mode_history_scalars: + tag[log_name] = mode_history_scalars[data_src].statistics( + **log_cfg) + return tag + + def _check_custom_cfg(self) -> None: + """Check the legality of ``self.custom_cfg``.""" + + def _check_window_size(): + for log_cfg in self.custom_cfg: + if not self.by_epoch: + assert log_cfg['window_size'] != 'epoch', \ + 'window_size cannot be epoch if LoggerHook.by_epoch' \ + ' is False.' + + def _check_repeated_log_name(): + check_dict = dict() + # The `log_name` of the same data_src should not be repeated. + # If `log_name` is not specified, `data_src` will be overwritten. + # But only allowed to be overwritten once. + for log_cfg in self.custom_cfg: + assert 'data_src' in log_cfg + data_src = log_cfg['data_src'] + log_name = log_cfg.get('log_name', data_src) + check_dict.setdefault(data_src, + dict(log_names=set(), log_counts=0)) + check_dict[data_src]['log_names'].add(log_name) + check_dict[data_src]['log_counts'] += 1 + assert (len( + check_dict[data_src] + ['log_names']) == check_dict[data_src]['log_counts']), ( + f'If you want to statistic {data_src} with multiple ' + 'statistics method, please check `log_name` is unique' + f'and {data_src} will not be overwritten twice. See ' + f'more information in the docstring of `LogProcessor`') + + _check_repeated_log_name() + _check_window_size() + + def _parse_windows_size(self, runner, batch_idx: int) -> list: + """Parse window_size defined in custom_cfg to int value. + + Args: + runner (Runner): The runner of the training process. + batch_idx (int): The iteration index of current dataloader. + """ + custom_cfg_copy = copy.deepcopy(self.custom_cfg) + for log_cfg in custom_cfg_copy: + window_size = log_cfg.get('window_size', None) + if window_size is None or isinstance(window_size, int): + continue + elif window_size == 'epoch': + log_cfg['window_size'] = batch_idx + 1 + elif window_size == 'global': + log_cfg['window_size'] = runner.iter + 1 + else: + raise TypeError( + 'window_size should be int, epoch or global, but got ' + f'invalid {window_size}') + return custom_cfg_copy + + def _get_max_memory(self, runner) -> int: + """Returns the maximum GPU memory occupied by tensors in megabytes (MB) + for a given device. + + Args: + runner (Runner): The runner of the training process. + + Returns: + The maximum GPU memory occupied by tensors in megabytes for a given + device. + """ + device = getattr(runner.model, 'output_device', None) + mem = torch.cuda.max_memory_allocated(device=device) + mem_mb = torch.tensor([int(mem) // (1024 * 1024)], + dtype=torch.int, + device=device) + torch.cuda.reset_peak_memory_stats() + return int(mem_mb.item()) + + def _get_iter(self, runner, batch_idx: int = None) -> int: + """Get current training iteration step. + + Args: + runner (Runner): The runner of the training process. + batch_idx (int, optional): The interaction index of current + dataloader. Defaults to None. + + Returns: + int: The current global iter or inner iter. + """ + if self.by_epoch and batch_idx: + current_iter = batch_idx + 1 + else: + current_iter = runner.iter + 1 + return current_iter + + def _get_epoch(self, runner, mode: str) -> int: + """Get current epoch according to mode. + + Args: + runner (Runner): The runner of the training/validation process. + mode (str): Current mode of runner, "train" or "val". + + Returns: + int: The current epoch. + """ + if mode == 'train': + epoch = runner.epoch + 1 + elif mode == 'val': + # normal val mode + # runner.epoch += 1 has been done before validation + epoch = runner.epoch + else: + raise ValueError( + f"runner mode should be 'train' or 'val', but got {mode}") + return epoch + + def _get_cur_loop(self, runner, mode: str): + """Get current loop according to mode. + + Args: + runner (Runner): The runner of the training/validation/testing + process. + mode (str): Current mode of runner, "train", "val" or test. + + Returns: + BaseLoop: Current loop of runner. + """ + # returns type hint will occur circular import + if mode == 'train': + return runner.train_loop + elif mode == 'val': + return runner.val_loop + else: + return runner.test_loop diff --git a/mmengine/logging/logger.py b/mmengine/logging/logger.py index 3ae26524f3f5a906b5e5bdf1eb39799cffacd52a..6066449f3528142a28dc6df72158ed16a59d900d 100644 --- a/mmengine/logging/logger.py +++ b/mmengine/logging/logger.py @@ -32,15 +32,15 @@ class MMFormatter(logging.Formatter): info_prefix = self._get_prefix('INFO', color) debug_prefix = self._get_prefix('DEBUG', color) # Config output format. - self.err_format = f'%(asctime)s - %(name)s - {error_prefix} - ' \ - f'%(pathname)s - %(funcName)s - %(lineno)d - ' \ - '%(message)s' - self.warn_format = f'%(asctime)s - %(name)s - {warn_prefix} - %(' \ - 'message)s' - self.info_format = f'%(asctime)s - %(name)s - {info_prefix} - %(' \ - 'message)s' - self.debug_format = f'%(asctime)s - %(name)s - {debug_prefix} - %(' \ - 'message)s' + self.err_format = (f'%(asctime)s - %(name)s - {error_prefix} - ' + '%(pathname)s - %(funcName)s - %(lineno)d - ' + '%(message)s') + self.warn_format = (f'%(asctime)s - %(name)s - {warn_prefix} - %(' + 'message)s') + self.info_format = (f'%(asctime)s - %(name)s - {info_prefix} - %(' + 'message)s') + self.debug_format = (f'%(asctime)s - %(name)s - {debug_prefix} - %(' + 'message)s') def _get_prefix(self, level: str, color: bool) -> str: """Get the prefix of the target log level. diff --git a/mmengine/optim/scheduler/__init__.py b/mmengine/optim/scheduler/__init__.py index f7ea1d57400ef829d2a71502c9a709a711311f03..733ca752836c230d810f6b7ea94fab9839ed2a61 100644 --- a/mmengine/optim/scheduler/__init__.py +++ b/mmengine/optim/scheduler/__init__.py @@ -1,14 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. from .lr_scheduler import (ConstantLR, CosineAnnealingLR, ExponentialLR, - LinearLR, MultiStepLR, StepLR) + LinearLR, MultiStepLR, PolyLR, StepLR) from .momentum_scheduler import (ConstantMomentum, CosineAnnealingMomentum, ExponentialMomentum, LinearMomentum, - MultiStepMomentum, StepMomentum) + MultiStepMomentum, PolyMomentum, StepMomentum) from .param_scheduler import (ConstantParamScheduler, CosineAnnealingParamScheduler, ExponentialParamScheduler, LinearParamScheduler, - MultiStepParamScheduler, StepParamScheduler, - _ParamScheduler) + MultiStepParamScheduler, PolyParamScheduler, + StepParamScheduler, _ParamScheduler) __all__ = [ 'ConstantLR', 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR', @@ -16,5 +16,6 @@ __all__ = [ 'ExponentialMomentum', 'LinearMomentum', 'MultiStepMomentum', 'StepMomentum', 'ConstantParamScheduler', 'CosineAnnealingParamScheduler', 'ExponentialParamScheduler', 'LinearParamScheduler', - 'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler' + 'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler', + 'PolyParamScheduler', 'PolyLR', 'PolyMomentum' ] diff --git a/mmengine/optim/scheduler/lr_scheduler.py b/mmengine/optim/scheduler/lr_scheduler.py index 514b8b035c80e8891faffa7a0ef577edfe7edc38..3c774a67b33e6cc6de7340f16edf8f0a7ab7beb0 100644 --- a/mmengine/optim/scheduler/lr_scheduler.py +++ b/mmengine/optim/scheduler/lr_scheduler.py @@ -7,7 +7,8 @@ from mmengine.registry import PARAM_SCHEDULERS from .param_scheduler import (INF, ConstantParamScheduler, CosineAnnealingParamScheduler, ExponentialParamScheduler, LinearParamScheduler, - MultiStepParamScheduler, StepParamScheduler) + MultiStepParamScheduler, PolyParamScheduler, + StepParamScheduler) @PARAM_SCHEDULERS.register_module() @@ -294,3 +295,49 @@ class StepLR(StepParamScheduler): last_step=last_step, by_epoch=by_epoch, verbose=verbose) + + +@PARAM_SCHEDULERS.register_module() +class PolyLR(PolyParamScheduler): + """Decays the learning rate of each parameter group in a polynomial decay + scheme. + + Notice that such decay can happen simultaneously with other changes to the + parameter value from outside this scheduler. + + Args: + optimizer (Optimizer): Wrapped optimizer. + eta_min (float): Minimum learning rate at the end of scheduling. + Defaults to 0. + power (float): The power of the polynomial. Defaults to 1.0. + begin (int): Step at which to start updating the parameters. + Defaults to 0. + end (int): Step at which to stop updating the parameters. + Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled parameters are updated by + epochs. Defaults to True. + verbose (bool): Whether to print the value for each update. + Defaults to False. + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + eta_min: float = 0, + power: float = 1, + begin: int = 0, + end: int = INF, + last_step: int = -1, + by_epoch: bool = True, + verbose: bool = False): + super().__init__( + optimizer, + param_name='lr', + eta_min=eta_min, + power=power, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) diff --git a/mmengine/optim/scheduler/momentum_scheduler.py b/mmengine/optim/scheduler/momentum_scheduler.py index cc882c3b423df50da7e352ad25e8f357da05138d..fa357eb1e3ab41e307a937315af1c9da53311392 100644 --- a/mmengine/optim/scheduler/momentum_scheduler.py +++ b/mmengine/optim/scheduler/momentum_scheduler.py @@ -7,7 +7,8 @@ from mmengine.registry import PARAM_SCHEDULERS from .param_scheduler import (INF, ConstantParamScheduler, CosineAnnealingParamScheduler, ExponentialParamScheduler, LinearParamScheduler, - MultiStepParamScheduler, StepParamScheduler) + MultiStepParamScheduler, PolyParamScheduler, + StepParamScheduler) @PARAM_SCHEDULERS.register_module() @@ -294,3 +295,49 @@ class StepMomentum(StepParamScheduler): last_step=last_step, by_epoch=by_epoch, verbose=verbose) + + +@PARAM_SCHEDULERS.register_module() +class PolyMomentum(PolyParamScheduler): + """Decays the momentum of each parameter group in a polynomial decay + scheme. + + Notice that such decay can happen simultaneously with other changes to the + parameter value from outside this scheduler. + + Args: + optimizer (Optimizer): Wrapped optimizer. + eta_min (float): Minimum momentum at the end of scheduling. + Defaults to 0. + power (float): The power of the polynomial. Defaults to 1.0. + begin (int): Step at which to start updating the parameters. + Defaults to 0. + end (int): Step at which to stop updating the parameters. + Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled parameters are updated by + epochs. Defaults to True. + verbose (bool): Whether to print the value for each update. + Defaults to False. + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + eta_min: float = 0, + power: float = 1, + begin: int = 0, + end: int = INF, + last_step: int = -1, + by_epoch: bool = True, + verbose: bool = False): + super().__init__( + optimizer, + param_name='momentum', + eta_min=eta_min, + power=power, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) diff --git a/mmengine/optim/scheduler/param_scheduler.py b/mmengine/optim/scheduler/param_scheduler.py index bbec0556b0db7a8739656473ebca39dd66f34261..f40507e54b0debcc947751a479be5fae38cbaa84 100644 --- a/mmengine/optim/scheduler/param_scheduler.py +++ b/mmengine/optim/scheduler/param_scheduler.py @@ -534,6 +534,7 @@ class LinearParamScheduler(_ParamScheduler): Notice that such decay can happen simultaneously with other changes to the parameter value from outside this scheduler. + Args: optimizer (Optimizer): Wrapped optimizer. start_factor (float): The number we multiply parameter value in the @@ -598,3 +599,64 @@ class LinearParamScheduler(_ParamScheduler): (self.end_factor - self.start_factor))) for group in self.optimizer.param_groups ] + + +@PARAM_SCHEDULERS.register_module() +class PolyParamScheduler(_ParamScheduler): + """Decays the parameter value of each parameter group in a polynomial decay + scheme. + + Notice that such decay can happen simultaneously with other changes to the + parameter value from outside this scheduler. + + Args: + optimizer (Optimizer): Wrapped optimizer. + eta_min (float): Minimum parameter value at the end of scheduling. + Defaults to 0. + power (float): The power of the polynomial. Defaults to 1.0. + begin (int): Step at which to start updating the parameters. + Defaults to 0. + end (int): Step at which to stop updating the parameters. + Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled parameters are updated by + epochs. Defaults to True. + verbose (bool): Whether to print the value for each update. + Defaults to False. + """ + + def __init__(self, + optimizer: Optimizer, + param_name: str, + eta_min: float = 0, + power: float = 1.0, + begin: int = 0, + end: int = INF, + last_step: int = -1, + by_epoch: bool = True, + verbose: bool = False): + + self.eta_min = eta_min + self.power = power + self.total_iters = end - begin - 1 + + super().__init__( + optimizer, + param_name=param_name, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) + + def _get_value(self): + + if self.last_step == 0: + return [ + group[self.param_name] for group in self.optimizer.param_groups + ] + + return [(group[self.param_name] - self.eta_min) * + (1 - 1 / (self.total_iters - self.last_step + 1))**self.power + + self.eta_min for group in self.optimizer.param_groups] diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 02cae1eccdf43e82ebea80d32e34c8d51a116986..5b1dd4c202a8b2207d44adbbdd5d470cb5dda3c4 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -25,7 +25,7 @@ from mmengine.dist import (broadcast, get_dist_info, init_dist, master_only, sync_random_seed) from mmengine.evaluator import Evaluator from mmengine.hooks import Hook -from mmengine.logging import MessageHub, MMLogger +from mmengine.logging import LogProcessor, MessageHub, MMLogger from mmengine.model import is_model_wrapper from mmengine.optim import _ParamScheduler, build_optimizer from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, @@ -127,6 +127,8 @@ class Runner: non-distributed environment will be launched. env_cfg (dict): A dict used for setting environment. Defaults to dict(dist_cfg=dict(backend='nccl')). + log_processor (dict, optional): A processor to format logs. Defaults to + None. log_level (int or str): The log level of MMLogger handlers. Defaults to 'INFO'. visualizer (Visualizer or dict, optional): A Visualizer object or a @@ -151,43 +153,44 @@ class Runner: Examples: >>> from mmengine import Runner >>> cfg = dict( - model=dict(type='ToyModel'), - work_dir='path/of/work_dir', - train_dataloader=dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - batch_size=1, - num_workers=0), - val_dataloader=dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=False), - batch_size=1, - num_workers=0), - test_dataloader=dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=False), - batch_size=1, - num_workers=0), - optimizer=dict(type='SGD', lr=0.01), - param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]), - val_evaluator=dict(type='ToyEvaluator'), - test_evaluator=dict(type='ToyEvaluator'), - train_cfg=dict(by_epoch=True, max_epochs=3), - val_cfg=dict(interval=1), - test_cfg=dict(), - custom_hooks=[], - default_hooks=dict( - timer=dict(type='IterTimerHook'), - checkpoint=dict(type='CheckpointHook', interval=1), - logger=dict(type='LoggerHook'), - optimizer=dict(type='OptimizerHook', grad_clip=False), - param_scheduler=dict(type='ParamSchedulerHook')), - launcher='none', - env_cfg=dict(dist_cfg=dict(backend='nccl')), - visualizer=dict(type='Visualizer', - vis_backends=[dict(type='LocalVisBackend', - save_dir='temp_dir')]) - ) + >>> model=dict(type='ToyModel'), + >>> work_dir='path/of/work_dir', + >>> train_dataloader=dict( + >>> dataset=dict(type='ToyDataset'), + >>> sampler=dict(type='DefaultSampler', shuffle=True), + >>> batch_size=1, + >>> num_workers=0), + >>> val_dataloader=dict( + >>> dataset=dict(type='ToyDataset'), + >>> sampler=dict(type='DefaultSampler', shuffle=False), + >>> batch_size=1, + >>> num_workers=0), + >>> test_dataloader=dict( + >>> dataset=dict(type='ToyDataset'), + >>> sampler=dict(type='DefaultSampler', shuffle=False), + >>> batch_size=1, + >>> num_workers=0), + >>> optimizer=dict(type='SGD', lr=0.01), + >>> param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]), + >>> val_evaluator=dict(type='ToyEvaluator'), + >>> test_evaluator=dict(type='ToyEvaluator'), + >>> train_cfg=dict(by_epoch=True, max_epochs=3), + >>> val_cfg=dict(interval=1), + >>> test_cfg=dict(), + >>> custom_hooks=[], + >>> default_hooks=dict( + >>> timer=dict(type='IterTimerHook'), + >>> checkpoint=dict(type='CheckpointHook', interval=1), + >>> logger=dict(type='LoggerHook'), + >>> optimizer=dict(type='OptimizerHook', grad_clip=False), + >>> param_scheduler=dict(type='ParamSchedulerHook')), + >>> launcher='none', + >>> env_cfg=dict(dist_cfg=dict(backend='nccl')), + >>> log_processor=dict(window_size=20), + >>> visualizer=dict(type='Visualizer', + >>> vis_backends=[dict(type='LocalVisBackend', + >>> save_dir='temp_dir')]) + >>> ) >>> runner = Runner.from_cfg(cfg) >>> runner.train() >>> runner.test() @@ -217,6 +220,7 @@ class Runner: resume: bool = False, launcher: str = 'none', env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')), + log_processor: Optional[Dict] = None, log_level: str = 'INFO', visualizer: Optional[Union[Visualizer, Dict]] = None, default_scope: Optional[str] = None, @@ -309,14 +313,16 @@ class Runner: self._experiment_name = f'{filename_no_ext}_{self._timestamp}' else: self._experiment_name = self.timestamp - # Used to reset registries location. See :meth:`Registry.build` for # more details. self.default_scope = DefaultScope.get_instance( self._experiment_name, scope_name=default_scope) - + # Build log processor to format message. + log_processor = dict() if log_processor is None else log_processor + self.log_processor = LogProcessor(**log_processor) + # Since `get_instance` could return any subclass of ManagerMixin. The + # corresponding attribute needs a type hint. self.logger = self.build_logger(log_level=log_level) - # Build `message_hub` for communication among components. # `message_hub` can store log scalars (loss, learning rate) and # runtime information (iter and epoch). Those components that do not @@ -387,6 +393,7 @@ class Runner: resume=cfg.get('resume', False), launcher=cfg.get('launcher', 'none'), env_cfg=cfg.get('env_cfg'), # type: ignore + log_processor=cfg.get('log_processor'), log_level=cfg.get('log_level', 'INFO'), visualizer=cfg.get('visualizer'), default_scope=cfg.get('default_scope'), diff --git a/tests/test_hook/test_iter_timer_hook.py b/tests/test_hook/test_iter_timer_hook.py index af149f2f1fd1937ce66400278b4dfa9fe1b4e155..8d3dfb9d5ae9a249648b9e3f4862cebacce1aa23 100644 --- a/tests/test_hook/test_iter_timer_hook.py +++ b/tests/test_hook/test_iter_timer_hook.py @@ -1,29 +1,70 @@ # Copyright (c) OpenMMLab. All rights reserved. -from unittest.mock import Mock +from unittest import TestCase +from unittest.mock import MagicMock, Mock, patch from mmengine.hooks import IterTimerHook +from mmengine.logging import MessageHub -class TestIterTimerHook: +def time_patch(): + if not hasattr(time_patch, 'time'): + time_patch.time = 0 + else: + time_patch.time += 1 + return time_patch.time + + +class TestIterTimerHook(TestCase): + + def setUp(self) -> None: + self.hook = IterTimerHook() + + def test_init(self): + assert self.hook.time_sec_tot == 0 + assert self.hook.start_iter == 0 + + def test_before_run(self): + runner = MagicMock() + runner.iter = 1 + self.hook.before_run(runner) + assert self.hook.start_iter == 1 def test_before_epoch(self): - hook = IterTimerHook() runner = Mock() - hook._before_epoch(runner) - assert isinstance(hook.t, float) + self.hook._before_epoch(runner) + assert isinstance(self.hook.t, float) + @patch('time.time', MagicMock(return_value=1)) def test_before_iter(self): - hook = IterTimerHook() - runner = Mock() + runner = MagicMock() runner.log_buffer = dict() - hook._before_epoch(runner) - hook._before_iter(runner, 0) - runner.message_hub.update_scalar.assert_called() + self.hook._before_epoch(runner) + for mode in ('train', 'val', 'test'): + self.hook._before_iter(runner, batch_idx=1, mode=mode) + runner.message_hub.update_scalar.assert_called_with( + f'{mode}/data_time', 0) + @patch('time.time', time_patch) def test_after_iter(self): - hook = IterTimerHook() - runner = Mock() + runner = MagicMock() runner.log_buffer = dict() - hook._before_epoch(runner) - hook._after_iter(runner, 0) + runner.log_processor.window_size = 10 + runner.train_loop.max_iters = 100 + runner.iter = 0 + runner.test_loop.dataloader = [0] * 20 + runner.val_loop.dataloader = [0] * 20 + self.hook._before_epoch(runner) + self.hook.before_run(runner) + self.hook._after_iter(runner, batch_idx=1) runner.message_hub.update_scalar.assert_called() + runner.message_hub.get_log.assert_not_called() + runner.message_hub.update_info.assert_not_called() + runner.message_hub = MessageHub.get_instance('test_iter_timer_hook') + runner.iter = 9 + # eta = (100 - 10) / 1 + self.hook._after_iter(runner, batch_idx=89) + assert runner.message_hub.get_info('eta') == 90 + self.hook._after_iter(runner, batch_idx=9, mode='val') + assert runner.message_hub.get_info('eta') == 10 + self.hook._after_iter(runner, batch_idx=19, mode='test') + assert runner.message_hub.get_info('eta') == 0 diff --git a/tests/test_hook/test_logger_hook.py b/tests/test_hook/test_logger_hook.py index b2b617991288135ee89cd8f6d72336d67fab9825..138ce2a1f031c549483bade8e4d2e4114ae469fa 100644 --- a/tests/test_hook/test_logger_hook.py +++ b/tests/test_hook/test_logger_hook.py @@ -1,13 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -import datetime -import logging import os.path as osp -import sys -from collections import OrderedDict -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest -import torch from mmengine.fileio.file_client import HardDiskBackend from mmengine.hooks import LoggerHook @@ -17,11 +12,8 @@ class TestLoggerHook: def test_init(self): logger_hook = LoggerHook(out_dir='tmp.txt') - assert logger_hook.by_epoch assert logger_hook.interval == 10 - assert not logger_hook.custom_keys assert logger_hook.ignore_last - assert logger_hook.time_sec_tot == 0 assert logger_hook.interval_exp_name == 1000 assert logger_hook.out_suffix == ('.log.json', '.log', '.py') assert logger_hook.keep_local @@ -30,22 +22,7 @@ class TestLoggerHook: # out_dir should be None or string or tuple of string. with pytest.raises(TypeError): LoggerHook(out_dir=1) - # time cannot be overwritten. - with pytest.raises(AssertionError): - LoggerHook(custom_keys=dict(time=dict(method='max'))) - LoggerHook( - custom_keys=dict(time=[ - dict(method='max', log_name='time_max'), - dict(method='min', log_name='time_min') - ])) - # Epoch window_size cannot be used when `LoggerHook.by_epoch=False` - with pytest.raises(AssertionError): - LoggerHook( - by_epoch=False, - custom_keys=dict( - time=dict( - method='max', log_name='time_max', - window_size='epoch'))) + with pytest.raises(ValueError): LoggerHook(file_client_args=dict(enable_mc=True)) @@ -60,19 +37,22 @@ class TestLoggerHook: assert logger_hook.out_dir == osp.join('out_dir', 'work_dir') assert logger_hook.json_log_path == osp.join('work_dir', 'timestamp.log.json') - assert logger_hook.start_iter == runner.iter def test_after_run(self, tmp_path): + # Test out_dir = tmp_path / 'out_dir' out_dir.mkdir() work_dir = tmp_path / 'work_dir' work_dir.mkdir() work_dir_json = work_dir / 'tmp.log.json' - json_f = open(work_dir_json, 'w') - json_f.close() runner = MagicMock() runner.work_dir = work_dir - + # Test without out_dir. + logger_hook = LoggerHook() + logger_hook.after_run(runner) + # Test with out_dir and make sure json file has been moved to out_dir. + json_f = open(work_dir_json, 'w') + json_f.close() logger_hook = LoggerHook(out_dir=str(tmp_path), keep_local=False) logger_hook.out_dir = str(out_dir) logger_hook.after_run(runner) @@ -83,276 +63,83 @@ class TestLoggerHook: def test_after_train_iter(self): # Test LoggerHook by iter. runner = MagicMock() - runner.iter = 10 - batch_idx = 5 - logger_hook = LoggerHook(by_epoch=False) - logger_hook._log_train = MagicMock() - logger_hook.after_train_iter(runner, batch_idx=batch_idx) + runner.log_processor.get_log_after_iter = MagicMock( + return_value=(dict(), 'log_str')) + logger_hook = LoggerHook() + logger_hook.after_train_iter(runner, batch_idx=5) # `cur_iter=10+1`, which cannot be exact division by # `logger_hook.interval` - logger_hook._log_train.assert_not_called() - runner.iter = 9 - logger_hook.after_train_iter(runner, batch_idx=batch_idx) - logger_hook._log_train.assert_called() + runner.log_processor.get_log_after_iter.assert_not_called() + logger_hook.after_train_iter(runner, batch_idx=9) + runner.log_processor.get_log_after_iter.assert_called() # Test LoggerHook by epoch. - logger_hook = LoggerHook(by_epoch=True) - logger_hook._log_train = MagicMock() - # Only `runner.inner_iter` will work. - runner.iter = 9 - batch_idx = 10 - logger_hook.after_train_iter(runner, batch_idx=batch_idx) - logger_hook._log_train.assert_not_called() - batch_idx = 9 - logger_hook.after_train_iter(runner, batch_idx=batch_idx) - logger_hook._log_train.assert_called() + logger_hook = LoggerHook() + runner = MagicMock() + runner.log_processor.get_log_after_iter = MagicMock( + return_value=(dict(), 'log_str')) + # Only `batch_idx` will work. + logger_hook.after_train_iter(runner, batch_idx=10) + runner.log_processor.get_log_after_iter.assert_not_called() + logger_hook.after_train_iter(runner, batch_idx=9) + runner.log_processor.get_log_after_iter.assert_called() # Test end of the epoch. - logger_hook = LoggerHook(by_epoch=True, ignore_last=False) - logger_hook._log_train = MagicMock() - runner.train_loop.dataloader = [0] * 5 - batch_idx = 4 - logger_hook.after_train_iter(runner, batch_idx=batch_idx) - logger_hook._log_train.assert_called() + runner = MagicMock() + runner.log_processor.get_log_after_iter = MagicMock( + return_value=(dict(), 'log_str')) + logger_hook = LoggerHook(ignore_last=False) + runner.train_dataloader = [0] * 5 + logger_hook.after_train_iter(runner, batch_idx=4) + runner.log_processor.get_log_after_iter.assert_called() # Test print exp_name + runner = MagicMock() + runner.log_processor.get_log_after_iter = MagicMock( + return_value=(dict(), 'log_str')) runner.meta = dict(exp_name='retinanet') - logger_hook = LoggerHook() runner.logger = MagicMock() - logger_hook._log_train = MagicMock() - logger_hook.after_train_iter(runner, batch_idx=batch_idx) - runner.logger.info.assert_called_with( - f'Exp name: {runner.meta["exp_name"]}') + logger_hook = LoggerHook() + logger_hook.after_train_iter(runner, batch_idx=999) + runner.logger.info.assert_called() def test_after_val_epoch(self): logger_hook = LoggerHook() runner = MagicMock() - logger_hook._log_val = MagicMock() + runner.log_processor.get_log_after_epoch = MagicMock( + return_value=(dict(), 'string')) logger_hook.after_val_epoch(runner) - logger_hook._log_val.assert_called() - - @pytest.mark.parametrize('by_epoch', [True, False]) - def test_log_train(self, by_epoch, capsys): - runner = self._setup_runner() - runner.meta = dict(exp_name='retinanet') - # Prepare LoggerHook - logger_hook = LoggerHook(by_epoch=by_epoch) - logger_hook._inner_iter = 1 - logger_hook.writer = MagicMock() - logger_hook.time_sec_tot = 1000 - logger_hook.start_iter = 0 - logger_hook._get_max_memory = MagicMock(return_value='100') - logger_hook.json_log_path = 'tmp.json' - - # Prepare training information. - train_infos = dict( - lr=0.1, momentum=0.9, time=1.0, data_time=1.0, loss_cls=1.0) - logger_hook._collect_info = MagicMock(return_value=train_infos) - logger_hook._log_train(runner) - # Verify that the correct variables have been written. - runner.visualizer.add_scalars.assert_called_with( - train_infos, step=11, file_path='tmp.json') - # Verify that the correct context have been logged. - out, _ = capsys.readouterr() - time_avg = logger_hook.time_sec_tot / ( - runner.iter + 1 - logger_hook.start_iter) - eta_second = time_avg * (runner.train_loop.max_iters - runner.iter - 1) - eta_str = str(datetime.timedelta(seconds=int(eta_second))) - if by_epoch: - if torch.cuda.is_available(): - log_str = 'Epoch [2][2/5] ' \ - f"lr: {train_infos['lr']:.3e} " \ - f"momentum: {train_infos['momentum']:.3e}, " \ - f'eta: {eta_str}, ' \ - f"time: {train_infos['time']:.3f}, " \ - f"data_time: {train_infos['data_time']:.3f}, " \ - f'memory: 100, ' \ - f"loss_cls: {train_infos['loss_cls']:.4f}\n" - else: - log_str = 'Epoch [2][2/5] ' \ - f"lr: {train_infos['lr']:.3e} " \ - f"momentum: {train_infos['momentum']:.3e}, " \ - f'eta: {eta_str}, ' \ - f"time: {train_infos['time']:.3f}, " \ - f"data_time: {train_infos['data_time']:.3f}, " \ - f"loss_cls: {train_infos['loss_cls']:.4f}\n" - assert out == log_str - else: - if torch.cuda.is_available(): - log_str = 'Iter [11/50] ' \ - f"lr: {train_infos['lr']:.3e} " \ - f"momentum: {train_infos['momentum']:.3e}, " \ - f'eta: {eta_str}, ' \ - f"time: {train_infos['time']:.3f}, " \ - f"data_time: {train_infos['data_time']:.3f}, " \ - f'memory: 100, ' \ - f"loss_cls: {train_infos['loss_cls']:.4f}\n" - else: - log_str = 'Iter [11/50] ' \ - f"lr: {train_infos['lr']:.3e} " \ - f"momentum: {train_infos['momentum']:.3e}, " \ - f'eta: {eta_str}, ' \ - f"time: {train_infos['time']:.3f}, " \ - f"data_time: {train_infos['data_time']:.3f}, " \ - f"loss_cls: {train_infos['loss_cls']:.4f}\n" - assert out == log_str - - @pytest.mark.parametrize('by_epoch', [True, False]) - def test_log_val(self, by_epoch, capsys): - runner = self._setup_runner() - # Prepare LoggerHook. - logger_hook = LoggerHook(by_epoch=by_epoch) - logger_hook.json_log_path = 'tmp.json' - metric = dict(accuracy=0.9, data_time=1.0) - logger_hook._collect_info = MagicMock(return_value=metric) - logger_hook._log_val(runner) - # Verify that the correct context have been logged. - out, _ = capsys.readouterr() - runner.visualizer.add_scalars.assert_called_with( - metric, step=11, file_path='tmp.json') - if by_epoch: - assert out == 'Epoch(val) [1][5] accuracy: 0.9000, ' \ - 'data_time: 1.0000\n' + runner.log_processor.get_log_after_epoch.assert_called() + runner.logger.info.assert_called() + runner.visualizer.add_scalars.assert_called() - else: - assert out == 'Iter(val) [5] accuracy: 0.9000, ' \ - 'data_time: 1.0000\n' - - def test_get_window_size(self): - runner = self._setup_runner() - logger_hook = LoggerHook() - logger_hook._inner_iter = 1 - # Test get window size by name. - assert logger_hook._get_window_size(runner, 'epoch') == 2 - assert logger_hook._get_window_size(runner, 'global') == 11 - assert logger_hook._get_window_size(runner, 10) == 10 - # Window size must equal to `logger_hook.interval`. - with pytest.raises(AssertionError): - logger_hook._get_window_size(runner, 20) - - with pytest.raises(ValueError): - logger_hook._get_window_size(runner, 'unknwon') - - def test_parse_custom_keys(self): - tag = OrderedDict() - runner = self._setup_runner() - log_buffers = OrderedDict(lr=MagicMock(), loss=MagicMock()) - cfg_dict = dict( - lr=dict(method='min'), - loss=[ - dict(method='min', window_size='global'), - dict(method='max', log_name='loss_max') - ]) - logger_hook = LoggerHook() - for log_key, log_cfg in cfg_dict.items(): - logger_hook._parse_custom_keys(runner, log_key, log_cfg, - log_buffers, tag) - assert list(tag) == ['lr', 'loss', 'loss_max'] - assert log_buffers['lr'].min.assert_called - assert log_buffers['loss'].min.assert_called - assert log_buffers['loss'].max.assert_called - assert log_buffers['loss'].mean.assert_called - # `log_name` Cannot be repeated. - with pytest.raises(KeyError): - cfg_dict = dict(loss=[ - dict(method='min', window_size='global'), - dict(method='max', log_name='loss_max'), - dict(method='mean', log_name='loss_max') - ]) - logger_hook.custom_keys = cfg_dict - for log_key, log_cfg in cfg_dict.items(): - logger_hook._parse_custom_keys(runner, log_key, log_cfg, - log_buffers, tag) - # `log_key` cannot be overwritten multiple times. - with pytest.raises(AssertionError): - cfg_dict = dict(loss=[ - dict(method='min', window_size='global'), - dict(method='max'), - ]) - logger_hook.custom_keys = cfg_dict - for log_key, log_cfg in cfg_dict.items(): - logger_hook._parse_custom_keys(runner, log_key, log_cfg, - log_buffers, tag) - - def test_collect_info(self): - runner = self._setup_runner() - logger_hook = LoggerHook( - custom_keys=dict(time=dict(method='max', log_name='time_max'))) - logger_hook._parse_custom_keys = MagicMock() - # Collect with prefix. - log_buffers = { - 'train/time': MagicMock(), - 'lr': MagicMock(), - 'train/loss_cls': MagicMock(), - 'val/metric': MagicMock() - } - runner.message_hub.log_scalars = log_buffers - tag = logger_hook._collect_info(runner, mode='train') - # Test parse custom_keys - logger_hook._parse_custom_keys.assert_called() - # Test training key in tag. - assert list(tag.keys()) == ['time', 'loss_cls'] - # Test statistics lr with `current`, loss and time with 'mean' - log_buffers['train/time'].mean.assert_called() - log_buffers['train/loss_cls'].mean.assert_called() - log_buffers['train/loss_cls'].current.assert_not_called() - - tag = logger_hook._collect_info(runner, mode='val') - assert list(tag.keys()) == ['metric'] - log_buffers['val/metric'].current.assert_called() - - @patch('torch.cuda.max_memory_allocated', MagicMock()) - @patch('torch.cuda.reset_peak_memory_stats', MagicMock()) - def test_get_max_memory(self): + def test_after_test_epoch(self): logger_hook = LoggerHook() runner = MagicMock() - runner.world_size = 1 - runner.model = torch.nn.Linear(1, 1) - logger_hook._get_max_memory(runner) - torch.cuda.max_memory_allocated.assert_called() - torch.cuda.reset_peak_memory_stats.assert_called() + runner.log_processor.get_log_after_epoch = MagicMock( + return_value=(dict(), 'log_str')) + logger_hook.after_test_epoch(runner) + runner.log_processor.get_log_after_epoch.assert_called() + runner.logger.info.assert_called() - def test_get_iter(self): - runner = self._setup_runner() + def test_after_val_iter(self): logger_hook = LoggerHook() - logger_hook._inner_iter = 1 - # Get global iter when `inner_iter=False` - iter = logger_hook._get_iter(runner) - assert iter == 11 - # Get inner iter - iter = logger_hook._get_iter(runner, inner_iter=True) - assert iter == 2 - # Still get global iter when `logger_hook.by_epoch==False` - logger_hook.by_epoch = False - iter = logger_hook._get_iter(runner, inner_iter=True) - assert iter == 11 - - def test_get_epoch(self): - runner = self._setup_runner() + runner = MagicMock() + runner.iter = 0 + runner.log_processor.get_log_after_iter = MagicMock( + return_value=(dict(), 'log_str')) + logger_hook.after_val_iter(runner, 1) + runner.log_processor.get_log_after_iter.assert_not_called() + logger_hook.after_val_iter(runner, 9) + runner.log_processor.get_log_after_iter.assert_called() + + def test_after_test_iter(self): logger_hook = LoggerHook() - epoch = logger_hook._get_epoch(runner, 'train') - assert epoch == 2 - epoch = logger_hook._get_epoch(runner, 'val') - assert epoch == 1 - with pytest.raises(ValueError): - logger_hook._get_epoch(runner, 'test') - - def _setup_runner(self): runner = MagicMock() - runner.epoch = 1 - runner.train_loop.dataloader = [0] * 5 - runner.val_loop.dataloader = [0] * 5 - runner.test_loop.dataloader = [0] * 5 - runner.iter = 10 - runner.train_loop.max_iters = 50 - logger = logging.getLogger() - logger.setLevel(logging.INFO) - for handler in logger.handlers: - if not isinstance(handler, logging.StreamHandler): - continue - else: - logger.addHandler(logging.StreamHandler(stream=sys.stdout)) - runner.logger = logger - runner.message_hub = MagicMock() - runner.composed_wirter = MagicMock() - return runner + runner.iter = 0 + runner.log_processor.get_log_after_iter = MagicMock( + return_value=(dict(), 'log_str')) + logger_hook.after_test_iter(runner, 1) + runner.log_processor.get_log_after_iter.assert_not_called() + logger_hook.after_test_iter(runner, 9) + runner.log_processor.get_log_after_iter.assert_called() diff --git a/tests/test_hook/test_optimizer_hook.py b/tests/test_hook/test_optimizer_hook.py index 5d04ca3fe5515a577b2b18157cd5bc1459947880..dc11ee0fa476710e9bfc4fc1c2affd4185e83a99 100644 --- a/tests/test_hook/test_optimizer_hook.py +++ b/tests/test_hook/test_optimizer_hook.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from unittest.mock import Mock +from unittest.mock import MagicMock, Mock import torch from torch import nn @@ -45,7 +45,7 @@ class TestOptimizerHook: model = Model() x = torch.rand(1, 1, 3, 3) - dummy_runner = Mock() + dummy_runner = MagicMock() dummy_runner.optimizer.zero_grad = Mock(return_value=None) dummy_runner.optimizer.step = Mock(return_value=None) dummy_runner.model = model diff --git a/tests/test_logging/test_log_processor.py b/tests/test_logging/test_log_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..b10cac481dddc004c7ed8c7c39a3274493622cef --- /dev/null +++ b/tests/test_logging/test_log_processor.py @@ -0,0 +1,242 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from mmengine.logging import LogProcessor, MessageHub, MMLogger + + +class TestLogProcessor: + + def test_init(self): + log_processor = LogProcessor( + window_size=10, by_epoch=True, custom_cfg=None) + assert log_processor.by_epoch + assert log_processor.window_size == 10 + assert log_processor.custom_cfg == [] + + def test_check_custom_cfg(self): + # ``by_epoch==False`` and `window_size='epoch'` in log config will + # raise AssertionError. + custom_cfg = [dict(data_src='loss', window_size='epoch')] + with pytest.raises(AssertionError): + LogProcessor(by_epoch=False, custom_cfg=custom_cfg) + # Duplicate log_name will raise AssertionError. + custom_cfg = [ + dict(data_src='loss', log_name='loss_1'), + dict(data_src='loss', log_name='loss_1') + ] + with pytest.raises(AssertionError): + LogProcessor(custom_cfg=custom_cfg) + # Overwrite loss item twice will raise AssertionError. + custom_cfg = [dict(data_src='loss'), dict(data_src='loss')] + with pytest.raises(AssertionError): + LogProcessor(custom_cfg=custom_cfg) + + custom_cfg = [ + dict(data_src='loss_cls', window_size=100, method_name='min'), + dict(data_src='loss', log_name='loss_min', method_name='max'), + dict(data_src='loss', log_name='loss_max', method_name='max') + ] + LogProcessor(custom_cfg=custom_cfg) + + def test_parse_windows_size(self): + log_processor = LogProcessor() + # Test parse 'epoch' window_size. + log_processor.custom_cfg = [ + dict(data_src='loss_cls', window_size='epoch') + ] + custom_cfg = log_processor._parse_windows_size(self.runner, 1) + assert custom_cfg[0]['window_size'] == 2 + + # Test parse 'global' window_size. + log_processor.custom_cfg = [ + dict(data_src='loss_cls', window_size='global') + ] + custom_cfg = log_processor._parse_windows_size(self.runner, 1) + assert custom_cfg[0]['window_size'] == 11 + + # Test parse int window_size + log_processor.custom_cfg = [dict(data_src='loss_cls', window_size=100)] + custom_cfg = log_processor._parse_windows_size(self.runner, 1) + assert custom_cfg[0]['window_size'] == 100 + + # Invalid type window_size will raise TypeError. + log_processor.custom_cfg = [dict(data_src='loss_cls', window_size=[])] + with pytest.raises(TypeError): + log_processor._parse_windows_size(custom_cfg, self.runner) + + @pytest.mark.parametrize('by_epoch,mode', + ([True, 'train'], [False, 'train'], [True, 'val'], + [False, 'val'], [True, 'test'], [False, 'test'])) + def test_get_log_after_iter(self, by_epoch, mode): + # Prepare LoggerHook + log_processor = LogProcessor(by_epoch=by_epoch) + log_processor._get_max_memory = MagicMock(return_value='100') + eta = 40 + self.runner.message_hub.update_info('eta', eta) + # Prepare training information. + if mode == 'train': + train_logs = dict(lr=0.1, time=1.0, data_time=1.0, loss_cls=1.0) + else: + train_logs = dict(time=1.0, data_time=1.0, loss_cls=1.0) + log_processor._collect_scalars = MagicMock(return_value=train_logs) + tag, out = log_processor.get_log_after_iter(self.runner, 1, mode) + # Verify that the correct context have been logged. + cur_loop = log_processor._get_cur_loop(self.runner, mode) + if by_epoch: + if mode in ['train', 'val']: + cur_epoch = log_processor._get_epoch(self.runner, mode) + log_str = (f'Epoch({mode}) [{cur_epoch}][2/' + f'{len(cur_loop.dataloader)}] ') + else: + log_str = (f'Epoch({mode}) [2/{len(cur_loop.dataloader)}] ') + + if mode == 'train': + log_str += f"lr: {train_logs['lr']:.3e} " + else: + log_str += ' ' + + log_str += (f'eta: 0:00:40 ' + f"time: {train_logs['time']:.3f} " + f"data_time: {train_logs['data_time']:.3f} ") + + if torch.cuda.is_available(): + log_str += 'memory: 100 ' + if mode == 'train': + log_str += f"loss_cls: {train_logs['loss_cls']:.4f}" + assert out == log_str + else: + if mode == 'train': + max_iters = self.runner.train_loop.max_iters + log_str = f'Iter({mode}) [11/{max_iters}] ' + else: + max_iters = len(cur_loop.dataloader) + log_str = f'Iter({mode}) [2/{max_iters}] ' + + if mode == 'train': + log_str += f"lr: {train_logs['lr']:.3e} " + else: + log_str += ' ' + + log_str += (f'eta: 0:00:40 ' + f"time: {train_logs['time']:.3f} " + f"data_time: {train_logs['data_time']:.3f} ") + + if torch.cuda.is_available(): + log_str += 'memory: 100 ' + + if mode == 'train': + log_str += f"loss_cls: {train_logs['loss_cls']:.4f}" + assert out == log_str + + @pytest.mark.parametrize( + 'by_epoch,mode', + ([True, 'val'], [False, 'val'], [True, 'test'], [False, 'test'])) + def test_log_val(self, by_epoch, mode): + # Prepare LoggerHook + log_processor = LogProcessor(by_epoch=by_epoch) + # Prepare validation information. + val_logs = dict(accuracy=0.9, data_time=1.0) + log_processor._collect_scalars = MagicMock(return_value=val_logs) + _, out = log_processor.get_log_after_epoch(self.runner, 2, mode) + if by_epoch: + if mode == 'test': + assert out == 'Epoch(test) [5/5] accuracy: 0.9000' + else: + assert out == 'Epoch(val) [1][10/10] accuracy: 0.9000' + else: + if mode == 'test': + assert out == 'Iter(test) [5/5] accuracy: 0.9000' + else: + assert out == 'Iter(val) [10/10] accuracy: 0.9000' + + def test_collect_scalars(self): + custom_cfg = [ + dict(data_src='time', method_name='mean', window_size=100), + dict(data_src='time', method_name='max', log_name='time_max') + ] + logger_hook = LogProcessor(custom_cfg=custom_cfg) + # Collect with prefix. + log_scalars = { + 'train/time': MagicMock(), + 'lr': MagicMock(), + 'train/loss_cls': MagicMock(), + 'val/metric': MagicMock() + } + self.runner.message_hub._log_scalars = log_scalars + tag = logger_hook._collect_scalars( + copy.deepcopy(custom_cfg), self.runner, mode='train') + # Test training key in tag. + assert list(tag.keys()) == ['time', 'loss_cls', 'time_max'] + # Test statistics lr with `current`, loss and time with 'mean' + log_scalars['train/time'].statistics.assert_called_with( + method_name='max') + log_scalars['train/loss_cls'].mean.assert_called() + + tag = logger_hook._collect_scalars( + copy.deepcopy(custom_cfg), self.runner, mode='val') + assert list(tag.keys()) == ['metric'] + log_scalars['val/metric'].current.assert_called() + + @patch('torch.cuda.max_memory_allocated', MagicMock()) + @patch('torch.cuda.reset_peak_memory_stats', MagicMock()) + def test_get_max_memory(self): + logger_hook = LogProcessor() + runner = MagicMock() + runner.world_size = 1 + runner.model = torch.nn.Linear(1, 1) + logger_hook._get_max_memory(runner) + torch.cuda.max_memory_allocated.assert_called() + torch.cuda.reset_peak_memory_stats.assert_called() + + def test_get_iter(self): + log_processor = LogProcessor() + # Get global iter when `inner_iter=False` + iter = log_processor._get_iter(self.runner) + assert iter == 11 + # Get inner iter + iter = log_processor._get_iter(self.runner, 1) + assert iter == 2 + # Still get global iter when `logger_hook.by_epoch==False` + log_processor.by_epoch = False + iter = log_processor._get_iter(self.runner, 1) + assert iter == 11 + + def test_get_epoch(self): + log_processor = LogProcessor() + epoch = log_processor._get_epoch(self.runner, 'train') + assert epoch == 2 + epoch = log_processor._get_epoch(self.runner, 'val') + assert epoch == 1 + with pytest.raises(ValueError): + log_processor._get_epoch(self.runner, 'test') + + def test_get_cur_loop(self): + log_processor = LogProcessor() + loop = log_processor._get_cur_loop(self.runner, 'train') + assert len(loop.dataloader) == 20 + loop = log_processor._get_cur_loop(self.runner, 'val') + assert len(loop.dataloader) == 10 + loop = log_processor._get_cur_loop(self.runner, 'test') + assert len(loop.dataloader) == 5 + + def setup(self): + runner = MagicMock() + runner.epoch = 1 + runner.iter = 10 + runner.train_loop.max_iters = 50 + runner.train_loop.dataloader = [0] * 20 + runner.val_loop.dataloader = [0] * 10 + runner.test_loop.dataloader = [0] * 5 + logger = MMLogger.get_instance('log_processor_test') + runner.logger = logger + message_hub = MessageHub.get_instance('log_processor_test') + for i in range(10): + message_hub.update_scalar('train/loss', 10 - i) + for i in range(10): + message_hub.update_scalar('val/acc', i * 0.1) + runner.message_hub = message_hub + self.runner = runner diff --git a/tests/test_optim/test_scheduler/test_lr_scheduler.py b/tests/test_optim/test_scheduler/test_lr_scheduler.py index d747b6bddb6fbbf3060c44f0eb080a79998052fd..6e8f337d89d5a762e71bb3808db4288d1989b3b9 100644 --- a/tests/test_optim/test_scheduler/test_lr_scheduler.py +++ b/tests/test_optim/test_scheduler/test_lr_scheduler.py @@ -8,7 +8,7 @@ import torch.optim as optim from mmengine.optim.scheduler import (ConstantLR, CosineAnnealingLR, ExponentialLR, LinearLR, MultiStepLR, - StepLR, _ParamScheduler) + PolyLR, StepLR, _ParamScheduler) from mmengine.testing import assert_allclose @@ -283,6 +283,21 @@ class TestLRScheduler(TestCase): scheduler = CosineAnnealingLR(self.optimizer, T_max=t, eta_min=eta_min) self._test_scheduler_value(scheduler, targets, epochs) + def test_poly_scheduler(self): + epochs = 10 + power = 0.9 + min_lr = 0.001 + iters = 4 + single_targets = [ + min_lr + (0.05 - min_lr) * (1 - i / iters)**power + for i in range(iters) + ] + [min_lr] * ( + epochs - iters) + targets = [single_targets, [x * epochs for x in single_targets]] + scheduler = PolyLR( + self.optimizer, power=power, eta_min=min_lr, end=iters + 1) + self._test_scheduler_value(scheduler, targets, epochs=10) + def _check_scheduler_state_dict(self, construct, construct2, epochs=10): scheduler = construct() for _ in range(epochs): @@ -331,6 +346,12 @@ class TestLRScheduler(TestCase): lambda: LinearLR(self.optimizer, start_factor=0, end_factor=0.3), epochs=epochs) + def test_poly_scheduler_state_dict(self): + self._check_scheduler_state_dict( + lambda: PolyLR(self.optimizer, power=0.5, eta_min=0.001), + lambda: PolyLR(self.optimizer, power=0.8, eta_min=0.002), + epochs=10) + def test_multi_scheduler_without_overlap_linear_multi_step(self): # use Linear in the first 5 epochs and then use MultiStep epochs = 12 diff --git a/tests/test_optim/test_scheduler/test_momentum_scheduler.py b/tests/test_optim/test_scheduler/test_momentum_scheduler.py index fd63a9b941686783a58680f2d2fc70195fdcca94..97d7af3b9a866578b19e0d2badd4ff5fb7a36870 100644 --- a/tests/test_optim/test_scheduler/test_momentum_scheduler.py +++ b/tests/test_optim/test_scheduler/test_momentum_scheduler.py @@ -9,8 +9,8 @@ import torch.optim as optim from mmengine.optim.scheduler import (ConstantMomentum, CosineAnnealingMomentum, ExponentialMomentum, LinearMomentum, - MultiStepMomentum, StepMomentum, - _ParamScheduler) + MultiStepMomentum, PolyMomentum, + StepMomentum, _ParamScheduler) from mmengine.testing import assert_allclose @@ -284,6 +284,21 @@ class TestMomentumScheduler(TestCase): self.optimizer, T_max=t, eta_min=eta_min) self._test_scheduler_value(scheduler, targets, epochs) + def test_poly_scheduler(self): + epochs = 10 + power = 0.9 + min_lr = 0.001 + iters = 4 + single_targets = [ + min_lr + (0.05 - min_lr) * (1 - i / iters)**power + for i in range(iters) + ] + [min_lr] * ( + epochs - iters) + targets = [single_targets, [x * epochs for x in single_targets]] + scheduler = PolyMomentum( + self.optimizer, power=power, eta_min=min_lr, end=iters + 1) + self._test_scheduler_value(scheduler, targets, epochs=10) + def _check_scheduler_state_dict(self, construct, construct2, epochs=10): scheduler = construct() for _ in range(epochs): @@ -333,6 +348,12 @@ class TestMomentumScheduler(TestCase): self.optimizer, start_factor=0, end_factor=0.3), epochs=epochs) + def test_poly_scheduler_state_dict(self): + self._check_scheduler_state_dict( + lambda: PolyMomentum(self.optimizer, power=0.5, eta_min=0.001), + lambda: PolyMomentum(self.optimizer, power=0.8, eta_min=0.002), + epochs=10) + def test_multi_scheduler_without_overlap_linear_multi_step(self): # use Linear in the first 5 epochs and then use MultiStep epochs = 12 diff --git a/tests/test_optim/test_scheduler/test_param_scheduler.py b/tests/test_optim/test_scheduler/test_param_scheduler.py index d1467828e15a74b2a4f42e850b83c9474b17ed77..c47033929725b921a418a53856036a174e8967c4 100644 --- a/tests/test_optim/test_scheduler/test_param_scheduler.py +++ b/tests/test_optim/test_scheduler/test_param_scheduler.py @@ -6,12 +6,15 @@ import torch import torch.nn.functional as F import torch.optim as optim +# yapf: disable from mmengine.optim.scheduler import (ConstantParamScheduler, CosineAnnealingParamScheduler, ExponentialParamScheduler, LinearParamScheduler, MultiStepParamScheduler, - StepParamScheduler, _ParamScheduler) + PolyParamScheduler, StepParamScheduler, + _ParamScheduler) +# yapf: enable from mmengine.testing import assert_allclose @@ -336,6 +339,25 @@ class TestParameterScheduler(TestCase): self.optimizer, param_name='lr', T_max=t, eta_min=eta_min) self._test_scheduler_value(scheduler, targets, epochs) + def test_poly_scheduler(self): + epochs = 10 + power = 0.9 + min_lr = 0.001 + iters = 4 + single_targets = [ + min_lr + (0.05 - min_lr) * (1 - i / iters)**power + for i in range(iters) + ] + [min_lr] * ( + epochs - iters) + targets = [single_targets, [x * epochs for x in single_targets]] + scheduler = PolyParamScheduler( + self.optimizer, + param_name='lr', + power=power, + eta_min=min_lr, + end=iters + 1) + self._test_scheduler_value(scheduler, targets, epochs=10) + def _check_scheduler_state_dict(self, construct, construct2, epochs=10): scheduler = construct() for _ in range(epochs): @@ -402,6 +424,14 @@ class TestParameterScheduler(TestCase): end_factor=0.3), epochs=epochs) + def test_poly_scheduler_state_dict(self): + self._check_scheduler_state_dict( + lambda: PolyParamScheduler( + self.optimizer, param_name='lr', power=0.5, eta_min=0.001), + lambda: PolyParamScheduler( + self.optimizer, param_name='lr', power=0.8, eta_min=0.002), + epochs=10) + def test_multi_scheduler_without_overlap_linear_multi_step(self): # use Linear in the first 5 epochs and then use MultiStep epochs = 12 diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 29e7ee3658a7678489c9805eeac400204c6cbc62..2c085b9e23ee25717a93be8cbd6424e2810c39fd 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -222,7 +222,7 @@ class TestRunner(TestCase): self.iter_based_cfg.default_hooks = dict( timer=dict(type='IterTimerHook'), checkpoint=dict(type='CheckpointHook', interval=1, by_epoch=False), - logger=dict(type='LoggerHook', by_epoch=False), + logger=dict(type='LoggerHook'), optimizer=dict(type='OptimizerHook', grad_clip=None), param_scheduler=dict(type='ParamSchedulerHook'))