From 5f8f36e6a5c807cec762fdcc2bc2a02ac2302929 Mon Sep 17 00:00:00 2001 From: liukuikun <24622904+Harold-lkk@users.noreply.github.com> Date: Fri, 15 Apr 2022 15:56:06 +0800 Subject: [PATCH] refactor visualization (#147) * [WIP] add inline * refactor vis module * [Refactor] according review * [Fix] fix comment * fix some error * Get sub visualizer be Visualizer.get_instance * fix conflict * fix lint * fix unit test * fix mypy * fix comment * fix lint * update docstr * update * update instancedata * remove replace __mro__ with issubclass Co-authored-by: PJLAB\huanghaian <1286304229@qq.com> Co-authored-by: HAOCHENYE <21724054@zju.edu.cn> --- docs/zh_cn/tutorials/visualization.md | 301 ------- mmengine/data/instance_data.py | 15 +- mmengine/hooks/logger_hook.py | 8 +- mmengine/hooks/naive_visualization_hook.py | 7 +- mmengine/logging/message_hub.py | 9 +- mmengine/registry/__init__.py | 4 +- mmengine/registry/registry.py | 10 +- mmengine/registry/root.py | 4 +- mmengine/runner/loops.py | 28 +- mmengine/runner/runner.py | 78 +- mmengine/visualization/__init__.py | 8 +- mmengine/visualization/utils.py | 62 +- mmengine/visualization/vis_backend.py | 494 +++++++++++ mmengine/visualization/visualizer.py | 801 ++++++++++------- mmengine/visualization/writer.py | 823 ------------------ tests/test_hook/test_logger_hook.py | 5 +- .../test_naive_visualization_hook.py | 2 +- tests/test_registry/test_registry.py | 15 + tests/test_runner/test_runner.py | 46 +- tests/test_visualizer/test_vis_backend.py | 200 +++++ tests/test_visualizer/test_visualizer.py | 240 +++-- tests/test_visualizer/test_writer.py | 484 ---------- 22 files changed, 1570 insertions(+), 2074 deletions(-) delete mode 100644 docs/zh_cn/tutorials/visualization.md create mode 100644 mmengine/visualization/vis_backend.py delete mode 100644 mmengine/visualization/writer.py create mode 100644 tests/test_visualizer/test_vis_backend.py delete mode 100644 tests/test_visualizer/test_writer.py diff --git a/docs/zh_cn/tutorials/visualization.md b/docs/zh_cn/tutorials/visualization.md deleted file mode 100644 index 80acabcc..00000000 --- a/docs/zh_cn/tutorials/visualization.md +++ /dev/null @@ -1,301 +0,0 @@ -# å¯è§†åŒ– (Visualization) - -## 概述 - -**(1) 总体介ç»** - -å¯è§†åŒ–å¯ä»¥ç»™æ·±åº¦å¦ä¹ 的模型è®ç»ƒå’Œæµ‹è¯•è¿‡ç¨‹æ供直观解释。在 OpenMMLab 算法库ä¸ï¼Œæˆ‘们期望å¯è§†åŒ–功能的设计能满足以下需求: - -- æ供丰富的开箱å³ç”¨å¯è§†åŒ–功能,能够满足大部分计算机视觉å¯è§†åŒ–任务 -- 高扩展性,å¯è§†åŒ–åŠŸèƒ½é€šå¸¸å¤šæ ·åŒ–ï¼Œåº”è¯¥èƒ½å¤Ÿé€šè¿‡ç®€å•æ‰©å±•å®žçŽ°å®šåˆ¶éœ€æ±‚ -- 能够在è®ç»ƒå’Œæµ‹è¯•æµç¨‹çš„ä»»æ„点ä½è¿›è¡Œå¯è§†åŒ– -- OpenMMLab å„个算法库具有统一å¯è§†åŒ–接å£ï¼Œåˆ©äºŽç”¨æˆ·ç†è§£å’Œç»´æŠ¤ - -基于上述需求,OpenMMLab 2.0 引入了绘制对象 Visualizer 和写端对象 Writer 的概念 - -- **Visualizer è´Ÿè´£å•å¼ 图片的绘制功能** - - MMEngine æ供了以 Matplotlib 库为绘制åŽç«¯çš„ `Visualizer` 类,其具备如下功能: - - - æä¾›äº†ä¸€ç³»åˆ—å’Œè§†è§‰ä»»åŠ¡æ— å…³çš„åŸºç¡€æ–¹æ³•ï¼Œä¾‹å¦‚ `draw_bboxes` å’Œ `draw_texts` ç‰ - - å„个基础方法支æŒé“¾å¼è°ƒç”¨ï¼Œæ–¹ä¾¿å åŠ ç»˜åˆ¶æ˜¾ç¤º - - 通过 `draw_featmap` æ供绘制特å¾å›¾åŠŸèƒ½ - - å„个下游算法库å¯ä»¥ç»§æ‰¿ `Visualizer` 并在 `draw` 接å£ä¸å®žçŽ°æ‰€éœ€çš„å¯è§†åŒ–功能,例如 MMDetection ä¸çš„ `DetVisualizer` 继承自 `Visualizer` 并在 `draw` 接å£ä¸å®žçŽ°å¯è§†åŒ–检测框ã€å®žä¾‹æŽ©ç å’Œè¯ä¹‰åˆ†å‰²å›¾ç‰åŠŸèƒ½ã€‚Visualizer 类的 UML 关系图如下 - - <div align="center"> - <img src="https://user-images.githubusercontent.com/17425982/154475592-7208a34b-f6cb-4171-b0be-9dbb13306862.png" > - </div> - -- **Writer 负责将å„类数æ®å†™å…¥åˆ°æŒ‡å®šåŽç«¯** - - 为了统一接å£è°ƒç”¨ï¼ŒMMEngine æ供了统一的抽象类 `BaseWriter`,和一些常用的 Writer 如 `LocalWriter` æ¥æ”¯æŒå°†æ•°æ®å†™å…¥æœ¬åœ°ï¼Œ`TensorboardWriter` æ¥æ”¯æŒå°†æ•°æ®å†™å…¥ Tensorboard,`WandbWriter` æ¥æ”¯æŒå°†æ•°æ®å†™å…¥ Wandb。用户也å¯ä»¥è‡ªå®šä¹‰ Writer æ¥å°†æ•°æ®å†™å…¥è‡ªå®šä¹‰åŽç«¯ã€‚写入的数æ®å¯ä»¥æ˜¯å›¾ç‰‡ï¼Œæ¨¡åž‹ç»“æž„å›¾ï¼Œæ ‡é‡å¦‚æ¨¡åž‹ç²¾åº¦æŒ‡æ ‡ç‰ã€‚ - - 考虑到在è®ç»ƒæˆ–者测试过程ä¸å¯èƒ½åŒæ—¶å˜åœ¨å¤šä¸ª Writer 对象,例如åŒæ—¶æƒ³è¿›è¡Œæœ¬åœ°å’Œè¿œç¨‹ç«¯å†™æ•°æ®ï¼Œä¸ºæ¤è®¾è®¡äº† `ComposedWriter` 负责管ç†æ‰€æœ‰è¿è¡Œä¸å®žä¾‹åŒ–çš„ Writer 对象,其会自动管ç†æ‰€æœ‰ Writer 对象,并é历调用所有 Writer 对象的方法。Writer 类的 UML 关系图如下 - <div align="center"> - <img src="https://user-images.githubusercontent.com/17425982/157000633-9f552539-f722-44b1-b253-1abaf4a8eba6.png" > - </div> - -**(2) Writer å’Œ Visualizer 关系** - -Writer å¯¹è±¡çš„æ ¸å¿ƒåŠŸèƒ½æ˜¯å†™å„类数æ®åˆ°æŒ‡å®šåŽç«¯ä¸ï¼Œä¾‹å¦‚写图片ã€å†™æ¨¡åž‹å›¾ã€å†™è¶…å‚å’Œå†™æ¨¡åž‹ç²¾åº¦æŒ‡æ ‡ç‰ï¼ŒåŽç«¯å¯ä»¥æŒ‡å®šä¸ºæœ¬åœ°å˜å‚¨ã€Wandb å’Œ Tensorboard ç‰ç‰ã€‚在写图片过程ä¸ï¼Œé€šå¸¸å¸Œæœ›èƒ½å¤Ÿå°†é¢„æµ‹ç»“æžœæˆ–è€…æ ‡æ³¨ç»“æžœç»˜åˆ¶åˆ°å›¾ç‰‡ä¸Šï¼Œç„¶åŽå†è¿›è¡Œå†™æ“作,为æ¤åœ¨ Writer 内部维护了 Visualizer 对象,将 Visualizer 作为 Writer 的一个属性。需è¦æ³¨æ„的是: - -- åªæœ‰è°ƒç”¨äº† Writer ä¸çš„ `add_image` 写图片功能时候æ‰å¯èƒ½ä¼šç”¨åˆ° Visualizer 对象,其余接å£å’Œ Visualizer 没有关系 -- 考虑到æŸäº› Writer åŽç«¯æœ¬èº«å°±å…·å¤‡ç»˜åˆ¶åŠŸèƒ½ä¾‹å¦‚ `WandbWriter`,æ¤æ—¶ `WandbWriter` ä¸çš„ Visualizer 属性就是å¯é€‰çš„,如果用户在åˆå§‹åŒ–æ—¶å€™ä¼ å…¥äº† Visualizer 对象,则在 `add_image` 时候会调用 Visualizer 对象,å¦åˆ™ä¼šç›´æŽ¥è°ƒç”¨ Wandb 本身 API 进行图片绘制 -- `LocalWriter` å’Œ `TensorboardWriter` 由于绘制功能å•ä¸€ï¼Œç›®å‰å¼ºåˆ¶ç”± Visualizer 对象绘制,所以这两个 Writer å¿…é¡»ä¼ å…¥ Visualizer 或者å类对象 - -`WandbWriter` 的一个简略的演示代ç 如下 - -```python -# 为了方便ç†è§£ï¼Œæ²¡æœ‰ç»§æ‰¿ BaseWriter -class WandbWriter: - def __init__(self, visualizer=None): - self._visualizer = None - if visualizer: - # 示例é…ç½® visualizer=dict(type='DetVisualizer') - self._visualizer = VISUALIZERS.build(visualizer) - - @property - def visualizer(self): - return self._visualizer - - def add_image(self, name, image, gt_sample=None, pred_sample=None, draw_gt=True, draw_pred=True, step=0, **kwargs): - if self._visualize: - self._visualize.draw(image, gt_sample, pred_sample, draw_gt, draw_pred) - # 调用 Writer API 写图片到åŽç«¯ - self.wandb.log({name: self.visualizer.get_image()}, ...) - ... - else: - # 调用 Writer API 汇总并写图片到åŽç«¯ - ... - - def add_scalar(self, name, value, step): - self.wandb.log({name: value}, ...) -``` - - -## 绘制对象 Visualizer - -绘制对象 Visualizer è´Ÿè´£å•å¼ 图片的å„类绘制功能,默认绘制åŽç«¯ä¸º Matplotlib。为了统一 OpenMMLab å„个算法库的å¯è§†åŒ–接å£ï¼ŒMMEngine 定义æ供了基础绘制功能的 `Visualizer` 类,下游库å¯ä»¥ç»§æ‰¿ `Visualizer` 并实现 `draw` 接å£æ¥æ»¡è¶³è‡ªå·±çš„绘制需求。 - -### Visualizer - -`Visualizer` æ供了基础而通用的绘制功能,主è¦æŽ¥å£å¦‚下: - -**(1) ç»˜åˆ¶æ— å…³çš„åŠŸèƒ½æ€§æŽ¥å£** - -- [set_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.set_image) 设置原始图片数æ®ï¼Œé»˜è®¤è¾“å…¥å›¾ç‰‡æ ¼å¼ä¸º RGB -- [get_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.get_image) 获å–绘制åŽçš„ Numpy æ ¼å¼å›¾ç‰‡æ•°æ®ï¼Œé»˜è®¤è¾“å‡ºæ ¼å¼ä¸º RGB -- [show](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.show) å¯è§†åŒ– -- [register_task](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.register_task) 注册绘制函数(其作用在 *自定义 Visualizer* å°èŠ‚æè¿°) - -**(2) 绘制相关接å£** - -- [draw](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw) ç”¨æˆ·ä½¿ç”¨çš„æŠ½è±¡ç»˜åˆ¶æŽ¥å£ -- [draw_featmap](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_featmap) 绘制特å¾å›¾ -- [draw_bboxes](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_bboxes) 绘制å•ä¸ªæˆ–者多个边界框 -- [draw_texts](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_texts) 绘制å•ä¸ªæˆ–者多个文本框 -- [draw_lines](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.lines) 绘制å•ä¸ªæˆ–者多个线段 -- [draw_circles](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_circles) 绘制å•ä¸ªæˆ–者多个圆 -- [draw_polygons](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_polygons) 绘制å•ä¸ªæˆ–者多个多边形 -- [draw_binary_masks](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_binary_mask) 绘制å•ä¸ªæˆ–者多个二值掩ç - -用户除了å¯ä»¥å•ç‹¬è°ƒç”¨ `Visualizer` ä¸åŸºç¡€ç»˜åˆ¶æŽ¥å£ï¼ŒåŒæ—¶ä¹Ÿæ供了链å¼è°ƒç”¨åŠŸèƒ½å’Œç‰¹å¾å›¾å¯è§†åŒ–功能。`draw` 函数是抽象接å£ï¼Œå†…部没有任何实现,继承了 Visualizer çš„ç±»å¯ä»¥å®žçŽ°è¯¥æŽ¥å£ï¼Œä»Žè€Œå¯¹å¤–æ供统一的绘制功能,而 `draw_xxx` ç‰ç›®çš„是æä¾›æœ€åŸºç¡€çš„ç»˜åˆ¶åŠŸèƒ½ï¼Œç”¨æˆ·ä¸€èˆ¬æ— éœ€é‡å†™ã€‚ - -**(1) 链å¼è°ƒç”¨** - -例如用户先绘制边界框,在æ¤åŸºç¡€ä¸Šç»˜åˆ¶æ–‡æœ¬ï¼Œç»˜åˆ¶çº¿æ®µï¼Œåˆ™è°ƒç”¨è¿‡ç¨‹ä¸ºï¼š - -```python -visualizer.set_image(image) -visualizer.draw_bboxes(...).draw_texts(...).draw_lines(...) -visualizer.show() # å¯è§†åŒ–绘制结果 -``` - -**(2) å¯è§†åŒ–特å¾å›¾** - -特å¾å›¾å¯è§†åŒ–是一个常è§çš„功能,通过调用 `draw_featmap` å¯ä»¥ç›´æŽ¥å¯è§†åŒ–特å¾å›¾ï¼Œå…¶å‚数定义为: - -```python -@staticmethod -def draw_featmap(tensor_chw: torch.Tensor, # è¾“å…¥æ ¼å¼è¦æ±‚为 CHW - image: Optional[np.ndarray] = None, # 如果åŒæ—¶è¾“入了 image æ•°æ®ï¼Œåˆ™ç‰¹å¾å›¾ä¼šå åŠ åˆ° image 上绘制 - mode: Optional[str] = 'mean', # 多个通é“压缩为å•é€šé“çš„ç–ç•¥ - topk: int = 10, # å¯é€‰æ‹©æ¿€æ´»åº¦æœ€é«˜çš„ topk 个特å¾å›¾æ˜¾ç¤º - arrangement: Tuple[int, int] = (5, 2), # 多通é“å±•å¼€ä¸ºå¤šå¼ å›¾æ—¶å€™å¸ƒå±€ - alpha: float = 0.3) -> np.ndarray: # 图片和特å¾å›¾ç»˜åˆ¶çš„å åŠ æ¯”ä¾‹ -``` - -特å¾å›¾å¯è§†åŒ–功能较多,目å‰ä¸æ”¯æŒ Batch 输入 - -- mode ä¸æ˜¯ None,topk æ— æ•ˆï¼Œä¼šå°†å¤šä¸ªé€šé“输出采用 mode 模å¼å‡½æ•°åŽ‹ç¼©ä¸ºå•é€šé“,å˜æˆå•å¼ å›¾ç‰‡æ˜¾ç¤ºï¼Œç›®å‰ mode ä»…æ”¯æŒ Noneã€'mean'ã€'max' å’Œ 'min' å‚数输入 -- mode 是 None,topk 有效,如果 topk ä¸æ˜¯ -1,则会按照激活度排åºé€‰æ‹© topk 个通é“显示,æ¤æ—¶å¯ä»¥é€šè¿‡ arrangement å‚数指定显示的布局 -- mode 是 None,topk 有效,如果 `topk = -1`,æ¤æ—¶é€šé“ C 必须是 1 或者 3 表示输入数æ®æ˜¯å›¾ç‰‡ï¼Œå¯ä»¥ç›´æŽ¥æ˜¾ç¤ºï¼Œå¦åˆ™æŠ¥é”™æ示用户应该设置 mode æ¥åŽ‹ç¼©é€šé“ - -```python -featmap=visualizer.draw_featmap(tensor_chw,image) -``` - -### 自定义 Visualizer - -自定义的 Visualizer ä¸å¤§éƒ¨åˆ†æƒ…况下åªéœ€è¦å®žçŽ° `get_image` å’Œ `draw` 接å£ã€‚`draw` 是最高层的用户调用接å£ï¼Œ`draw` 接å£è´Ÿè´£æ‰€æœ‰ç»˜åˆ¶åŠŸèƒ½ï¼Œä¾‹å¦‚绘制检测框ã€æ£€æµ‹æŽ©ç mask å’Œ 检测è¯ä¹‰åˆ†å‰²å›¾ç‰ç‰ã€‚ä¾æ®ä»»åŠ¡çš„ä¸åŒï¼Œ`draw` 接å£å®žçŽ°çš„å¤æ‚度也ä¸åŒã€‚ - -ä»¥ç›®æ ‡æ£€æµ‹å¯è§†åŒ–需求为例,å¯èƒ½éœ€è¦åŒæ—¶ç»˜åˆ¶è¾¹ç•Œæ¡† bboxã€æŽ©ç mask å’Œè¯ä¹‰åˆ†å‰²å›¾ seg_map,如果如æ¤å¤šåŠŸèƒ½å…¨éƒ¨å†™åˆ° `draw` 方法ä¸ä¼šéš¾ä»¥ç†è§£å’Œç»´æŠ¤ã€‚为了解决该问题,`Visualizer` 基于 OpenMMLab 2.0 抽象数æ®æŽ¥å£è§„范支æŒäº† `register_task` 函数。å‡è®¾ MMDetection ä¸éœ€è¦åŒæ—¶ç»˜åˆ¶é¢„测结果ä¸çš„ instances å’Œ sem_seg,å¯ä»¥åœ¨ MMDetection çš„ `DetVisualizer` ä¸å®žçŽ° `draw_instances` å’Œ `draw_sem_seg` 两个方法,用于绘制预测实例和预测è¯ä¹‰åˆ†å‰²å›¾ï¼Œ 我们希望åªè¦è¾“入数æ®ä¸å˜åœ¨ instances 或 sem_seg 时候,对应的两个绘制函数 `draw_instances` å’Œ `draw_sem_seg` 能够自动被调用,而用户ä¸éœ€è¦æ‰‹åŠ¨è°ƒç”¨ã€‚为了实现上述功能,å¯ä»¥é€šè¿‡åœ¨ `draw_instances` å’Œ `draw_sem_seg` ä¸¤ä¸ªå‡½æ•°åŠ ä¸Š `@Visualizer.register_task` 装饰器,æ¤æ—¶ `task_dict` ä¸å°±ä¼šå˜å‚¨å—ç¬¦ä¸²å’Œå‡½æ•°çš„æ˜ å°„å…³ç³»ï¼Œåœ¨è°ƒç”¨ `draw` 方法时候就å¯ä»¥é€šè¿‡ `self.task_dict`获å–到已ç»è¢«æ³¨å†Œçš„函数。一个简略的实现如下所示 - -```python -class DetVisualizer(Visualizer): - - def draw(self, image, gt_sample=None, pred_sample=None, draw_gt=True, draw_pred=True): - # 将图片和 matplotlib å¸ƒå±€å…³è” - self.set_image(image) - - if draw_gt: - # self.task_dict 内部å˜å‚¨å¦‚下信æ¯ï¼š - # dict(instances=draw_instance 方法,sem_seg=draw_sem_seg 方法) - for task in self.task_dict: - task_attr = 'gt_' + task - if task_attr in gt_sample: - self.task_dict[task](self, gt_sample[task_attr], 'gt') - if draw_pred: - for task in self.task_dict: - task_attr = 'pred_' + task - if task_attr in pred_sample: - self.task_dict[task](self, pred_sample[task_attr], 'pred') - - # data_type 用于区分当å‰ç»˜åˆ¶çš„å†…å®¹æ˜¯æ ‡æ³¨è¿˜æ˜¯é¢„æµ‹ç»“æžœ - @Visualizer.register_task('instances') - def draw_instance(self, instances, data_type): - ... - - # data_type 用于区分当å‰ç»˜åˆ¶çš„å†…å®¹æ˜¯æ ‡æ³¨è¿˜æ˜¯é¢„æµ‹ç»“æžœ - @Visualizer.register_task('sem_seg') - def draw_sem_seg(self, pixel_data, data_type): - ... -``` - -注æ„:是å¦ä½¿ç”¨ `register_task` 装饰器函数ä¸æ˜¯å¿…须的,如果用户自定义 Visualizer,并且 `draw` 实现éžå¸¸ç®€å•ï¼Œåˆ™æ— 需考虑 `register_task`。 - -在使用 Jupyter notebook 或者其他地方ä¸éœ€è¦å†™æ•°æ®åˆ°æŒ‡å®šåŽç«¯çš„情形下,用户å¯ä»¥è‡ªå·±å®žä¾‹åŒ– visualizer。一个简å•çš„例å如下 - -```python -# 实例化 visualizer -visualizer=dict(type='DetVisualizer') -visualizer = VISUALIZERS.build(visualizer) -visualizer.draw(image, datasample) -visualizer.show() # å¯è§†åŒ–绘制结果 -``` - -## 写端 Writer - -Visualizer åªå®žçŽ°äº†å•å¼ 图片的绘制功能,但是在è®ç»ƒæˆ–者测试过程ä¸ï¼Œå¯¹ä¸€äº›å…³é”®æŒ‡æ ‡æˆ–者模型è®ç»ƒè¶…å‚的记录éžå¸¸é‡è¦ï¼Œæ¤åŠŸèƒ½é€šè¿‡å†™ç«¯ Writer 实现。为了统一接å£è°ƒç”¨ï¼ŒMMEngine æ供了统一的抽象类 `BaseWriter`,和一些常用的 Writer 如 `LocalWriter` ã€`TensorboardWriter` å’Œ `WandbWriter` 。 - -### BaseWriter - -BaseWriter 定义了对外调用的接å£è§„范,主è¦æŽ¥å£å’Œå±žæ€§å¦‚下: - -- [add_params](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_params) 写超å‚到特定åŽç«¯ï¼Œå¸¸è§çš„è®ç»ƒè¶…å‚如åˆå§‹å¦ä¹ 率 LRã€æƒé‡è¡°å‡ç³»æ•°å’Œæ‰¹å¤§å°ç‰ç‰ -- [add_graph](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_graph) 写模型图到特定åŽç«¯ -- [add_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_image) 写图片到特定åŽç«¯ -- [add_scalar](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_scalar) å†™æ ‡é‡åˆ°ç‰¹å®šåŽç«¯ -- [add_scalars](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_scalars) ä¸€æ¬¡æ€§å†™å¤šä¸ªæ ‡é‡åˆ°ç‰¹å®šåŽç«¯ -- [visualizer](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.visualizer) 绘制对象 -- [experiment](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.experiment) 写åŽç«¯å¯¹è±¡ï¼Œä¾‹å¦‚ Wandb 对象和 Tensorboard 对象 - -`BaseWriter` 定义了 5 个常è§çš„写数æ®æŽ¥å£ï¼Œè€ƒè™‘到æŸäº›å†™åŽç«¯åŠŸèƒ½éžå¸¸å¼ºå¤§ï¼Œä¾‹å¦‚ Wandbï¼Œå…¶å…·å¤‡å†™è¡¨æ ¼ï¼Œå†™è§†é¢‘ç‰ç‰åŠŸèƒ½ï¼Œé’ˆå¯¹è¿™ç±»éœ€æ±‚用户å¯ä»¥ç›´æŽ¥èŽ·å– experiment 对象,然åŽè°ƒç”¨å†™åŽç«¯å¯¹è±¡æœ¬èº«çš„ API å³å¯ã€‚ - -### LocalWriterã€TensorboardWriter å’Œ WandbWriter - -`LocalWriter` æ供了将数æ®å†™å…¥åˆ°æœ¬åœ°ç£ç›˜åŠŸèƒ½ã€‚如果用户需è¦å†™å›¾ç‰‡åˆ°ç¡¬ç›˜ï¼Œåˆ™**å¿…é¡»è¦é€šè¿‡åˆå§‹åŒ–å‚æ•°æä¾› Visualizer对象**。其典型用法为: - -```python -# é…置文件 -writer=dict(type='LocalWriter', save_dir='demo_dir', visualizer=dict(type='DetVisualizer')) -# 实例化和调用 -local_writer=WRITERS.build(writer) -# 写模型精度值 -local_writer.add_scalar('mAP', 0.9) -local_writer.add_scalars({'loss': 1.2, 'acc': 0.8}) -# å†™è¶…å‚ -local_writer.add_params(dict(lr=0.1, mode='linear')) -# 写图片 -local_writer.add_image('demo_image', image, datasample) -``` - -如果用户有自定义绘制需求,则å¯ä»¥é€šè¿‡èŽ·å–内部的 visualizer 属性æ¥å®žçŽ°ï¼Œå¦‚下所示 - -```python -# é…置文件 -writer=dict(type='LocalWriter', save_dir='demo_dir', visualizer=dict(type='DetVisualizer')) -# 实例化和调用 -local_writer=WRITERS.build(writer) -# 写图片 -local_writer.visualizer.draw_bboxes(np.array([0, 0, 1, 1])) -local_writer.add_image('img', local_writer.visualizer.get_image()) - -# 绘制特å¾å›¾å¹¶ä¿å˜åˆ°æœ¬åœ° -featmap_image=local_writer.visualizer.draw_featmap(tensor_chw) -local_writer.add_image('featmap', featmap_image) -``` - -`TensorboardWriter` æ供了将å„类数æ®å†™å…¥åˆ° Tensorboard 功能,其用法和 LocalWriter éžå¸¸ç±»ä¼¼ã€‚ 注æ„如果用户需è¦å†™å›¾ç‰‡åˆ° Tensorboard,则**å¿…é¡»è¦é€šè¿‡åˆå§‹åŒ–å‚æ•°æä¾› Visualizer对象**。 - -`WandbWriter` æ供了将å„类数æ®å†™å…¥åˆ° Wandb 功能。考虑到 Wandb 本身具备强大的图片功能,在调用 `WandbWriter` çš„ `add_image` 方法时 Visualizer 对象是å¯é€‰çš„,如果用户指定了 Visualizer 对象,则会调用 Visualizer 对象的绘制方法,å¦åˆ™ç›´æŽ¥è°ƒç”¨ Wandb 自带的图片处ç†åŠŸèƒ½ã€‚ - -## 组åˆå†™ç«¯ ComposedWriter - -考虑到在è®ç»ƒæˆ–者测试过程ä¸ï¼Œå¯èƒ½éœ€è¦åŒæ—¶è°ƒç”¨å¤šä¸ª Writer,例如想åŒæ—¶å†™åˆ°æœ¬åœ°å’Œ Wandb 端,为æ¤è®¾è®¡äº†å¯¹å¤–çš„ `ComposedWriter` 类,在è®ç»ƒæˆ–è€…æµ‹è¯•è¿‡ç¨‹ä¸ `ComposedWriter` 会ä¾æ¬¡è°ƒç”¨å„个 Writer 的接å£ï¼Œå…¶æŽ¥å£å’Œ `BaseWriter` 一致,主è¦æŽ¥å£å¦‚下: - -- [add_params](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.add_params) 写超å‚到所有已ç»åŠ 入的åŽç«¯ä¸ï¼Œå¸¸è§çš„è®ç»ƒè¶…å‚如åˆå§‹å¦ä¹ 率 LRã€æƒé‡è¡°å‡ç³»æ•°å’Œæ‰¹å¤§å°ç‰ç‰ -- [add_graph](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.add_graph) 写模型图到所有已ç»åŠ 入的åŽç«¯ä¸ -- [add_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.add_image) 写图片到所有已ç»åŠ 入的åŽç«¯ä¸ -- [add_scalar](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.add_scalar) å†™æ ‡é‡åˆ°æ‰€æœ‰å·²ç»åŠ 入的åŽç«¯ä¸ -- [add_scalars](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.add_scalars) ä¸€æ¬¡æ€§å†™å¤šä¸ªæ ‡é‡åˆ°æ‰€æœ‰å·²ç»åŠ 入的åŽç«¯ä¸ -- [get_writer](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.get_writer) 获å–指定索引的 Writer,任何一个 Writer ä¸åŒ…括了 experiment å’Œ visualizer 属性 -- [get_experiment](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.get_experiment) 获å–指定索引的 experiment -- [get_visualizer](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.get_visualizer) 获å–指定索引的 visualizer -- [close](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.close) 调用所有 Writer çš„ close 方法 - -为了让用户å¯ä»¥åœ¨ä»£ç çš„ä»»æ„ä½ç½®è¿›è¡Œæ•°æ®å¯è§†åŒ–,`ComposedWriter` 类继承至 [全局å¯è®¿é—®åŸºç±» BaseGlobalAccessible](./logging.md/#全局å¯è®¿é—®åŸºç±»baseglobalaccessible)。一旦继承了全局å¯è®¿é—®åŸºç±», 用户就å¯ä»¥é€šè¿‡è°ƒç”¨ `ComposedWriter` 对象的 `get_instance` æ¥èŽ·å–全局对象。其基本用法如下 - -```python -# 创建实例 -writers=[dict(type='LocalWriter', save_dir='temp_dir', visualizer=dict(type='DetVisualizer')), dict(type='WandbWriter')] - -ComposedWriter.create_instance('composed_writer', writers=writers) -``` - -一旦创建实例åŽï¼Œå¯ä»¥åœ¨ä»£ç ä»»æ„ä½ç½®èŽ·å– `ComposedWriter` 对象 - -```python -composed_writer=ComposedWriter.get_instance('composed_writer') - -# 写模型精度值 -composed_writer.add_scalar('mAP', 0.9) -composed_writer.add_scalars({'loss': 1.2, 'acc': 0.8}) -# å†™è¶…å‚ -composed_writer.add_params(dict(lr=0.1, mode='linear')) -# 写图片 -composed_writer.add_image('demo_image', image, datasample) -# 写模型图 -composed_writer.add_graph(model, input_array) -``` - -对于一些用户需è¦çš„自定义绘制需求或者上述接å£æ— 法满足的需求,用户å¯ä»¥é€šè¿‡ `get_xxx` 方法获å–具体对象æ¥å®žçŽ°ç‰¹å®šéœ€æ±‚ - -```python -composed_writer=ComposedWriter.get_instance('composed_writer') - -# 绘制特å¾å›¾ï¼ŒèŽ·å– LocalWriter ä¸çš„ visualizer -visualizer=composed_writer.get_visualizer(0) -featmap_image=visualizer.draw_featmap(tensor_chw) -composed_writer.add_image('featmap', featmap_image) - -# 扩展 add 功能,例如利用 Wandb å¯¹è±¡ç»˜åˆ¶è¡¨æ ¼ -wandb=composed_writer.get_experiment(1) -val_table = wandb.Table(data=my_data, columns=column_names) -wandb.log({'my_val_table': val_table}) - -# é…ç½®ä¸å˜åœ¨å¤šä¸ª Writer,在ä¸æƒ³æ”¹åŠ¨é…置情况下åªä½¿ç”¨ LocalWriter -local_writer=composed_writer.get_writer(0) -local_writer.add_image('demo_image', image, datasample) -``` diff --git a/mmengine/data/instance_data.py b/mmengine/data/instance_data.py index 76d5e996..2c4932f7 100644 --- a/mmengine/data/instance_data.py +++ b/mmengine/data/instance_data.py @@ -7,6 +7,9 @@ import torch from .base_data_element import BaseDataElement +IndexType = Union[str, slice, int, torch.LongTensor, torch.cuda.LongTensor, + torch.BoolTensor, torch.cuda.BoolTensor, np.long, np.bool] + # Modified from # https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa @@ -87,9 +90,7 @@ class InstanceData(BaseDataElement): f'{len(self)} ' super().__setattr__(name, value) - def __getitem__( - self, item: Union[str, slice, int, torch.LongTensor, torch.BoolTensor] - ) -> 'InstanceData': + def __getitem__(self, item: IndexType) -> 'InstanceData': """ Args: item (str, obj:`slice`, @@ -102,7 +103,8 @@ class InstanceData(BaseDataElement): assert len(self) > 0, ' This is a empty instance' assert isinstance( - item, (str, slice, int, torch.LongTensor, torch.BoolTensor)) + item, (str, slice, int, torch.LongTensor, torch.cuda.LongTensor, + torch.BoolTensor, torch.cuda.BoolTensor, np.bool, np.long)) if isinstance(item, str): return getattr(self, item) @@ -118,7 +120,7 @@ class InstanceData(BaseDataElement): if isinstance(item, torch.Tensor): assert item.dim() == 1, 'Only support to get the' \ ' values along the first dimension.' - if isinstance(item, torch.BoolTensor): + if isinstance(item, (torch.BoolTensor, torch.cuda.BoolTensor)): assert len(item) == len(self), f'The shape of the' \ f' input(BoolTensor)) ' \ f'{len(item)} ' \ @@ -136,7 +138,8 @@ class InstanceData(BaseDataElement): elif isinstance(v, list): r_list = [] # convert to indexes from boolTensor - if isinstance(item, torch.BoolTensor): + if isinstance(item, + (torch.BoolTensor, torch.cuda.BoolTensor)): indexes = torch.nonzero(item).view(-1) else: indexes = item diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py index 427e1ff9..aca48652 100644 --- a/mmengine/hooks/logger_hook.py +++ b/mmengine/hooks/logger_hook.py @@ -165,11 +165,7 @@ class LoggerHook(Hook): self.json_log_path = osp.join(runner.work_dir, f'{runner.timestamp}.log.json') - self.yaml_log_path = osp.join(runner.work_dir, - f'{runner.timestamp}.log.json') self.start_iter = runner.iter - if runner.meta is not None: - runner.writer.add_params(runner.meta, file_path=self.yaml_log_path) def after_train_iter(self, runner, @@ -298,7 +294,7 @@ class LoggerHook(Hook): log_str += ', '.join(log_items) runner.logger.info(log_str) # Write logs to local, tensorboad, and wandb. - runner.writer.add_scalars( + runner.visualizer.add_scalars( tag, step=runner.iter + 1, file_path=self.json_log_path) def _log_val(self, runner) -> None: @@ -330,7 +326,7 @@ class LoggerHook(Hook): log_str += ', '.join(log_items) runner.logger.info(log_str) # Write tag. - runner.writer.add_scalars( + runner.visualizer.add_scalars( tag, step=cur_iter, file_path=self.json_log_path) def _get_window_size(self, runner, window_size: Union[int, str]) \ diff --git a/mmengine/hooks/naive_visualization_hook.py b/mmengine/hooks/naive_visualization_hook.py index e8bd3834..2819563a 100644 --- a/mmengine/hooks/naive_visualization_hook.py +++ b/mmengine/hooks/naive_visualization_hook.py @@ -11,6 +11,8 @@ from mmengine.registry import HOOKS from mmengine.utils.misc import tensor2imgs +# TODO: Due to interface changes, the current class +# functions incorrectly @HOOKS.register_module() class NaiveVisualizationHook(Hook): """Show or Write the predicted results during the process of testing. @@ -68,5 +70,6 @@ class NaiveVisualizationHook(Hook): data_sample.get('scale', ori_shape)) origin_image = cv2.resize(input, ori_shape) name = osp.basename(data_sample.img_path) - runner.writer.add_image(name, origin_image, data_sample, - output, self.draw_gt, self.draw_pred) + runner.visualizer.add_datasample(name, origin_image, + data_sample, output, + self.draw_gt, self.draw_pred) diff --git a/mmengine/logging/message_hub.py b/mmengine/logging/message_hub.py index 75a2a4bc..1a6f4e67 100644 --- a/mmengine/logging/message_hub.py +++ b/mmengine/logging/message_hub.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import copy from collections import OrderedDict from typing import Any, Union @@ -103,7 +102,8 @@ class MessageHub(ManagerMixin): Returns: OrderedDict: A copy of all runtime information. """ - return copy.deepcopy(self._runtime_info) + # return copy.deepcopy(self._runtime_info) + return self._runtime_info def get_log(self, key: str) -> LogBuffer: """Get ``LogBuffer`` instance by key. @@ -136,7 +136,10 @@ class MessageHub(ManagerMixin): if key not in self.runtime_info: raise KeyError(f'{key} is not found in Messagehub.log_buffers: ' f'instance name is: {MessageHub.instance_name}') - return copy.deepcopy(self._runtime_info[key]) + + # TODO: There are restrictions on objects that can be saved + # return copy.deepcopy(self._runtime_info[key]) + return self._runtime_info[key] def _get_valid_value(self, key: str, value: Union[torch.Tensor, np.ndarray, int, float])\ diff --git a/mmengine/registry/__init__.py b/mmengine/registry/__init__.py index ead8cb0a..56c65b80 100644 --- a/mmengine/registry/__init__.py +++ b/mmengine/registry/__init__.py @@ -4,12 +4,12 @@ from .registry import Registry, build_from_cfg from .root import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, METRICS, MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS, - TRANSFORMS, VISUALIZERS, WEIGHT_INITIALIZERS, WRITERS) + TRANSFORMS, VISBACKENDS, VISUALIZERS, WEIGHT_INITIALIZERS) __all__ = [ 'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', 'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS', - 'METRICS', 'MODEL_WRAPPERS', 'LOOPS', 'WRITERS', 'VISUALIZERS', + 'METRICS', 'MODEL_WRAPPERS', 'LOOPS', 'VISBACKENDS', 'VISUALIZERS', 'DefaultScope' ] diff --git a/mmengine/registry/registry.py b/mmengine/registry/registry.py index 3ee7d4d6..7e2b4845 100644 --- a/mmengine/registry/registry.py +++ b/mmengine/registry/registry.py @@ -6,7 +6,7 @@ from collections.abc import Callable from typing import Any, Dict, List, Optional, Tuple, Type, Union from ..config import Config, ConfigDict -from ..utils import is_seq_of +from ..utils import ManagerMixin, is_seq_of from .default_scope import DefaultScope @@ -88,7 +88,13 @@ def build_from_cfg( f'type must be a str or valid type, but got {type(obj_type)}') try: - return obj_cls(**args) # type: ignore + # If `obj_cls` inherits from `ManagerMixin`, it should be instantiated + # by `ManagerMixin.get_instance` to ensure that it can be accessed + # globally. + if issubclass(obj_cls, ManagerMixin): + return obj_cls.get_instance(**args) # type: ignore + else: + return obj_cls(**args) # type: ignore except Exception as e: # Normal TypeError does not print class name. raise type(e)(f'{obj_cls.__name__}: {e}') # type: ignore diff --git a/mmengine/registry/root.py b/mmengine/registry/root.py index 571d55cb..62d72f70 100644 --- a/mmengine/registry/root.py +++ b/mmengine/registry/root.py @@ -43,5 +43,5 @@ TASK_UTILS = Registry('task util') # manage visualizer VISUALIZERS = Registry('visualizer') -# manage writer -WRITERS = Registry('writer') +# manage visualizer backend +VISBACKENDS = Registry('vis_backend') diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index d791c52c..f1821311 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -27,6 +27,14 @@ class EpochBasedTrainLoop(BaseLoop): super().__init__(runner, dataloader) self._max_epochs = max_epochs self._max_iters = max_epochs * len(self.dataloader) + if hasattr(self.dataloader.dataset, 'metainfo'): + self.runner.visualizer.dataset_meta = \ + self.dataloader.dataset.metainfo + else: + warnings.warn( + f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' + 'metainfo. ``dataset_meta`` in visualizer will be ' + 'None.') @property def max_epochs(self): @@ -100,6 +108,14 @@ class IterBasedTrainLoop(BaseLoop): max_iters: int) -> None: super().__init__(runner, dataloader) self._max_iters = max_iters + if hasattr(self.dataloader.dataset, 'metainfo'): + self.runner.visualizer.dataset_meta = \ + self.dataloader.dataset.metainfo + else: + warnings.warn( + f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' + 'metainfo. ``dataset_meta`` in visualizer will be ' + 'None.') self.dataloader = iter(self.dataloader) @property @@ -176,11 +192,13 @@ class ValLoop(BaseLoop): self.evaluator = evaluator # type: ignore if hasattr(self.dataloader.dataset, 'metainfo'): self.evaluator.dataset_meta = self.dataloader.dataset.metainfo + self.runner.visualizer.dataset_meta = \ + self.dataloader.dataset.metainfo else: warnings.warn( f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' - 'metainfo. ``dataset_meta`` in evaluator and metric will be ' - 'None.') + 'metainfo. ``dataset_meta`` in evaluator, metric and ' + 'visualizer will be None.') self.interval = interval def run(self): @@ -240,11 +258,13 @@ class TestLoop(BaseLoop): self.evaluator = evaluator # type: ignore if hasattr(self.dataloader.dataset, 'metainfo'): self.evaluator.dataset_meta = self.dataloader.dataset.metainfo + self.runner.visualizer.dataset_meta = \ + self.dataloader.dataset.metainfo else: warnings.warn( f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' - 'metainfo. ``dataset_meta`` in evaluator and metric will be ' - 'None.') + 'metainfo. ``dataset_meta`` in evaluator, metric and ' + 'visualizer will be None.') def run(self) -> None: """Launch test.""" diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 54e1e710..86a57f36 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -30,10 +30,10 @@ from mmengine.model import is_model_wrapper from mmengine.optim import _ParamScheduler, build_optimizer from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS, - DefaultScope) + VISUALIZERS, DefaultScope) from mmengine.utils import (TORCH_VERSION, digit_version, find_latest_checkpoint, is_list_of, symlink) -from mmengine.visualization import ComposedWriter +from mmengine.visualization import Visualizer from .base_loop import BaseLoop from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model, get_state_dict, save_checkpoint, weights_to_cpu) @@ -129,8 +129,8 @@ class Runner: dict(dist_cfg=dict(backend='nccl')). log_level (int or str): The log level of MMLogger handlers. Defaults to 'INFO'. - writer (ComposedWriter or dict, optional): A ComposedWriter object or a - dict build ComposedWriter object. Defaults to None. If not + visualizer (Visualizer or dict, optional): A Visualizer object or a + dict build Visualizer object. Defaults to None. If not specified, default config will be used. default_scope (str, optional): Used to reset registries location. Defaults to None. @@ -184,9 +184,9 @@ class Runner: param_scheduler=dict(type='ParamSchedulerHook')), launcher='none', env_cfg=dict(dist_cfg=dict(backend='nccl')), - writer=dict( - name='composed_writer', - writers=[dict(type='LocalWriter', save_dir='temp_dir')]) + visualizer=dict(type='Visualizer', + vis_backends=[dict(type='LocalVisBackend', + save_dir='temp_dir')]) ) >>> runner = Runner.from_cfg(cfg) >>> runner.train() @@ -218,7 +218,7 @@ class Runner: launcher: str = 'none', env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')), log_level: str = 'INFO', - writer: Optional[Union[ComposedWriter, Dict]] = None, + visualizer: Optional[Union[Visualizer, Dict]] = None, default_scope: Optional[str] = None, randomness: Dict = dict(seed=None), experiment_name: Optional[str] = None, @@ -310,16 +310,17 @@ class Runner: else: self._experiment_name = self.timestamp - self.logger = self.build_logger(log_level=log_level) - # message hub used for component interaction - self.message_hub = self.build_message_hub() - # writer used for writing log or visualizing all kinds of data - self.writer = self.build_writer(writer) # Used to reset registries location. See :meth:`Registry.build` for # more details. self.default_scope = DefaultScope.get_instance( self._experiment_name, scope_name=default_scope) + self.logger = self.build_logger(log_level=log_level) + # message hub used for component interaction + self.message_hub = self.build_message_hub() + # visualizer used for writing log or visualizing all kinds of data + self.visualizer = self.build_visualizer(visualizer) + self._load_from = load_from self._resume = resume # flag to mark whether checkpoint has been loaded or resumed @@ -378,7 +379,7 @@ class Runner: launcher=cfg.get('launcher', 'none'), env_cfg=cfg.get('env_cfg'), # type: ignore log_level=cfg.get('log_level', 'INFO'), - writer=cfg.get('writer'), + visualizer=cfg.get('visualizer'), default_scope=cfg.get('default_scope'), randomness=cfg.get('randomness', dict(seed=None)), experiment_name=cfg.get('experiment_name'), @@ -623,37 +624,42 @@ class Runner: return MessageHub.get_instance(**message_hub) - def build_writer( - self, - writer: Optional[Union[ComposedWriter, - Dict]] = None) -> ComposedWriter: - """Build a global asscessable ComposedWriter. + def build_visualizer( + self, + visualizer: Optional[Union[Visualizer, + Dict]] = None) -> Visualizer: + """Build a global asscessable Visualizer. Args: - writer (ComposedWriter or dict, optional): A ComposedWriter object - or a dict to build ComposedWriter object. If ``writer`` is a - ComposedWriter object, just returns itself. If not specified, - default config will be used to build ComposedWriter object. + visualizer (Visualizer or dict, optional): A Visualizer object + or a dict to build Visualizer object. If ``visualizer`` is a + Visualizer object, just returns itself. If not specified, + default config will be used to build Visualizer object. Defaults to None. Returns: - ComposedWriter: A ComposedWriter object build from ``writer``. + Visualizer: A Visualizer object build from ``visualizer``. """ - if isinstance(writer, ComposedWriter): - return writer - elif writer is None: - writer = dict( + if visualizer is None: + visualizer = dict( name=self._experiment_name, - writers=[dict(type='LocalWriter', save_dir=self._work_dir)]) - elif isinstance(writer, dict): - # ensure writer containing name key - writer.setdefault('name', self._experiment_name) + vis_backends=[ + dict(type='LocalVisBackend', save_dir=self._work_dir) + ]) + return Visualizer.get_instance(**visualizer) + + if isinstance(visualizer, Visualizer): + return visualizer + + if isinstance(visualizer, dict): + # ensure visualizer containing name key + visualizer.setdefault('name', self._experiment_name) + visualizer.setdefault('save_dir', self._work_dir) + return VISUALIZERS.build(visualizer) else: raise TypeError( - 'writer should be ComposedWriter object, a dict or None, ' - f'but got {writer}') - - return ComposedWriter.get_instance(**writer) + 'visualizer should be Visualizer object, a dict or None, ' + f'but got {visualizer}') def build_model(self, model: Union[nn.Module, Dict]) -> nn.Module: """Build model. diff --git a/mmengine/visualization/__init__.py b/mmengine/visualization/__init__.py index 892c3daa..6c8b0bb5 100644 --- a/mmengine/visualization/__init__.py +++ b/mmengine/visualization/__init__.py @@ -1,9 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .vis_backend import (BaseVisBackend, LocalVisBackend, + TensorboardVisBackend, WandbVisBackend) from .visualizer import Visualizer -from .writer import (BaseWriter, ComposedWriter, LocalWriter, - TensorboardWriter, WandbWriter) __all__ = [ - 'Visualizer', 'BaseWriter', 'LocalWriter', 'WandbWriter', - 'TensorboardWriter', 'ComposedWriter' + 'Visualizer', 'BaseVisBackend', 'LocalVisBackend', 'WandbVisBackend', + 'TensorboardVisBackend' ] diff --git a/mmengine/visualization/utils.py b/mmengine/visualization/utils.py index 97803ce2..a0033dac 100644 --- a/mmengine/visualization/utils.py +++ b/mmengine/visualization/utils.py @@ -1,6 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, List, Tuple, Type, Union +from typing import Any, List, Optional, Tuple, Type, Union + +import cv2 +import matplotlib import numpy as np import torch @@ -84,3 +87,60 @@ def check_type_and_length(name: str, value: Any, """ check_type(name, value, valid_type) check_length(name, value, valid_length) + + +def color_val_matplotlib(colors): + """Convert various input in RGB order to normalized RGB matplotlib color + tuples, + Args: + color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Color inputs + Returns: + tuple[float]: A tuple of 3 normalized floats indicating RGB channels. + """ + if isinstance(colors, str): + return colors + elif isinstance(colors, tuple): + assert len(colors) == 3 + for channel in colors: + assert 0 <= channel <= 255 + colors = [channel / 255 for channel in colors] + return tuple(colors) + elif isinstance(colors, list): + colors = [color_val_matplotlib(color) for color in colors] + return colors + else: + raise TypeError(f'Invalid type for color: {type(colors)}') + + +def str_color_to_rgb(color): + color = matplotlib.colors.to_rgb(color) + color = tuple([int(c * 255) for c in color]) + return color + + +def convert_overlay_heatmap(feat_map: Union[np.ndarray, torch.Tensor], + img: Optional[np.ndarray] = None, + alpha: float = 0.5) -> np.ndarray: + """Convert feat_map to heatmap and overlay on image, if image is not None. + + Args: + feat_map (np.ndarray, torch.Tensor): The feat_map to convert + with of shape (H, W), where H is the image height and W is + the image width. + img (np.ndarray, optional): The origin image. The format + should be RGB. Defaults to None. + alpha (float): The transparency of origin image. Defaults to 0.5. + + Returns: + np.ndarray: heatmap + """ + if isinstance(feat_map, torch.Tensor): + feat_map = feat_map.detach().cpu().numpy() + norm_img = np.zeros(feat_map.shape) + norm_img = cv2.normalize(feat_map, norm_img, 0, 255, cv2.NORM_MINMAX) + norm_img = np.asarray(norm_img, dtype=np.uint8) + heat_img = cv2.applyColorMap(norm_img, cv2.COLORMAP_JET) + heat_img = cv2.cvtColor(heat_img, cv2.COLOR_BGR2RGB) + if img is not None: + heat_img = cv2.addWeighted(img, alpha, heat_img, 1 - alpha, 0) + return heat_img diff --git a/mmengine/visualization/vis_backend.py b/mmengine/visualization/vis_backend.py new file mode 100644 index 00000000..13de36d8 --- /dev/null +++ b/mmengine/visualization/vis_backend.py @@ -0,0 +1,494 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +import time +from abc import ABCMeta, abstractmethod +from typing import Any, Optional, Sequence, Union + +import cv2 +import numpy as np +import torch + +from mmengine.config import Config +from mmengine.fileio import dump +from mmengine.registry import VISBACKENDS +from mmengine.utils import TORCH_VERSION + + +class BaseVisBackend(metaclass=ABCMeta): + """Base class for vis backend. + + All backends must inherit ``BaseVisBackend`` and implement + the required functions. + + Args: + save_dir (str, optional): The root directory to save + the files produced by the backend. Default to None. + """ + + def __init__(self, save_dir: Optional[str] = None): + self._save_dir = save_dir + if self._save_dir: + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + self._save_dir = osp.join(self._save_dir, + f'vis_data_{timestamp}') # type: ignore + + @property + @abstractmethod + def experiment(self) -> Any: + """Return the experiment object associated with this writer. + + The experiment attribute can get the visualizer backend, such as wandb, + tensorboard. If you want to write other data, such as writing a table, + you can directly get the visualizer backend through experiment. + """ + pass + + def add_config(self, config: Config, **kwargs) -> None: + """Record a set of parameters. + + Args: + config (Config): The Config object + """ + pass + + def add_graph(self, model: torch.nn.Module, data_batch: Sequence[dict], + **kwargs) -> None: + """Record graph. + + Args: + model (torch.nn.Module): Model to draw. + data_batch (Sequence[dict]): Batch of data from dataloader. + """ + pass + + def add_image(self, + name: str, + image: np.ndarray, + step: int = 0, + **kwargs) -> None: + """Record image. + + Args: + name (str): The unique identifier for the image to save. + image (np.ndarray, optional): The image to be saved. The format + should be RGB. Default to None. + step (int): Global step value to record. Default to 0. + """ + pass + + def add_scalar(self, + name: str, + value: Union[int, float], + step: int = 0, + **kwargs) -> None: + """Record scalar. + + Args: + name (str): The unique identifier for the scalar to save. + value (float, int): Value to save. + step (int): Global step value to record. Default to 0. + """ + pass + + def add_scalars(self, + scalar_dict: dict, + step: int = 0, + file_path: Optional[str] = None, + **kwargs) -> None: + """Record scalars' data. + + Args: + scalar_dict (dict): Key-value pair storing the tag and + corresponding values. + step (int): Global step value to record. Default to 0. + file_path (str, optional): The scalar's data will be + saved to the `file_path` file at the same time + if the `file_path` parameter is specified. + Default to None. + """ + pass + + def close(self) -> None: + """close an opened object.""" + pass + + +@VISBACKENDS.register_module() +class LocalVisBackend(BaseVisBackend): + """Local vis backend class. + + It can write image, config, scalars, etc. + to the local hard disk. You can get the drawing backend + through the visualizer property for custom drawing. + + Examples: + >>> from mmengine.visualization import LocalVisBackend + >>> import numpy as np + >>> local_vis_backend = LocalVisBackend(save_dir='temp_dir') + >>> img=np.random.randint(0, 256, size=(10, 10, 3)) + >>> local_vis_backend.add_image('img', img) + >>> local_vis_backend.add_scaler('mAP', 0.6) + >>> local_vis_backend.add_scalars({'loss': [1, 2, 3], 'acc': 0.8}) + >>> local_vis_backend.add_image('img', image) + + Args: + save_dir (str, optional): The root directory to save the files + produced by the writer. If it is none, it means no data + is stored. Default None. + img_save_dir (str): The directory to save images. + Default to 'writer_image'. + config_save_file (str): The file to save parameters. + Default to 'parameters.yaml'. + scalar_save_file (str): The file to save scalar values. + Default to 'scalars.json'. + """ + + def __init__(self, + save_dir: Optional[str] = None, + img_save_dir: str = 'vis_image', + config_save_file: str = 'config.py', + scalar_save_file: str = 'scalars.json'): + assert config_save_file.split('.')[-1] == 'py' + assert scalar_save_file.split('.')[-1] == 'json' + super(LocalVisBackend, self).__init__(save_dir) + if self._save_dir is not None: + os.makedirs(self._save_dir, exist_ok=True) # type: ignore + self._img_save_dir = osp.join( + self._save_dir, # type: ignore + img_save_dir) + self._scalar_save_file = osp.join( + self._save_dir, # type: ignore + scalar_save_file) + self._config_save_file = osp.join( + self._save_dir, # type: ignore + config_save_file) + + @property + def experiment(self) -> 'LocalVisBackend': + """Return the experiment object associated with this visualizer + backend.""" + return self + + def add_config(self, config: Config, **kwargs) -> None: + # TODO + assert isinstance(config, Config) + + def add_image(self, + name: str, + image: np.ndarray = None, + step: int = 0, + **kwargs) -> None: + """Record image to disk. + + Args: + name (str): The unique identifier for the image to save. + image (np.ndarray, optional): The image to be saved. The format + should be RGB. Default to None. + step (int): Global step value to record. Default to 0. + """ + + drawn_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + os.makedirs(self._img_save_dir, exist_ok=True) + save_file_name = f'{name}_{step}.png' + cv2.imwrite(osp.join(self._img_save_dir, save_file_name), drawn_image) + + def add_scalar(self, + name: str, + value: Union[int, float], + step: int = 0, + **kwargs) -> None: + """Add scalar data to disk. + + Args: + name (str): The unique identifier for the scalar to save. + value (float, int): Value to save. + step (int): Global step value to record. Default to 0. + """ + self._dump({name: value, 'step': step}, self._scalar_save_file, 'json') + + def add_scalars(self, + scalar_dict: dict, + step: int = 0, + file_path: Optional[str] = None, + **kwargs) -> None: + """Record scalars. The scalar dict will be written to the default and + specified files if ``file_name`` is specified. + + Args: + scalar_dict (dict): Key-value pair storing the tag and + corresponding values. + step (int): Global step value to record. Default to 0. + file_path (str, optional): The scalar's data will be + saved to the ``file_path`` file at the same time + if the ``file_path`` parameter is specified. + Default to None. + """ + assert isinstance(scalar_dict, dict) + scalar_dict.setdefault('step', step) + if file_path is not None: + assert file_path.split('.')[-1] == 'json' + new_save_file_path = osp.join( + self._save_dir, # type: ignore + file_path) + assert new_save_file_path != self._scalar_save_file, \ + '"file_path" and "scalar_save_file" have the same name, ' \ + 'please set "file_path" to another value' + self._dump(scalar_dict, new_save_file_path, 'json') + self._dump(scalar_dict, self._scalar_save_file, 'json') + + def _dump(self, value_dict: dict, file_path: str, + file_format: str) -> None: + """dump dict to file. + + Args: + value_dict (dict) : Save dict data. + file_path (str): The file path to save data. + file_format (str): The file format to save data. + """ + with open(file_path, 'a+') as f: + dump(value_dict, f, file_format=file_format) + f.write('\n') + + +@VISBACKENDS.register_module() +class WandbVisBackend(BaseVisBackend): + """Write various types of data to wandb. + + Examples: + >>> from mmengine.visualization import WandbVisBackend + >>> import numpy as np + >>> wandb_vis_backend = WandbVisBackend() + >>> img=np.random.randint(0, 256, size=(10, 10, 3)) + >>> wandb_vis_backend.add_image('img', img) + >>> wandb_vis_backend.add_scaler('mAP', 0.6) + >>> wandb_vis_backend.add_scalars({'loss': [1, 2, 3],'acc': 0.8}) + >>> wandb_vis_backend.add_image('img', img) + + Args: + init_kwargs (dict, optional): wandb initialization + input parameters. Default to None. + commit: (bool, optional) Save the metrics dict to the wandb server + and increment the step. If false `wandb.log` just + updates the current metrics dict with the row argument + and metrics won't be saved until `wandb.log` is called + with `commit=True`. Default to True. + save_dir (str, optional): The root directory to save the files + produced by the writer. Default to None. + """ + + def __init__(self, + init_kwargs: Optional[dict] = None, + commit: Optional[bool] = True, + save_dir: Optional[str] = None): + super(WandbVisBackend, self).__init__(save_dir) + self._commit = commit + self._wandb = self._setup_env(init_kwargs) + + @property + def experiment(self): + """Return wandb object. + + The experiment attribute can get the wandb backend, If you want to + write other data, such as writing a table, you can directly get the + wandb backend through experiment. + """ + return self._wandb + + def _setup_env(self, init_kwargs: Optional[dict] = None) -> Any: + """Setup env. + + Args: + init_kwargs (dict): The init args. + + Return: + :obj:`wandb` + """ + try: + import wandb + except ImportError: + raise ImportError( + 'Please run "pip install wandb" to install wandb') + if init_kwargs: + wandb.init(**init_kwargs) + else: + wandb.init() + + return wandb + + def add_config(self, config: Config, **kwargs) -> None: + # TODO + pass + + def add_image(self, + name: str, + image: np.ndarray = None, + step: int = 0, + **kwargs) -> None: + """Record image to wandb. + + Args: + name (str): The unique identifier for the image to save. + image (np.ndarray, optional): The image to be saved. The format + should be RGB. Default to None. + step (int): Global step value to record. Default to 0. + """ + self._wandb.log({name: image}, commit=self._commit, step=step) + + def add_scalar(self, + name: str, + value: Union[int, float], + step: int = 0, + **kwargs) -> None: + """Record scalar data to wandb. + + Args: + name (str): The unique identifier for the scalar to save. + value (float, int): Value to save. + step (int): Global step value to record. Default to 0. + """ + self._wandb.log({name: value}, commit=self._commit, step=step) + + def add_scalars(self, + scalar_dict: dict, + step: int = 0, + file_path: Optional[str] = None, + **kwargs) -> None: + """Record scalar's data to wandb. + + Args: + scalar_dict (dict): Key-value pair storing the tag and + corresponding values. + step (int): Global step value to record. Default to 0. + file_path (str, optional): Useless parameter. Just for + interface unification. Default to None. + """ + self._wandb.log(scalar_dict, commit=self._commit, step=step) + + def close(self) -> None: + """close an opened wandb object.""" + if hasattr(self, '_wandb'): + self._wandb.join() + + +@VISBACKENDS.register_module() +class TensorboardVisBackend(BaseVisBackend): + """Tensorboard class. It can write images, config, scalars, etc. to a + tensorboard file. + + Its drawing function is provided by Visualizer. + + Examples: + >>> from mmengine.visualization import TensorboardVisBackend + >>> import numpy as np + >>> tensorboard_visualizer = TensorboardVisBackend(save_dir='temp_dir') + >>> img=np.random.randint(0, 256, size=(10, 10, 3)) + >>> tensorboard_visualizer.add_image('img', img) + >>> tensorboard_visualizer.add_scaler('mAP', 0.6) + >>> tensorboard_visualizer.add_scalars({'loss': 0.1,'acc':0.8}) + >>> tensorboard_visualizer.add_image('img', image) + + Args: + save_dir (str): The root directory to save the files + produced by the backend. + log_dir (str): Save directory location. Default to 'tf_logs'. + """ + + def __init__(self, + save_dir: Optional[str] = None, + log_dir: str = 'tf_logs'): + super(TensorboardVisBackend, self).__init__(save_dir) + if save_dir is not None: + self._tensorboard = self._setup_env(log_dir) + + def _setup_env(self, log_dir: str): + """Setup env. + + Args: + log_dir (str): Save directory location. + + Return: + :obj:`SummaryWriter` + """ + if TORCH_VERSION == 'parrots': + try: + from tensorboardX import SummaryWriter + except ImportError: + raise ImportError('Please install tensorboardX to use ' + 'TensorboardLoggerHook.') + else: + try: + from torch.utils.tensorboard import SummaryWriter + except ImportError: + raise ImportError( + 'Please run "pip install future tensorboard" to install ' + 'the dependencies to use torch.utils.tensorboard ' + '(applicable to PyTorch 1.1 or higher)') + if self._save_dir is None: + return SummaryWriter(f'./{log_dir}') + else: + self.log_dir = osp.join(self._save_dir, log_dir) # type: ignore + return SummaryWriter(self.log_dir) + + @property + def experiment(self): + """Return Tensorboard object.""" + return self._tensorboard + + def add_config(self, config: Config, **kwargs) -> None: + # TODO + pass + + def add_image(self, + name: str, + image: np.ndarray, + step: int = 0, + **kwargs) -> None: + """Record image to tensorboard. + + Args: + name (str): The unique identifier for the image to save. + image (np.ndarray, optional): The image to be saved. The format + should be RGB. Default to None. + step (int): Global step value to record. Default to 0. + """ + self._tensorboard.add_image(name, image, step, dataformats='HWC') + + def add_scalar(self, + name: str, + value: Union[int, float], + step: int = 0, + **kwargs) -> None: + """Record scalar data to summary. + + Args: + name (str): The unique identifier for the scalar to save. + value (float, int): Value to save. + step (int): Global step value to record. Default to 0. + """ + self._tensorboard.add_scalar(name, value, step) + + def add_scalars(self, + scalar_dict: dict, + step: int = 0, + file_path: Optional[str] = None, + **kwargs) -> None: + """Record scalar's data to summary. + + Args: + scalar_dict (dict): Key-value pair storing the tag and + corresponding values. + step (int): Global step value to record. Default to 0. + file_path (str, optional): Useless parameter. Just for + interface unification. Default to None. + """ + assert isinstance(scalar_dict, dict) + assert 'step' not in scalar_dict, 'Please set it directly ' \ + 'through the step parameter' + for key, value in scalar_dict.items(): + self.add_scalar(key, value, step) + + def close(self): + """close an opened tensorboard object.""" + if hasattr(self, '_tensorboard'): + self._tensorboard.close() diff --git a/mmengine/visualization/visualizer.py b/mmengine/visualization/visualizer.py index ae6ff113..d8025616 100644 --- a/mmengine/visualization/visualizer.py +++ b/mmengine/visualization/visualizer.py @@ -1,25 +1,32 @@ # Copyright (c) OpenMMLab. All rights reserved. import warnings -from typing import Callable, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Sequence, Tuple, Union import cv2 import matplotlib.pyplot as plt import numpy as np import torch +from matplotlib.backend_bases import CloseEvent from matplotlib.backends.backend_agg import FigureCanvasAgg from matplotlib.collections import (LineCollection, PatchCollection, PolyCollection) from matplotlib.figure import Figure from matplotlib.patches import Circle +from mmengine.config import Config from mmengine.data import BaseDataElement -from mmengine.registry import VISUALIZERS -from .utils import (check_type, check_type_and_length, tensor2ndarray, - value2list) +from mmengine.registry import VISBACKENDS, VISUALIZERS +from mmengine.utils import ManagerMixin +from mmengine.visualization.utils import (check_type, check_type_and_length, + color_val_matplotlib, + convert_overlay_heatmap, + str_color_to_rgb, tensor2ndarray, + value2list) +from mmengine.visualization.vis_backend import BaseVisBackend @VISUALIZERS.register_module() -class Visualizer: +class Visualizer(ManagerMixin): """MMEngine provides a Visualizer class that uses the ``Matplotlib`` library as the backend. It has the following functions: @@ -67,15 +74,15 @@ class Visualizer: >>> # Basic drawing methods >>> vis = Visualizer(metadata=metadata, image=image) - >>> vis.draw_bboxes(np.array([0, 0, 1, 1]), edgecolors='g') + >>> vis.draw_bboxes(np.array([0, 0, 1, 1]), edge_colors='g') >>> vis.draw_bboxes(bbox=np.array([[1, 1, 2, 2], [2, 2, 3, 3]]), - edgecolors=['g', 'r'], is_filling=True) + edge_colors=['g', 'r'], is_filling=True) >>> vis.draw_lines(x_datas=np.array([1, 3]), y_datas=np.array([1, 3]), - colors='r', linewidths=1) + colors='r', line_widths=1) >>> vis.draw_lines(x_datas=np.array([[1, 3], [2, 4]]), y_datas=np.array([[1, 3], [2, 4]]), - colors=['r', 'r'], linewidths=[1, 2]) + colors=['r', 'r'], line_widths=[1, 2]) >>> vis.draw_texts(text='MMEngine', position=np.array([2, 2]), colors='b') @@ -87,10 +94,10 @@ class Visualizer: radius=np.array[1, 2], colors=['g', 'r'], is_filling=True) >>> vis.draw_polygons(np.array([0, 0, 1, 0, 1, 1, 0, 1]), - edgecolors='g') + edge_colors='g') >>> vis.draw_polygons(bbox=[np.array([0, 0, 1, 0, 1, 1, 0, 1], np.array([2, 2, 3, 2, 3, 3, 2, 3]]), - edgecolors=['g', 'r'], is_filling=True) + edge_colors=['g', 'r'], is_filling=True) >>> vis.draw_binary_masks(binary_mask, alpha=0.6) >>> # chain calls @@ -106,80 +113,99 @@ class Visualizer: >>> # inherit >>> class DetVisualizer2(Visualizer): - >>> @Visualizer.register_task('instances') - >>> def draw_instance(self, - >>> instances: 'BaseDataInstance', - >>> data_type: Type): - >>> pass - >>> def draw(self, + >>> def add_datasample(self, >>> image: Optional[np.ndarray] = None, >>> gt_sample: Optional['BaseDataElement'] = None, >>> pred_sample: Optional['BaseDataElement'] = None, >>> show_gt: bool = True, - >>> show_pred: bool = True) -> None: + >>> show_pred: bool = True, + >>> show:bool = True) -> None: >>> pass """ - task_dict: dict = {} - - def __init__(self, - image: Optional[np.ndarray] = None, - metadata: Optional[dict] = None) -> None: - self._metadata = metadata + def __init__( + self, + name='visualizer', + image: Optional[np.ndarray] = None, + vis_backends: Optional[Dict] = None, + save_dir: Optional[str] = None, + fig_save_cfg=dict(frameon=False), + fig_show_cfg=dict(frameon=False, num='show') + ) -> None: + super().__init__(name) + self._dataset_meta: Union[None, dict] = None + self._vis_backends: Union[Dict, Dict[str, 'BaseVisBackend']] = dict() + + if vis_backends: + with_name = False + without_name = False + for vis_backend in vis_backends: + if 'name' in vis_backend: + with_name = True + else: + without_name = True + if with_name and without_name: + raise AssertionError + + for vis_backend in vis_backends: + name = vis_backend.pop('name', vis_backend['type']) + assert name not in self._vis_backends + vis_backend.setdefault('save_dir', save_dir) + self._vis_backends[name] = VISBACKENDS.build(vis_backend) + + self.is_inline = 'inline' in plt.get_backend() + + self.fig_save = None + self.fig_show = None + self.fig_save_num = fig_save_cfg.get('num', None) + self.fig_show_num = fig_show_cfg.get('num', None) + self.fig_save_cfg = fig_save_cfg + self.fig_show_cfg = fig_show_cfg + + (self.fig_save, self.ax_save, + self.fig_save_num) = self._initialize_fig(fig_save_cfg) + self.dpi = self.fig_save.get_dpi() if image is not None: - self._setup_fig(image) - - def draw(self, - image: Optional[np.ndarray] = None, - gt_sample: Optional['BaseDataElement'] = None, - pred_sample: Optional['BaseDataElement'] = None, - draw_gt: bool = True, - draw_pred: bool = True) -> None: - pass + self.set_image(image) - def show(self, wait_time: int = 0) -> None: + @property + def dataset_meta(self) -> Optional[dict]: + return self._dataset_meta + + @dataset_meta.setter + def dataset_meta(self, dataset_meta: dict) -> None: + self._dataset_meta = dataset_meta + + def show(self, + drawn_img: Optional[np.ndarray] = None, + win_name: str = 'image', + wait_time: int = 0, + continue_key=' ') -> None: """Show the drawn image. Args: wait_time (int, optional): Delay in milliseconds. 0 is the special value that means "forever". Defaults to 0. """ - if wait_time == 0: - plt.show() - else: - plt.show(block=False) - plt.pause(wait_time) - - def close(self) -> None: - """Close the figure.""" - plt.close(self.fig) - - @classmethod - def register_task(cls, task_name: str, force: bool = False) -> Callable: - """Register a function. - - A record will be added to ``task_dict``, whose key is the task_name - and value is the decorated function. - - Args: - cls (type): Module class to be registered. - task_name (str or list of str, optional): The module name to be - registered. - force (bool): Whether to override an existing function with the - same name. Defaults to False. - """ - - def _register(task_func): - - if (task_name not in cls.task_dict) or force: - cls.task_dict[task_name] = task_func - else: - raise KeyError( - f'"{task_name}" is already registered in task_dict, ' - 'add "force=True" if you want to override it') - return task_func - - return _register + if self.is_inline: + return + if self.fig_show is None or not plt.fignum_exists(self.fig_show_num): + (self.fig_show, self.ax_show, + self.fig_show_num) = self._initialize_fig(self.fig_show_cfg) + img = self.get_image() if drawn_img is None else drawn_img + # dpi = self.fig_show.get_dpi() + # height, width = img.shape[:2] + # self.fig_show.set_size_inches((width + 1e-2) / dpi, + # (height + 1e-2) / dpi) + self.ax_show.cla() + self.ax_show.axis(False) + # self.ax_show.set_title(win_name) + # self.fig_show.set_label(win_name) + + # Refresh canvas, necessary for Qt5 backend. + self.ax_show.imshow(img) + self.fig_show.canvas.draw() # type: ignore + self._wait_continue(timeout=wait_time, continue_key=continue_key) def set_image(self, image: np.ndarray) -> None: """Set the image to draw. @@ -188,7 +214,23 @@ class Visualizer: image (np.ndarray): The image to draw. """ assert image is not None - self._setup_fig(image) + image = image.astype('uint8') + self._image = image + self.width, self.height = image.shape[1], image.shape[0] + self._default_font_size = max( + np.sqrt(self.height * self.width) // 90, 10) + + # add a small 1e-2 to avoid precision lost due to matplotlib's + # truncation (https://github.com/matplotlib/matplotlib/issues/15363) + self.fig_save.set_size_inches( # type: ignore + (self.width + 1e-2) / self.dpi, (self.height + 1e-2) / self.dpi) + # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig) + self.ax_save.cla() + self.ax_save.axis(False) + self.ax_save.imshow( + image, + extent=(0, self.width, self.height, 0), + interpolation='none') def get_image(self) -> np.ndarray: """Get the drawn image. The format is RGB. @@ -197,43 +239,24 @@ class Visualizer: np.ndarray: the drawn image which channel is rgb. """ assert self._image is not None, 'Please set image using `set_image`' - canvas = self.canvas + canvas = self.fig_save.canvas # type: ignore s, (width, height) = canvas.print_to_buffer() buffer = np.frombuffer(s, dtype='uint8') img_rgba = buffer.reshape(height, width, 4) rgb, alpha = np.split(img_rgba, [3], axis=2) return rgb.astype('uint8') - def _setup_fig(self, image: np.ndarray) -> None: - """Set the image to draw. + def _initialize_fig(self, fig_cfg): + fig = plt.figure(**fig_cfg) + ax = fig.add_subplot() + ax.axis(False) - Args: - image (np.ndarray): The image to draw.The format - should be RGB. - """ - image = image.astype('uint8') - self._image = image - self.width, self.height = image.shape[1], image.shape[0] - self._default_font_size = max( - np.sqrt(self.height * self.width) // 90, 10) - fig = plt.figure(frameon=False) + # remove white edges by set subplot margin + fig.subplots_adjust(left=0, right=1, bottom=0, top=1) + return fig, ax, fig.number - self.dpi = fig.get_dpi() - # add a small 1e-2 to avoid precision lost due to matplotlib's - # truncation (https://github.com/matplotlib/matplotlib/issues/15363) - fig.set_size_inches((self.width + 1e-2) / self.dpi, - (self.height + 1e-2) / self.dpi) - self.canvas = fig.canvas - # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig) - plt.subplots_adjust(left=0, right=1, bottom=0, top=1) - plt.axis('off') - ax = plt.gca() - self.fig = fig - self.ax = ax - self.ax.imshow( - image, - extent=(0, self.width, self.height, 0), - interpolation='none') + def get_backend(self, name) -> 'BaseVisBackend': + return self._vis_backends.get(name) # type: ignore def _is_posion_valid(self, position: np.ndarray) -> bool: """Judge whether the position is in image. @@ -251,14 +274,86 @@ class Visualizer: (position[..., 1] >= 0).all() return flag + def _wait_continue(self, timeout: int = 0, continue_key=' ') -> int: + """Show the image and wait for the user's input. + + This implementation refers to + https://github.com/matplotlib/matplotlib/blob/v3.5.x/lib/matplotlib/_blocking_input.py + + Args: + timeout (int): If positive, continue after ``timeout`` seconds. + Defaults to 0. + continue_key (str): The key for users to continue. Defaults to + the space key. + + Returns: + int: If zero, means time out or the user pressed ``continue_key``, + and if one, means the user closed the show figure. + """ # noqa: E501 + if self.is_inline: + # If use inline backend, interactive input and timeout is no use. + return 0 + + if self.fig_show.canvas.manager: # type: ignore + # Ensure that the figure is shown + self.fig_show.show() # type: ignore + + while True: + + # Connect the events to the handler function call. + event = None + + def handler(ev): + # Set external event variable + nonlocal event + # Qt backend may fire two events at the same time, + # use a condition to avoid missing close event. + event = ev if not isinstance(event, CloseEvent) else event + self.fig_show.canvas.stop_event_loop() + + cids = [ + self.fig_show.canvas.mpl_connect(name, handler) # type: ignore + for name in ('key_press_event', 'close_event') + ] + + try: + self.fig_show.canvas.start_event_loop(timeout) # type: ignore + finally: # Run even on exception like ctrl-c. + # Disconnect the callbacks. + for cid in cids: + self.fig_show.canvas.mpl_disconnect(cid) # type: ignore + + if isinstance(event, CloseEvent): + return 1 # Quit for close. + elif event is None or event.key == continue_key: + return 0 # Quit for continue. + + def draw_points(self, + positions: Union[np.ndarray, torch.Tensor], + colors: Union[str, tuple, List[str], List[tuple]] = 'g', + marker: Optional[str] = None, + sizes: Optional[Union[np.ndarray, torch.Tensor]] = None): + check_type('positions', positions, (np.ndarray, torch.Tensor)) + positions = tensor2ndarray(positions) + + if len(positions.shape) == 1: + positions = positions[None] + assert positions.shape[-1] == 2, ( + 'The shape of `positions` should be (N, 2), ' + f'but got {positions.shape}') + colors = color_val_matplotlib(colors) + self.ax_save.scatter( + positions[:, 0], positions[:, 1], c=colors, s=sizes, marker=marker) + return self + def draw_texts( self, texts: Union[str, List[str]], positions: Union[np.ndarray, torch.Tensor], font_sizes: Optional[Union[int, List[int]]] = None, - colors: Union[str, List[str]] = 'g', - verticalalignments: Union[str, List[str]] = 'top', - horizontalalignments: Union[str, List[str]] = 'left', + colors: Union[str, tuple, List[str], List[tuple]] = 'g', + vertical_alignments: Union[str, List[str]] = 'top', + horizontal_alignments: Union[str, List[str]] = 'left', font_families: Union[str, List[str]] = 'sans-serif', rotations: Union[int, str, List[Union[int, str]]] = 0, bboxes: Optional[Union[dict, List[dict]]] = None) -> 'Visualizer': @@ -273,29 +368,29 @@ class Visualizer: texts. ``font_sizes`` can have the same length with texts or just single value. If ``font_sizes`` is single value, all the texts will have the same font size. Defaults to None. - colors (Union[str, List[str]]): The colors of texts. ``colors`` - can have the same length with texts or just single value. - If ``colors`` is single value, all the texts will have the same - colors. Reference to + colors (Union[str, tuple, List[str], List[tuple]]): The colors + of texts. ``colors`` can have the same length with texts or + just single value. If ``colors`` is single value, all the + texts will have the same colors. Reference to https://matplotlib.org/stable/gallery/color/named_colors.html for more details. Defaults to 'g. - verticalalignments (Union[str, List[str]]): The verticalalignment + vertical_alignments (Union[str, List[str]]): The verticalalignment of texts. verticalalignment controls whether the y positional argument for the text indicates the bottom, center or top side of the text bounding box. - ``verticalalignments`` can have the same length with - texts or just single value. If ``verticalalignments`` is single - value, all the texts will have the same verticalalignment. - verticalalignment can be 'center' or 'top', 'bottom' or - 'baseline'. Defaults to 'top'. - horizontalalignments (Union[str, List[str]]): The + ``vertical_alignments`` can have the same length with + texts or just single value. If ``vertical_alignments`` is + single value, all the texts will have the same + verticalalignment. verticalalignment can be 'center' or + 'top', 'bottom' or 'baseline'. Defaults to 'top'. + horizontal_alignments (Union[str, List[str]]): The horizontalalignment of texts. Horizontalalignment controls whether the x positional argument for the text indicates the left, center or right side of the text bounding box. - ``horizontalalignments`` can have + ``horizontal_alignments`` can have the same length with texts or just single value. - If ``horizontalalignments`` is single value, all the texts will - have the same horizontalalignment. Horizontalalignment + If ``horizontal_alignments`` is single value, all the texts + will have the same horizontalalignment. Horizontalalignment can be 'center','right' or 'left'. Defaults to 'left'. font_families (Union[str, List[str]]): The font family of texts. ``font_families`` can have the same length with texts or @@ -335,19 +430,22 @@ class Visualizer: if font_sizes is None: font_sizes = self._default_font_size - check_type_and_length('font_sizes', font_sizes, (int, list), num_text) - font_sizes = value2list(font_sizes, int, num_text) + check_type_and_length('font_sizes', font_sizes, (int, float, list), + num_text) + font_sizes = value2list(font_sizes, (int, float), num_text) - check_type_and_length('colors', colors, (str, list), num_text) - colors = value2list(colors, str, num_text) + check_type_and_length('colors', colors, (str, tuple, list), num_text) + colors = value2list(colors, (str, tuple), num_text) + colors = color_val_matplotlib(colors) - check_type_and_length('verticalalignments', verticalalignments, + check_type_and_length('vertical_alignments', vertical_alignments, (str, list), num_text) - verticalalignments = value2list(verticalalignments, str, num_text) + vertical_alignments = value2list(vertical_alignments, str, num_text) - check_type_and_length('horizontalalignments', horizontalalignments, + check_type_and_length('horizontal_alignments', horizontal_alignments, (str, list), num_text) - horizontalalignments = value2list(horizontalalignments, str, num_text) + horizontal_alignments = value2list(horizontal_alignments, str, + num_text) check_type_and_length('rotations', rotations, (int, list), num_text) rotations = value2list(rotations, int, num_text) @@ -363,14 +461,14 @@ class Visualizer: bboxes = value2list(bboxes, dict, num_text) for i in range(num_text): - self.ax.text( + self.ax_save.text( positions[i][0], positions[i][1], texts[i], size=font_sizes[i], # type: ignore bbox=bboxes[i], # type: ignore - verticalalignment=verticalalignments[i], - horizontalalignment=horizontalalignments[i], + verticalalignment=vertical_alignments[i], + horizontalalignment=horizontal_alignments[i], family=font_families[i], color=colors[i]) return self @@ -379,9 +477,9 @@ class Visualizer: self, x_datas: Union[np.ndarray, torch.Tensor], y_datas: Union[np.ndarray, torch.Tensor], - colors: Union[str, List[str]] = 'g', - linestyles: Union[str, List[str]] = '-', - linewidths: Union[Union[int, float], List[Union[int, float]]] = 1 + colors: Union[str, tuple, List[str], List[tuple]] = 'g', + line_styles: Union[str, List[str]] = '-', + line_widths: Union[Union[int, float], List[Union[int, float]]] = 2 ) -> 'Visualizer': """Draw single or multiple line segments. @@ -390,24 +488,24 @@ class Visualizer: each line' start and end points. y_datas (Union[np.ndarray, torch.Tensor]): The y coordinate of each line' start and end points. - colors (Union[str, List[str]]): The colors of lines. ``colors`` - can have the same length with lines or just single value. - If ``colors`` is single value, all the lines will have the same - colors. Reference to + colors (Union[str, tuple, List[str], List[tuple]]): The colors of + lines. ``colors`` can have the same length with lines or just + single value. If ``colors`` is single value, all the lines + will have the same colors. Reference to https://matplotlib.org/stable/gallery/color/named_colors.html for more details. Defaults to 'g'. - linestyles (Union[str, List[str]]): The linestyle - of lines. ``linestyles`` can have the same length with - texts or just single value. If ``linestyles`` is single + line_styles (Union[str, List[str]]): The linestyle + of lines. ``line_styles`` can have the same length with + texts or just single value. If ``line_styles`` is single value, all the lines will have the same linestyle. Reference to https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle for more details. Defaults to '-'. - linewidths (Union[Union[int, float], List[Union[int, float]]]): The - linewidth of lines. ``linewidths`` can have + line_widths (Union[Union[int, float], List[Union[int, float]]]): + The linewidth of lines. ``line_widths`` can have the same length with lines or just single value. - If ``linewidths`` is single value, all the lines will - have the same linewidth. Defaults to 1. + If ``line_widths`` is single value, all the lines will + have the same linewidth. Defaults to 2. """ check_type('x_datas', x_datas, (np.ndarray, torch.Tensor)) x_datas = tensor2ndarray(x_datas) @@ -421,31 +519,31 @@ class Visualizer: if len(x_datas.shape) == 1: x_datas = x_datas[None] y_datas = y_datas[None] - + colors = color_val_matplotlib(colors) lines = np.concatenate( (x_datas.reshape(-1, 2, 1), y_datas.reshape(-1, 2, 1)), axis=-1) if not self._is_posion_valid(lines): - warnings.warn( 'Warning: The line is out of bounds,' ' the drawn line may not be in the image', UserWarning) line_collect = LineCollection( lines.tolist(), colors=colors, - linestyles=linestyles, - linewidths=linewidths) - self.ax.add_collection(line_collect) + linestyles=line_styles, + linewidths=line_widths) + self.ax_save.add_collection(line_collect) return self - def draw_circles(self, - center: Union[np.ndarray, torch.Tensor], - radius: Union[np.ndarray, torch.Tensor], - alpha: Union[float, int] = 0.8, - edgecolors: Union[str, List[str]] = 'g', - linestyles: Union[str, List[str]] = '-', - linewidths: Union[Union[int, float], - List[Union[int, float]]] = 1, - is_filling: bool = False) -> 'Visualizer': + def draw_circles( + self, + center: Union[np.ndarray, torch.Tensor], + radius: Union[np.ndarray, torch.Tensor], + alpha: Union[float, int] = 0.8, + edge_colors: Union[str, tuple, List[str], List[tuple]] = 'g', + line_styles: Union[str, List[str]] = '-', + line_widths: Union[Union[int, float], List[Union[int, float]]] = 2, + face_colors: Union[str, tuple, List[str], List[tuple]] = 'none' + ) -> 'Visualizer': """Draw single or multiple circles. Args: @@ -453,24 +551,24 @@ class Visualizer: each line' start and end points. radius (Union[np.ndarray, torch.Tensor]): The y coordinate of each line' start and end points. - edgecolors (Union[str, List[str]]): The colors of circles. - ``colors`` can have the same length with lines or just single - value. If ``colors`` is single value, all the lines will have - the same colors. Reference to + edge_colors (Union[str, tuple, List[str], List[tuple]]): The + colors of circles. ``colors`` can have the same length with + lines or just single value. If ``colors`` is single value, + all the lines will have the same colors. Reference to https://matplotlib.org/stable/gallery/color/named_colors.html for more details. Defaults to 'g. - linestyles (Union[str, List[str]]): The linestyle - of lines. ``linestyles`` can have the same length with - texts or just single value. If ``linestyles`` is single + line_styles (Union[str, List[str]]): The linestyle + of lines. ``line_styles`` can have the same length with + texts or just single value. If ``line_styles`` is single value, all the lines will have the same linestyle. Reference to https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle for more details. Defaults to '-'. - linewidths (Union[Union[int, float], List[Union[int, float]]]): The - linewidth of lines. ``linewidths`` can have + line_widths (Union[Union[int, float], List[Union[int, float]]]): + The linewidth of lines. ``line_widths`` can have the same length with lines or just single value. - If ``linewidths`` is single value, all the lines will - have the same linewidth. Defaults to 1. + If ``line_widths`` is single value, all the lines will + have the same linewidth. Defaults to 2. is_filling (bool): Whether to fill all the circles. Defaults to False. """ @@ -493,57 +591,59 @@ class Visualizer: center = center.tolist() radius = radius.tolist() + edge_colors = color_val_matplotlib(edge_colors) + face_colors = color_val_matplotlib(face_colors) circles = [] for i in range(len(center)): circles.append(Circle(tuple(center[i]), radius[i])) - if is_filling: - p = PatchCollection(circles, alpha=alpha, facecolor=edgecolors) - else: - if isinstance(linewidths, (int, float)): - linewidths = [linewidths] * len(circles) - linewidths = [ - min(max(linewidth, 1), self._default_font_size / 4) - for linewidth in linewidths - ] - p = PatchCollection( - circles, - alpha=alpha, - facecolor='none', - edgecolor=edgecolors, - linewidth=linewidths, - linestyles=linestyles) - self.ax.add_collection(p) + + if isinstance(line_widths, (int, float)): + line_widths = [line_widths] * len(circles) + line_widths = [ + min(max(linewidth, 1), self._default_font_size / 4) + for linewidth in line_widths + ] + p = PatchCollection( + circles, + alpha=alpha, + facecolors=face_colors, + edgecolors=edge_colors, + linewidths=line_widths, + linestyles=line_styles) + self.ax_save.add_collection(p) return self - def draw_bboxes(self, - bboxes: Union[np.ndarray, torch.Tensor], - alpha: Union[int, float] = 0.8, - edgecolors: Union[str, List[str]] = 'g', - linestyles: Union[str, List[str]] = '-', - linewidths: Union[Union[int, float], - List[Union[int, float]]] = 1, - is_filling: bool = False) -> 'Visualizer': + def draw_bboxes( + self, + bboxes: Union[np.ndarray, torch.Tensor], + alpha: Union[int, float] = 0.8, + edge_colors: Union[str, tuple, List[str], List[tuple]] = 'g', + line_styles: Union[str, List[str]] = '-', + line_widths: Union[Union[int, float], List[Union[int, float]]] = 2, + face_colors: Union[str, tuple, List[str], List[tuple]] = 'none' + ) -> 'Visualizer': """Draw single or multiple bboxes. Args: bboxes (Union[np.ndarray, torch.Tensor]): The bboxes to draw with the format of(x1,y1,x2,y2). - edgecolors (Union[str, List[str]]): The colors of bboxes. - ``colors`` can have the same length with lines or just single - value. If ``colors`` is single value, all the lines will have - the same colors. Refer to `matplotlib.colors` for full list of - formats that are accepted.. Defaults to 'g'. - linestyles (Union[str, List[str]]): The linestyle - of lines. ``linestyles`` can have the same length with - texts or just single value. If ``linestyles`` is single + edge_colors (Union[str, tuple, List[str], List[tuple]]): The + colors of bboxes. ``colors`` can have the same length with + lines or just single value. If ``colors`` is single value, all + the lines will have the same colors. Refer to `matplotlib. + colors` for full list of formats that are accepted. + Defaults to 'g'. + line_styles (Union[str, List[str]]): The linestyle + of lines. ``line_styles`` can have the same length with + texts or just single value. If ``line_styles`` is single value, all the lines will have the same linestyle. Reference to https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle for more details. Defaults to '-'. - linewidths (Union[Union[int, float], List[Union[int, float]]]): The - linewidth of lines. ``linewidths`` can have + line_widths (Union[Union[int, float], List[Union[int, float]]]): + The linewidth of lines. ``line_widths`` can have the same length with lines or just single value. - If ``linewidths`` is single value, all the lines will + If ``line_widths`` is single value, all the lines will have the same linewidth. Defaults to 1. is_filling (bool): Whether to fill all the bboxes. Defaults to False. @@ -570,47 +670,51 @@ class Visualizer: return self.draw_polygons( poly, alpha=alpha, - edgecolors=edgecolors, - linestyles=linestyles, - linewidths=linewidths, - is_filling=is_filling) - - def draw_polygons(self, - polygons: Union[Union[np.ndarray, torch.Tensor], - List[Union[np.ndarray, torch.Tensor]]], - alpha: Union[int, float] = 0.8, - edgecolors: Union[str, List[str]] = 'g', - linestyles: Union[str, List[str]] = '-', - linewidths: Union[Union[int, float], - List[Union[int, float]]] = 1.0, - is_filling: bool = False) -> 'Visualizer': + edge_colors=edge_colors, + line_styles=line_styles, + line_widths=line_widths, + face_colors=face_colors) + + def draw_polygons( + self, + polygons: Union[Union[np.ndarray, torch.Tensor], + List[Union[np.ndarray, torch.Tensor]]], + alpha: Union[int, float] = 0.8, + edge_colors: Union[str, tuple, List[str], List[tuple]] = 'g', + line_styles: Union[str, List[str]] = '-', + line_widths: Union[Union[int, float], List[Union[int, float]]] = 2, + face_colors: Union[str, tuple, List[str], List[tuple]] = 'none' + ) -> 'Visualizer': """Draw single or multiple bboxes. Args: polygons (Union[Union[np.ndarray, torch.Tensor], List[Union[np.ndarray, torch.Tensor]]]): The polygons to draw with the format of (x1,y1,x2,y2,...,xn,yn). - edgecolors (Union[str, List[str]]): The colors of polygons. - ``colors`` can have the same length with lines or just single - value. If ``colors`` is single value, all the lines will have - the same colors. Refer to `matplotlib.colors` for full list of - formats that are accepted.. Defaults to 'g. - linestyles (Union[str, List[str]]): The linestyle - of lines. ``linestyles`` can have the same length with - texts or just single value. If ``linestyles`` is single + edge_colors (Union[str, tuple, List[str], List[tuple]]): The + colors of polygons. ``colors`` can have the same length with + lines or just single value. If ``colors`` is single value, + all the lines will have the same colors. Refer to + `matplotlib.colors` for full list of formats that are accepted. + Defaults to 'g. + line_styles (Union[str, List[str]]): The linestyle + of lines. ``line_styles`` can have the same length with + texts or just single value. If ``line_styles`` is single value, all the lines will have the same linestyle. Reference to https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle for more details. Defaults to '-'. - linewidths (Union[Union[int, float], List[Union[int, float]]]): The - linewidth of lines. ``linewidths`` can have + line_widths (Union[Union[int, float], List[Union[int, float]]]): + The linewidth of lines. ``line_widths`` can have the same length with lines or just single value. - If ``linewidths`` is single value, all the lines will - have the same linewidth. Defaults to 1. + If ``line_widths`` is single value, all the lines will + have the same linewidth. Defaults to 2. is_filling (bool): Whether to fill all the polygons. Defaults to False. """ check_type('polygons', polygons, (list, np.ndarray, torch.Tensor)) + edge_colors = color_val_matplotlib(edge_colors) + face_colors = color_val_matplotlib(face_colors) if isinstance(polygons, (np.ndarray, torch.Tensor)): polygons = [polygons] @@ -625,32 +729,29 @@ class Visualizer: warnings.warn( 'Warning: The polygon is out of bounds,' ' the drawn polygon may not be in the image', UserWarning) - if is_filling: - polygon_collection = PolyCollection( - polygons, alpha=alpha, facecolor=edgecolors) - else: - if isinstance(linewidths, (int, float)): - linewidths = [linewidths] * len(polygons) - linewidths = [ - min(max(linewidth, 1), self._default_font_size / 4) - for linewidth in linewidths - ] - polygon_collection = PolyCollection( - polygons, - alpha=alpha, - facecolor='none', - linestyles=linestyles, - edgecolors=edgecolors, - linewidths=linewidths) - - self.ax.add_collection(polygon_collection) + if isinstance(line_widths, (int, float)): + line_widths = [line_widths] * len(polygons) + line_widths = [ + min(max(linewidth, 1), self._default_font_size / 4) + for linewidth in line_widths + ] + polygon_collection = PolyCollection( + polygons, + alpha=alpha, + facecolor=face_colors, + linestyles=line_styles, + edgecolors=edge_colors, + linewidths=line_widths) + + self.ax_save.add_collection(polygon_collection) return self def draw_binary_masks( - self, - binary_masks: Union[np.ndarray, torch.Tensor], - colors: np.ndarray = np.array([0, 255, 0]), - alphas: Union[float, List[float]] = 0.5) -> 'Visualizer': + self, + binary_masks: Union[np.ndarray, torch.Tensor], + alphas: Union[float, List[float]] = 0.8, + colors: Union[str, tuple, List[str], + List[tuple]] = 'g') -> 'Visualizer': """Draw single or multiple binary masks. Args: @@ -677,12 +778,24 @@ class Visualizer: binary_masks = binary_masks[None] assert img.shape[:2] == binary_masks.shape[ 1:], '`binary_marks` must have the same shpe with image' - assert isinstance(colors, np.ndarray) - if colors.ndim == 1: - colors = np.tile(colors, (binary_masks.shape[0], 1)) - assert colors.shape == (binary_masks.shape[0], 3) + binary_mask_len = binary_masks.shape[0] + + check_type_and_length('colors', colors, (str, tuple, list), + binary_mask_len) + colors = value2list(colors, (str, tuple), binary_mask_len) + colors = [ + str_color_to_rgb(color) if isinstance(color, str) else color + for color in colors + ] + for color in colors: + assert len(color) == 3 + for channel in color: + assert 0 <= channel <= 255 # type: ignore + colors = np.array(colors) + if colors.ndim == 1: # type: ignore + colors = np.tile(colors, (binary_mask_len, 1)) if isinstance(alphas, float): - alphas = [alphas] * binary_masks.shape[0] + alphas = [alphas] * binary_mask_len for binary_mask, color, alpha in zip(binary_masks, colors, alphas): binary_mask_complement = cv2.bitwise_not(binary_mask) @@ -692,8 +805,8 @@ class Visualizer: img_complement = cv2.bitwise_and( img, img, mask=binary_mask_complement) rgb = rgb + img_complement - img = cv2.addWeighted(img, alpha, rgb, 1 - alpha, 0) - self.ax.imshow( + img = cv2.addWeighted(img, 1 - alpha, rgb, alpha, 0) + self.ax_save.imshow( img, extent=(0, self.width, self.height, 0), interpolation='nearest') @@ -705,7 +818,7 @@ class Visualizer: mode: str = 'mean', topk: int = 10, arrangement: Tuple[int, int] = (5, 2), - alpha: float = 0.3) -> np.ndarray: + alpha: float = 0.8) -> np.ndarray: """Draw featmap. If img is not None, the final image will be the weighted sum of img and featmap. It support the mode: @@ -738,37 +851,6 @@ class Visualizer: Returns: np.ndarray: featmap. """ - - def concat_heatmap(feat_map: Union[np.ndarray, torch.Tensor], - img: Optional[np.ndarray] = None, - alpha: float = 0.5) -> np.ndarray: - """Convert feat_map to heatmap and sum to image, if image is not - None. - - Args: - feat_map (np.ndarray, torch.Tensor): The feat_map to convert - with of shape (H, W), where H is the image height and W is - the image width. - img (np.ndarray, optional): The origin image. The format - should be RGB. Defaults to None. - alphas (Union[int, List[int]]): The transparency of origin - image. Defaults to 0.5. - - Returns: - np.ndarray: heatmap - """ - if isinstance(feat_map, torch.Tensor): - feat_map = feat_map.detach().cpu().numpy() - norm_img = np.zeros(feat_map.shape) - norm_img = cv2.normalize(feat_map, norm_img, 0, 255, - cv2.NORM_MINMAX) - norm_img = np.asarray(norm_img, dtype=np.uint8) - heat_img = cv2.applyColorMap(norm_img, cv2.COLORMAP_JET) - heat_img = cv2.cvtColor(heat_img, cv2.COLOR_BGR2RGB) - if img is not None: - heat_img = cv2.addWeighted(img, alpha, heat_img, 1 - alpha, 0) - return heat_img - assert isinstance( tensor_chw, torch.Tensor), (f'`tensor_chw` should be {torch.Tensor} ' @@ -785,11 +867,9 @@ class Visualizer: ], (f'Mode only support "mean", "max", "min", but got {mode}') if mode == 'max': feat_map, _ = torch.max(tensor_chw, dim=0) - elif mode == 'min': - feat_map, _ = torch.min(tensor_chw, dim=0) elif mode == 'mean': feat_map = torch.mean(tensor_chw, dim=0) - return concat_heatmap(feat_map, image, alpha) + return convert_overlay_heatmap(feat_map, image, alpha) if topk <= 0: tensor_chw_channel = tensor_chw.shape[0] @@ -801,16 +881,15 @@ class Visualizer: ' mode parameter or set topk greater than 0 to solve ' 'the error') if tensor_chw_channel == 1: - return concat_heatmap(tensor_chw[0], image, alpha) + return convert_overlay_heatmap(tensor_chw[0], image, alpha) else: tensor_chw = tensor_chw.permute(1, 2, 0).numpy() - norm_img = np.zeros(tensor_chw.shape) norm_img = cv2.normalize(tensor_chw, None, 0, 255, cv2.NORM_MINMAX) heat_img = np.asarray(norm_img, dtype=np.uint8) if image is not None: - heat_img = cv2.addWeighted(image, alpha, heat_img, - 1 - alpha, 0) + heat_img = cv2.addWeighted(image, 1 - alpha, heat_img, + alpha, 0) return heat_img else: row, col = arrangement @@ -833,9 +912,133 @@ class Visualizer: for i in range(topk): axes = fig.add_subplot(row, col, i + 1) axes.axis('off') - axes.imshow(concat_heatmap(topk_tensor[i], image, alpha)) + axes.imshow( + convert_overlay_heatmap(topk_tensor[i], image, alpha)) s, (width, height) = canvas.print_to_buffer() buffer = np.frombuffer(s, dtype='uint8') img_rgba = buffer.reshape(height, width, 4) rgb, alpha = np.split(img_rgba, [3], axis=2) return rgb.astype('uint8') + + def add_config(self, config: Config, **kwargs): + """Record parameters. + + Args: + config (Config): The Config object. + """ + for vis_backend in self._vis_backends.values(): + vis_backend.add_config(config, **kwargs) # type: ignore + + def add_graph(self, model: torch.nn.Module, data_batch: Sequence[dict], + **kwargs) -> None: + """Record graph data. + + Args: + model (torch.nn.Module): Model to draw. + data_batch (Sequence[dict]): Batch of data from dataloader. + """ + for vis_backend in self._vis_backends.values(): + vis_backend.add_graph(model, data_batch, **kwargs) # type: ignore + + def add_image(self, name: str, image: np.ndarray, step: int = 0) -> None: + """Record image. + + Args: + name (str): The unique identifier for the image to save. + image (np.ndarray, optional): The image to be saved. The format + should be RGB. Default to None. + step (int): Global step value to record. Default to 0. + """ + for vis_backend in self._vis_backends.values(): + vis_backend.add_image(name, image, step) # type: ignore + + def add_scalar(self, + name: str, + value: Union[int, float], + step: int = 0, + **kwargs) -> None: + """Record scalar data. + + Args: + name (str): The unique identifier for the scalar to save. + value (float, int): Value to save. + step (int): Global step value to record. Default to 0. + """ + for vis_backend in self._vis_backends.values(): + vis_backend.add_scalar(name, value, step, **kwargs) # type: ignore + + def add_scalars(self, + scalar_dict: dict, + step: int = 0, + file_path: Optional[str] = None, + **kwargs) -> None: + """Record scalars' data. + + Args: + scalar_dict (dict): Key-value pair storing the tag and + corresponding values. + step (int): Global step value to record. Default to 0. + file_path (str, optional): The scalar's data will be + saved to the `file_path` file at the same time + if the `file_path` parameter is specified. + Default to None. + """ + for vis_backend in self._vis_backends.values(): + vis_backend.add_scalars( # type: ignore + scalar_dict, step, file_path, **kwargs) + + def add_datasample(self, + name, + image: np.ndarray, + gt_sample: Optional['BaseDataElement'] = None, + pred_sample: Optional['BaseDataElement'] = None, + draw_gt: bool = True, + draw_pred: bool = True, + show: bool = False, + wait_time: int = 0, + step: int = 0) -> None: + pass + + def close(self) -> None: + """close an opened object.""" + plt.close(self.fig_save) + if self.fig_show is not None: + plt.close(self.fig_show) + for vis_backend in self._vis_backends.values(): + vis_backend.close() # type: ignore + + @classmethod + def get_instance(cls, name: str, **kwargs) -> 'Visualizer': + """Make subclass can get latest created instance by + ``Visualizer.get_current_instance()``. + + Downstream codebase may need to get the latest created instance + without knowing the specific Visualizer type. For example, mmdetection + builds visualizer in runner and some component which cannot access + runner wants to get latest created visualizer. In this case, + the component does not know which type of visualizer has been built + and cannot get target instance. Therefore, :class:`Visualizer` + overrides the :meth:`get_instance` and its subclass will register + the created instance to :attr:`_instance_dict` additionally. + :meth:`get_current_instance` will return the latest created subclass + instance. + + Examples: + >>> class DetLocalVisualizer(Visualizer): + >>> def __init__(self, name): + >>> super().__init__(name) + >>> + >>> visualizer1 = DetLocalVisualizer.get_instance('name1') + >>> visualizer2 = Visualizer.get_current_instance() + >>> visualizer3 = DetLocalVisualizer.get_current_instance() + >>> assert id(visualizer1) == id(visualizer2) == id(visualizer3) + + Args: + name (str): Name of instance. Defaults to ''. + + Returns: + object: Corresponding name instance. + """ + instance = super().get_instance(name, **kwargs) + Visualizer._instance_dict[name] = instance + return instance diff --git a/mmengine/visualization/writer.py b/mmengine/visualization/writer.py deleted file mode 100644 index 72217ac8..00000000 --- a/mmengine/visualization/writer.py +++ /dev/null @@ -1,823 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os -import os.path as osp -import time -from abc import ABCMeta, abstractmethod -from typing import Any, List, Optional, Union - -import cv2 -import numpy as np -import torch - -from mmengine.data import BaseDataElement -from mmengine.dist import master_only -from mmengine.fileio import dump -from mmengine.registry import VISUALIZERS, WRITERS -from mmengine.utils import TORCH_VERSION, ManagerMixin -from .visualizer import Visualizer - - -class BaseWriter(metaclass=ABCMeta): - """Base class for writer. - - Each writer can inherit ``BaseWriter`` and implement - the required functions. - - Args: - visualizer (dict, :obj:`Visualizer`, optional): - Visualizer instance or dictionary. Default to None. - save_dir (str, optional): The root directory to save - the files produced by the writer. Default to None. - """ - - def __init__(self, - visualizer: Optional[Union[dict, 'Visualizer']] = None, - save_dir: Optional[str] = None): - self._save_dir = save_dir - if self._save_dir: - timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) - self._save_dir = osp.join( - self._save_dir, f'write_data_{timestamp}') # type: ignore - self._visualizer = visualizer - if visualizer: - if isinstance(visualizer, dict): - self._visualizer = VISUALIZERS.build(visualizer) - else: - assert isinstance(visualizer, Visualizer), \ - 'visualizer should be an instance of Visualizer, ' \ - f'but got {type(visualizer)}' - - @property - def visualizer(self) -> 'Visualizer': - """Return the visualizer object. - - You can get the drawing backend through the visualizer property for - custom drawing. - """ - return self._visualizer # type: ignore - - @property - @abstractmethod - def experiment(self) -> Any: - """Return the experiment object associated with this writer. - - The experiment attribute can get the write backend, such as wandb, - tensorboard. If you want to write other data, such as writing a table, - you can directly get the write backend through experiment. - """ - pass - - def add_params(self, params_dict: dict, **kwargs) -> None: - """Record a set of parameters. - - Args: - params_dict (dict): Each key-value pair in the dictionary is the - name of the parameters and it's corresponding value. - """ - pass - - def add_graph(self, model: torch.nn.Module, - input_tensor: Union[torch.Tensor, - List[torch.Tensor]], **kwargs) -> None: - """Record graph. - - Args: - model (torch.nn.Module): Model to draw. - input_tensor (torch.Tensor, list[torch.Tensor]): A variable - or a tuple of variables to be fed. - """ - pass - - def add_image(self, - name: str, - image: Optional[np.ndarray] = None, - gt_sample: Optional['BaseDataElement'] = None, - pred_sample: Optional['BaseDataElement'] = None, - draw_gt: bool = True, - draw_pred: bool = True, - step: int = 0, - **kwargs) -> None: - """Record image. - - Args: - name (str): The unique identifier for the image to save. - image (np.ndarray, optional): The image to be saved. The format - should be RGB. Default to None. - gt_sample (:obj:`BaseDataElement`, optional): The ground truth data - structure of OpenMMlab. Default to None. - pred_sample (:obj:`BaseDataElement`, optional): The predicted - result data structure of OpenMMlab. Default to None. - draw_gt (bool): Whether to draw the ground truth. Default: True. - draw_pred (bool): Whether to draw the predicted result. - Default to True. - step (int): Global step value to record. Default to 0. - """ - pass - - def add_scalar(self, - name: str, - value: Union[int, float], - step: int = 0, - **kwargs) -> None: - """Record scalar. - - Args: - name (str): The unique identifier for the scalar to save. - value (float, int): Value to save. - step (int): Global step value to record. Default to 0. - """ - pass - - def add_scalars(self, - scalar_dict: dict, - step: int = 0, - file_path: Optional[str] = None, - **kwargs) -> None: - """Record scalars' data. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. - step (int): Global step value to record. Default to 0. - file_path (str, optional): The scalar's data will be - saved to the `file_path` file at the same time - if the `file_path` parameter is specified. - Default to None. - """ - pass - - def close(self) -> None: - """close an opened object.""" - pass - - -@WRITERS.register_module() -class LocalWriter(BaseWriter): - """Local write class. - - It can write image, hyperparameters, scalars, etc. - to the local hard disk. You can get the drawing backend - through the visualizer property for custom drawing. - - Examples: - >>> from mmengine.visualization import LocalWriter - >>> import numpy as np - >>> local_writer = LocalWriter(dict(type='DetVisualizer'),\ - save_dir='temp_dir') - >>> img=np.random.randint(0, 256, size=(10, 10, 3)) - >>> local_writer.add_image('img', img) - >>> local_writer.add_scaler('mAP', 0.6) - >>> local_writer.add_scalars({'loss': [1, 2, 3], 'acc': 0.8}) - >>> local_writer.add_params(dict(lr=0.1, mode='linear')) - - >>> local_writer.visualizer.draw_bboxes(np.array([0, 0, 1, 1]), \ - edgecolors='g') - >>> local_writer.add_image('img', \ - local_writer.visualizer.get_image()) - - Args: - save_dir (str): The root directory to save the files - produced by the writer. - visualizer (dict, :obj:`Visualizer`, optional): Visualizer - instance or dictionary. Default to None - img_save_dir (str): The directory to save images. - Default to 'writer_image'. - params_save_file (str): The file to save parameters. - Default to 'parameters.yaml'. - scalar_save_file (str): The file to save scalar values. - Default to 'scalars.json'. - img_show (bool): Whether to show the image when calling add_image. - Default to False. - """ - - def __init__(self, - save_dir: str, - visualizer: Optional[Union[dict, 'Visualizer']] = None, - img_save_dir: str = 'writer_image', - params_save_file: str = 'parameters.yaml', - scalar_save_file: str = 'scalars.json', - img_show: bool = False): - assert params_save_file.split('.')[-1] == 'yaml' - assert scalar_save_file.split('.')[-1] == 'json' - super(LocalWriter, self).__init__(visualizer, save_dir) - os.makedirs(self._save_dir, exist_ok=True) # type: ignore - self._img_save_dir = osp.join( - self._save_dir, # type: ignore - img_save_dir) - self._scalar_save_file = osp.join( - self._save_dir, # type: ignore - scalar_save_file) - self._params_save_file = osp.join( - self._save_dir, # type: ignore - params_save_file) - self._img_show = img_show - - @property - def experiment(self) -> 'LocalWriter': - """Return the experiment object associated with this writer.""" - return self - - def add_params(self, params_dict: dict, **kwargs) -> None: - """Record parameters to disk. - - Args: - params_dict (dict): The dict of parameters to save. - """ - assert isinstance(params_dict, dict) - self._dump(params_dict, self._params_save_file, 'yaml') - - def add_image(self, - name: str, - image: Optional[np.ndarray] = None, - gt_sample: Optional['BaseDataElement'] = None, - pred_sample: Optional['BaseDataElement'] = None, - draw_gt: bool = True, - draw_pred: bool = True, - step: int = 0, - **kwargs) -> None: - """Record image to disk. - - Args: - name (str): The unique identifier for the image to save. - image (np.ndarray, optional): The image to be saved. The format - should be RGB. Default to None. - gt_sample (:obj:`BaseDataElement`, optional): The ground truth data - structure of OpenMMlab. Default to None. - pred_sample (:obj:`BaseDataElement`, optional): The predicted - result data structure of OpenMMlab. Default to None. - draw_gt (bool): Whether to draw the ground truth. Default to True. - draw_pred (bool): Whether to draw the predicted result. - Default to True. - step (int): Global step value to record. Default to 0. - """ - assert self.visualizer, 'Please instantiate the visualizer ' \ - 'object with initialization parameters.' - self.visualizer.draw(image, gt_sample, pred_sample, draw_gt, draw_pred) - if self._img_show: - self.visualizer.show() - else: - drawn_image = cv2.cvtColor(self.visualizer.get_image(), - cv2.COLOR_RGB2BGR) - os.makedirs(self._img_save_dir, exist_ok=True) - save_file_name = f'{name}_{step}.png' - cv2.imwrite( - osp.join(self._img_save_dir, save_file_name), drawn_image) - - def add_scalar(self, - name: str, - value: Union[int, float], - step: int = 0, - **kwargs) -> None: - """Add scalar data to disk. - - Args: - name (str): The unique identifier for the scalar to save. - value (float, int): Value to save. - step (int): Global step value to record. Default to 0. - """ - self._dump({name: value, 'step': step}, self._scalar_save_file, 'json') - - def add_scalars(self, - scalar_dict: dict, - step: int = 0, - file_path: Optional[str] = None, - **kwargs) -> None: - """Record scalars. The scalar dict will be written to the default and - specified files if ``file_name`` is specified. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. - step (int): Global step value to record. Default to 0. - file_path (str, optional): The scalar's data will be - saved to the ``file_path`` file at the same time - if the ``file_path`` parameter is specified. - Default to None. - """ - assert isinstance(scalar_dict, dict) - scalar_dict.setdefault('step', step) - if file_path is not None: - assert file_path.split('.')[-1] == 'json' - new_save_file_path = osp.join( - self._save_dir, # type: ignore - file_path) - assert new_save_file_path != self._scalar_save_file, \ - '"file_path" and "scalar_save_file" have the same name, ' \ - 'please set "file_path" to another value' - self._dump(scalar_dict, new_save_file_path, 'json') - self._dump(scalar_dict, self._scalar_save_file, 'json') - - def _dump(self, value_dict: dict, file_path: str, - file_format: str) -> None: - """dump dict to file. - - Args: - value_dict (dict) : Save dict data. - file_path (str): The file path to save data. - file_format (str): The file format to save data. - """ - with open(file_path, 'a+') as f: - dump(value_dict, f, file_format=file_format) - f.write('\n') - - -@WRITERS.register_module() -class WandbWriter(BaseWriter): - """Write various types of data to wandb. - - Examples: - >>> from mmengine.visualization import WandbWriter - >>> import numpy as np - >>> wandb_writer = WandbWriter(dict(type='DetVisualizer')) - >>> img=np.random.randint(0, 256, size=(10, 10, 3)) - >>> wandb_writer.add_image('img', img) - >>> wandb_writer.add_scaler('mAP', 0.6) - >>> wandb_writer.add_scalars({'loss': [1, 2, 3],'acc': 0.8}) - >>> wandb_writer.add_params(dict(lr=0.1, mode='linear')) - - >>> wandb_writer.visualizer.draw_bboxes(np.array([0, 0, 1, 1]), \ - edgecolors='g') - >>> wandb_writer.add_image('img', \ - wandb_writer.visualizer.get_image()) - - >>> wandb_writer = WandbWriter() - >>> assert wandb_writer.visualizer is None - >>> wandb_writer.add_image('img', img) - - Args: - init_kwargs (dict, optional): wandb initialization - input parameters. Default to None. - commit: (bool, optional) Save the metrics dict to the wandb server - and increment the step. If false `wandb.log` just - updates the current metrics dict with the row argument - and metrics won't be saved until `wandb.log` is called - with `commit=True`. Default to True. - visualizer (dict, :obj:`Visualizer`, optional): - Visualizer instance or dictionary. Default to None. - save_dir (str, optional): The root directory to save the files - produced by the writer. Default to None. - """ - - def __init__(self, - init_kwargs: Optional[dict] = None, - commit: Optional[bool] = True, - visualizer: Optional[Union[dict, 'Visualizer']] = None, - save_dir: Optional[str] = None): - super(WandbWriter, self).__init__(visualizer, save_dir) - self._commit = commit - self._wandb = self._setup_env(init_kwargs) - - @property - def experiment(self): - """Return wandb object. - - The experiment attribute can get the wandb backend, If you want to - write other data, such as writing a table, you can directly get the - wandb backend through experiment. - """ - return self._wandb - - def _setup_env(self, init_kwargs: Optional[dict] = None) -> Any: - """Setup env. - - Args: - init_kwargs (dict): The init args. - - Return: - :obj:`wandb` - """ - try: - import wandb - except ImportError: - raise ImportError( - 'Please run "pip install wandb" to install wandb') - if init_kwargs: - wandb.init(**init_kwargs) - else: - wandb.init() - - return wandb - - def add_params(self, params_dict: dict, **kwargs) -> None: - """Record a set of parameters to be compared in wandb. - - Args: - params_dict (dict): Each key-value pair in the dictionary - is the name of the parameters and it's - corresponding value. - """ - assert isinstance(params_dict, dict) - self._wandb.log(params_dict, commit=self._commit) - - def add_image(self, - name: str, - image: Optional[np.ndarray] = None, - gt_sample: Optional['BaseDataElement'] = None, - pred_sample: Optional['BaseDataElement'] = None, - draw_gt: bool = True, - draw_pred: bool = True, - step: int = 0, - **kwargs) -> None: - """Record image to wandb. - - Args: - name (str): The unique identifier for the image to save. - image (np.ndarray, optional): The image to be saved. The format - should be RGB. Default to None. - gt_sample (:obj:`BaseDataElement`, optional): The ground truth data - structure of OpenMMlab. Default to None. - pred_sample (:obj:`BaseDataElement`, optional): The predicted - result data structure of OpenMMlab. Default to None. - draw_gt (bool): Whether to draw the ground truth. Default: True. - draw_pred (bool): Whether to draw the predicted result. - Default to True. - step (int): Global step value to record. Default to 0. - """ - if self.visualizer: - self.visualizer.draw(image, gt_sample, pred_sample, draw_gt, - draw_pred) - self._wandb.log({name: self.visualizer.get_image()}, - commit=self._commit, - step=step) - else: - self.add_image_to_wandb(name, image, gt_sample, pred_sample, - draw_gt, draw_pred, step, **kwargs) - - def add_scalar(self, - name: str, - value: Union[int, float], - step: int = 0, - **kwargs) -> None: - """Record scalar data to wandb. - - Args: - name (str): The unique identifier for the scalar to save. - value (float, int): Value to save. - step (int): Global step value to record. Default to 0. - """ - self._wandb.log({name: value}, commit=self._commit, step=step) - - def add_scalars(self, - scalar_dict: dict, - step: int = 0, - file_path: Optional[str] = None, - **kwargs) -> None: - """Record scalar's data to wandb. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. - step (int): Global step value to record. Default to 0. - file_path (str, optional): Useless parameter. Just for - interface unification. Default to None. - """ - self._wandb.log(scalar_dict, commit=self._commit, step=step) - - def add_image_to_wandb(self, - name: str, - image: np.ndarray, - gt_sample: Optional['BaseDataElement'] = None, - pred_sample: Optional['BaseDataElement'] = None, - draw_gt: bool = True, - draw_pred: bool = True, - step: int = 0, - **kwargs) -> None: - """Record image to wandb. - - Args: - name (str): The unique identifier for the image to save. - image (np.ndarray): The image to be saved. The format - should be BGR. - gt_sample (:obj:`BaseDataElement`, optional): The ground truth data - structure of OpenMMlab. Default to None. - pred_sample (:obj:`BaseDataElement`, optional): The predicted - result data structure of OpenMMlab. Default to None. - draw_gt (bool): Whether to draw the ground truth. Default to True. - draw_pred (bool): Whether to draw the predicted result. - Default to True. - step (int): Global step value to record. Default to 0. - """ - raise NotImplementedError() - - def close(self) -> None: - """close an opened wandb object.""" - if hasattr(self, '_wandb'): - self._wandb.join() - - -@WRITERS.register_module() -class TensorboardWriter(BaseWriter): - """Tensorboard write class. It can write images, hyperparameters, scalars, - etc. to a tensorboard file. - - Its drawing function is provided by Visualizer. - - Examples: - >>> from mmengine.visualization import TensorboardWriter - >>> import numpy as np - >>> tensorboard_writer = TensorboardWriter(dict(type='DetVisualizer'),\ - save_dir='temp_dir') - >>> img=np.random.randint(0, 256, size=(10, 10, 3)) - >>> tensorboard_writer.add_image('img', img) - >>> tensorboard_writer.add_scaler('mAP', 0.6) - >>> tensorboard_writer.add_scalars({'loss': 0.1,'acc':0.8}) - >>> tensorboard_writer.add_params(dict(lr=0.1, mode='linear')) - - >>> tensorboard_writer.visualizer.draw_bboxes(np.array([0, 0, 1, 1]), \ - edgecolors='g') - >>> tensorboard_writer.add_image('img', \ - tensorboard_writer.visualizer.get_image()) - - Args: - save_dir (str): The root directory to save the files - produced by the writer. - visualizer (dict, :obj:`Visualizer`, optional): Visualizer instance - or dictionary. Default to None. - log_dir (str): Save directory location. Default to 'tf_writer'. - """ - - def __init__(self, - save_dir: str, - visualizer: Optional[Union[dict, 'Visualizer']] = None, - log_dir: str = 'tf_logs'): - super(TensorboardWriter, self).__init__(visualizer, save_dir) - self._tensorboard = self._setup_env(log_dir) - - def _setup_env(self, log_dir: str): - """Setup env. - - Args: - log_dir (str): Save directory location. Default 'tf_writer'. - - Return: - :obj:`SummaryWriter` - """ - if TORCH_VERSION == 'parrots': - try: - from tensorboardX import SummaryWriter - except ImportError: - raise ImportError('Please install tensorboardX to use ' - 'TensorboardLoggerHook.') - else: - try: - from torch.utils.tensorboard import SummaryWriter - except ImportError: - raise ImportError( - 'Please run "pip install future tensorboard" to install ' - 'the dependencies to use torch.utils.tensorboard ' - '(applicable to PyTorch 1.1 or higher)') - - self.log_dir = osp.join(self._save_dir, log_dir) # type: ignore - return SummaryWriter(self.log_dir) - - @property - def experiment(self): - """Return Tensorboard object.""" - return self._tensorboard - - def add_graph(self, model: torch.nn.Module, - input_tensor: Union[torch.Tensor, - List[torch.Tensor]], **kwargs) -> None: - """Record graph data to tensorboard. - - Args: - model (torch.nn.Module): Model to draw. - input_tensor (torch.Tensor, list[torch.Tensor]): A variable - or a tuple of variables to be fed. - """ - if isinstance(input_tensor, list): - for array in input_tensor: - assert array.ndim == 4 - assert isinstance(array, torch.Tensor) - else: - assert isinstance(input_tensor, - torch.Tensor) and input_tensor.ndim == 4 - self._tensorboard.add_graph(model, input_tensor) - - def add_params(self, params_dict: dict, **kwargs) -> None: - """Record a set of hyperparameters to be compared in TensorBoard. - - Args: - params_dict (dict): Each key-value pair in the dictionary is the - name of the hyper parameter and it's corresponding value. - The type of the value can be one of `bool`, `string`, - `float`, `int`, or `None`. - """ - assert isinstance(params_dict, dict) - self._tensorboard.add_hparams(params_dict, {}) - - def add_image(self, - name: str, - image: Optional[np.ndarray] = None, - gt_sample: Optional['BaseDataElement'] = None, - pred_sample: Optional['BaseDataElement'] = None, - draw_gt: bool = True, - draw_pred: bool = True, - step: int = 0, - **kwargs) -> None: - """Record image to tensorboard. - - Args: - name (str): The unique identifier for the image to save. - image (np.ndarray, optional): The image to be saved. The format - should be RGB. Default to None. - gt_sample (:obj:`BaseDataElement`, optional): The ground truth data - structure of OpenMMlab. Default to None. - pred_sample (:obj:`BaseDataElement`, optional): The predicted - result data structure of OpenMMlab. Default to None. - draw_gt (bool): Whether to draw the ground truth. Default to True. - draw_pred (bool): Whether to draw the predicted result. - Default to True. - step (int): Global step value to record. Default to 0. - """ - assert self.visualizer, 'Please instantiate the visualizer ' \ - 'object with initialization parameters.' - self.visualizer.draw(image, gt_sample, pred_sample, draw_gt, draw_pred) - self._tensorboard.add_image( - name, self.visualizer.get_image(), step, dataformats='HWC') - - def add_scalar(self, - name: str, - value: Union[int, float], - step: int = 0, - **kwargs) -> None: - """Record scalar data to summary. - - Args: - name (str): The unique identifier for the scalar to save. - value (float, int): Value to save. - step (int): Global step value to record. Default to 0. - """ - self._tensorboard.add_scalar(name, value, step) - - def add_scalars(self, - scalar_dict: dict, - step: int = 0, - file_path: Optional[str] = None, - **kwargs) -> None: - """Record scalar's data to summary. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. - step (int): Global step value to record. Default to 0. - file_path (str, optional): Useless parameter. Just for - interface unification. Default to None. - """ - assert isinstance(scalar_dict, dict) - assert 'step' not in scalar_dict, 'Please set it directly ' \ - 'through the step parameter' - for key, value in scalar_dict.items(): - self.add_scalar(key, value, step) - - def close(self): - """close an opened tensorboard object.""" - if hasattr(self, '_tensorboard'): - self._tensorboard.close() - - -class ComposedWriter(ManagerMixin): - """Wrapper class to compose multiple a subclass of :class:`BaseWriter` - instances. By inheriting ManagerMixin, it can be accessed anywhere once - instantiated. - - Examples: - >>> from mmengine.visualization import ComposedWriter - >>> import numpy as np - >>> composed_writer= ComposedWriter.get_instance( \ - 'composed_writer', writers=[dict(type='LocalWriter', \ - visualizer=dict(type='DetVisualizer'), \ - save_dir='temp_dir'), dict(type='WandbWriter')]) - >>> img=np.random.randint(0, 256, size=(10, 10, 3)) - >>> composed_writer.add_image('img', img) - >>> composed_writer.add_scalar('mAP', 0.6) - >>> composed_writer.add_scalars({'loss': 0.1,'acc':0.8}) - >>> composed_writer.add_params(dict(lr=0.1, mode='linear')) - - Args: - name (str): The name of the instance. Defaults: 'composed_writer'. - writers (list, optional): The writers to compose. Default to None - """ - - def __init__(self, - name: str = 'composed_writer', - writers: Optional[List[Union[dict, 'BaseWriter']]] = None): - super().__init__(name) - self._writers = [] - if writers is not None: - assert isinstance(writers, list) - for writer in writers: - if isinstance(writer, dict): - self._writers.append(WRITERS.build(writer)) - else: - assert isinstance(writer, BaseWriter), \ - f'writer should be an instance of a subclass of ' \ - f'BaseWriter, but got {type(writer)}' - self._writers.append(writer) - - def __len__(self): - return len(self._writers) - - def get_writer(self, index: int) -> 'BaseWriter': - """Returns the writer object corresponding to the specified index.""" - return self._writers[index] - - def get_experiment(self, index: int) -> Any: - """Returns the writer's experiment object corresponding to the - specified index.""" - return self._writers[index].experiment - - def get_visualizer(self, index: int) -> 'Visualizer': - """Returns the writer's visualizer object corresponding to the - specified index.""" - return self._writers[index].visualizer - - def add_params(self, params_dict: dict, **kwargs): - """Record parameters. - - Args: - params_dict (dict): The dictionary of parameters to save. - """ - for writer in self._writers: - writer.add_params(params_dict, **kwargs) - - def add_graph(self, model: torch.nn.Module, - input_array: Union[torch.Tensor, - List[torch.Tensor]], **kwargs) -> None: - """Record graph data. - - Args: - model (torch.nn.Module): Model to draw. - input_array (torch.Tensor, list[torch.Tensor]): A variable - or a tuple of variables to be fed. - """ - for writer in self._writers: - writer.add_graph(model, input_array, **kwargs) - - def add_image(self, - name: str, - image: Optional[np.ndarray] = None, - gt_sample: Optional['BaseDataElement'] = None, - pred_sample: Optional['BaseDataElement'] = None, - draw_gt: bool = True, - draw_pred: bool = True, - step: int = 0, - **kwargs) -> None: - """Record image. - - Args: - name (str): The unique identifier for the image to save. - image (np.ndarray, optional): The image to be saved. The format - should be RGB. Default to None. - gt_sample (:obj:`BaseDataElement`, optional): The ground truth data - structure of OpenMMlab. Default to None. - pred_sample (:obj:`BaseDataElement`, optional): The predicted - result data structure of OpenMMlab. Default to None. - draw_gt (bool): Whether to draw the ground truth. Default to True. - draw_pred (bool): Whether to draw the predicted result. - Default to True. - step (int): Global step value to record. Default to 0. - """ - for writer in self._writers: - writer.add_image(name, image, gt_sample, pred_sample, draw_gt, - draw_pred, step, **kwargs) - - def add_scalar(self, - name: str, - value: Union[int, float], - step: int = 0, - **kwargs) -> None: - """Record scalar data. - - Args: - name (str): The unique identifier for the scalar to save. - value (float, int): Value to save. - step (int): Global step value to record. Default to 0. - """ - for writer in self._writers: - writer.add_scalar(name, value, step, **kwargs) - - @master_only - def add_scalars(self, - scalar_dict: dict, - step: int = 0, - file_path: Optional[str] = None, - **kwargs) -> None: - """Record scalars' data. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. - step (int): Global step value to record. Default to 0. - file_path (str, optional): The scalar's data will be - saved to the `file_path` file at the same time - if the `file_path` parameter is specified. - Default to None. - """ - for writer in self._writers: - writer.add_scalars(scalar_dict, step, file_path, **kwargs) - - def close(self) -> None: - """close an opened object.""" - for writer in self._writers: - writer.close() diff --git a/tests/test_hook/test_logger_hook.py b/tests/test_hook/test_logger_hook.py index 70e93d1c..1f9a3b76 100644 --- a/tests/test_hook/test_logger_hook.py +++ b/tests/test_hook/test_logger_hook.py @@ -61,7 +61,6 @@ class TestLoggerHook: assert logger_hook.json_log_path == osp.join('work_dir', 'timestamp.log.json') assert logger_hook.start_iter == runner.iter - runner.writer.add_params.assert_called() def test_after_run(self, tmp_path): out_dir = tmp_path / 'out_dir' @@ -151,7 +150,7 @@ class TestLoggerHook: logger_hook._collect_info = MagicMock(return_value=train_infos) logger_hook._log_train(runner) # Verify that the correct variables have been written. - runner.writer.add_scalars.assert_called_with( + runner.visualizer.add_scalars.assert_called_with( train_infos, step=11, file_path='tmp.json') # Verify that the correct context have been logged. out, _ = capsys.readouterr() @@ -209,7 +208,7 @@ class TestLoggerHook: logger_hook._log_val(runner) # Verify that the correct context have been logged. out, _ = capsys.readouterr() - runner.writer.add_scalars.assert_called_with( + runner.visualizer.add_scalars.assert_called_with( metric, step=11, file_path='tmp.json') if by_epoch: assert out == 'Epoch(val) [1][5] accuracy: 0.9000, ' \ diff --git a/tests/test_hook/test_naive_visualization_hook.py b/tests/test_hook/test_naive_visualization_hook.py index 0bbe47df..e06dd281 100644 --- a/tests/test_hook/test_naive_visualization_hook.py +++ b/tests/test_hook/test_naive_visualization_hook.py @@ -12,7 +12,7 @@ class TestNaiveVisualizationHook: def test_after_train_iter(self): naive_visualization_hook = NaiveVisualizationHook() runner = Mock(iter=1) - runner.writer.add_image = Mock() + runner.visualizer.add_image = Mock() inputs = torch.randn(1, 3, 15, 15) batch_idx = 10 # test with normalize, resize, pad diff --git a/tests/test_registry/test_registry.py b/tests/test_registry/test_registry.py index 76f1d7ce..cc7d4156 100644 --- a/tests/test_registry/test_registry.py +++ b/tests/test_registry/test_registry.py @@ -5,6 +5,7 @@ import pytest from mmengine.config import Config, ConfigDict # type: ignore from mmengine.registry import DefaultScope, Registry, build_from_cfg +from mmengine.utils import ManagerMixin class TestRegistry: @@ -482,3 +483,17 @@ def test_build_from_cfg(cfg_type): "<class 'str'>")): cfg = cfg_type(dict(type='ResNet', depth=50)) model = build_from_cfg(cfg, 'BACKBONES') + + VISUALIZER = Registry('visualizer') + + @VISUALIZER.register_module() + class Visualizer(ManagerMixin): + + def __init__(self, name): + super().__init__(name) + + with pytest.raises(RuntimeError): + Visualizer.get_current_instance() + cfg = dict(type='Visualizer', name='visualizer') + build_from_cfg(cfg, VISUALIZER) + Visualizer.get_current_instance() diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index b066838a..3e694be0 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -25,7 +25,7 @@ from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop, Runner, TestLoop, ValLoop) from mmengine.runner.priority import Priority, get_priority from mmengine.utils import is_list_of -from mmengine.visualization.writer import ComposedWriter +from mmengine.visualization import Visualizer @MODELS.register_module() @@ -308,24 +308,24 @@ class TestRunner(TestCase): self.assertFalse(runner.distributed) self.assertFalse(runner.deterministic) - # 1.5 message_hub, logger and writer + # 1.5 message_hub, logger and visualizer # they are all not specified cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_init12' runner = Runner(**cfg) self.assertIsInstance(runner.logger, MMLogger) self.assertIsInstance(runner.message_hub, MessageHub) - self.assertIsInstance(runner.writer, ComposedWriter) + self.assertIsInstance(runner.visualizer, Visualizer) # they are all specified cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_init13' cfg.log_level = 'INFO' - cfg.writer = dict(name='test_writer') + cfg.visualizer = None runner = Runner(**cfg) self.assertIsInstance(runner.logger, MMLogger) self.assertIsInstance(runner.message_hub, MessageHub) - self.assertIsInstance(runner.writer, ComposedWriter) + self.assertIsInstance(runner.visualizer, Visualizer) assert runner.distributed is False assert runner.seed is not None @@ -446,32 +446,34 @@ class TestRunner(TestCase): with self.assertRaisesRegex(TypeError, 'message_hub should be'): runner.build_message_hub('invalid-type') - def test_build_writer(self): - self.epoch_based_cfg.experiment_name = 'test_build_writer1' + def test_build_visualizer(self): + self.epoch_based_cfg.experiment_name = 'test_build_visualizer1' runner = Runner.from_cfg(self.epoch_based_cfg) - self.assertIsInstance(runner.writer, ComposedWriter) - self.assertEqual(runner.experiment_name, runner.writer.instance_name) + self.assertIsInstance(runner.visualizer, Visualizer) + self.assertEqual(runner.experiment_name, + runner.visualizer.instance_name) - # input is a ComposedWriter object + # input is a Visualizer object self.assertEqual( - id(runner.build_writer(runner.writer)), id(runner.writer)) + id(runner.build_visualizer(runner.visualizer)), + id(runner.visualizer)) # input is a dict - writer_cfg = dict(name='test_build_writer2') - writer = runner.build_writer(writer_cfg) - self.assertIsInstance(writer, ComposedWriter) - self.assertEqual(writer.instance_name, 'test_build_writer2') + visualizer_cfg = dict(type='Visualizer', name='test_build_visualizer2') + visualizer = runner.build_visualizer(visualizer_cfg) + self.assertIsInstance(visualizer, Visualizer) + self.assertEqual(visualizer.instance_name, 'test_build_visualizer2') # input is a dict but does not contain name key - runner._experiment_name = 'test_build_writer3' - writer_cfg = dict() - writer = runner.build_writer(writer_cfg) - self.assertIsInstance(writer, ComposedWriter) - self.assertEqual(writer.instance_name, 'test_build_writer3') + runner._experiment_name = 'test_build_visualizer3' + visualizer_cfg = None + visualizer = runner.build_visualizer(visualizer_cfg) + self.assertIsInstance(visualizer, Visualizer) + self.assertEqual(visualizer.instance_name, 'test_build_visualizer3') # input is not a valid type - with self.assertRaisesRegex(TypeError, 'writer should be'): - runner.build_writer('invalid-type') + with self.assertRaisesRegex(TypeError, 'visualizer should be'): + runner.build_visualizer('invalid-type') def test_default_scope(self): TOY_SCHEDULERS = Registry( diff --git a/tests/test_visualizer/test_vis_backend.py b/tests/test_visualizer/test_vis_backend.py new file mode 100644 index 00000000..da662a65 --- /dev/null +++ b/tests/test_visualizer/test_vis_backend.py @@ -0,0 +1,200 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import shutil +import sys +from unittest.mock import MagicMock + +import numpy as np +import pytest + +from mmengine.fileio import load +from mmengine.registry import VISBACKENDS +from mmengine.visualization import (LocalVisBackend, TensorboardVisBackend, + WandbVisBackend) + + +class TestLocalVisBackend: + + def test_init(self): + + # 'config_save_file' format must be py + with pytest.raises(AssertionError): + LocalVisBackend('temp_dir', config_save_file='a.txt') + + # 'scalar_save_file' format must be json + with pytest.raises(AssertionError): + LocalVisBackend('temp_dir', scalar_save_file='a.yaml') + + local_vis_backend = LocalVisBackend('temp_dir') + assert os.path.exists(local_vis_backend._save_dir) + shutil.rmtree('temp_dir') + + local_vis_backend = VISBACKENDS.build( + dict(type='LocalVisBackend', save_dir='temp_dir')) + assert os.path.exists(local_vis_backend._save_dir) + shutil.rmtree('temp_dir') + + def test_experiment(self): + local_vis_backend = LocalVisBackend('temp_dir') + assert local_vis_backend.experiment == local_vis_backend + shutil.rmtree('temp_dir') + + def test_add_config(self): + local_vis_backend = LocalVisBackend('temp_dir') + + # 'params_dict' must be dict + with pytest.raises(AssertionError): + local_vis_backend.add_config(['lr', 0]) + + # TODO + + shutil.rmtree('temp_dir') + + def test_add_image(self): + image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) + local_vis_backend = LocalVisBackend('temp_dir') + local_vis_backend.add_image('img', image) + assert os.path.exists( + os.path.join(local_vis_backend._img_save_dir, 'img_0.png')) + + local_vis_backend.add_image('img', image, step=2) + assert os.path.exists( + os.path.join(local_vis_backend._img_save_dir, 'img_2.png')) + + shutil.rmtree('temp_dir') + + def test_add_scalar(self): + local_vis_backend = LocalVisBackend('temp_dir') + local_vis_backend.add_scalar('map', 0.9) + out_dict = load(local_vis_backend._scalar_save_file, 'json') + assert out_dict == {'map': 0.9, 'step': 0} + shutil.rmtree('temp_dir') + + # test append mode + local_vis_backend = LocalVisBackend('temp_dir') + local_vis_backend.add_scalar('map', 0.9, step=0) + local_vis_backend.add_scalar('map', 0.95, step=1) + with open(local_vis_backend._scalar_save_file) as f: + out_dict = f.read() + assert out_dict == '{"map": 0.9, "step": 0}\n{"map": ' \ + '0.95, "step": 1}\n' + shutil.rmtree('temp_dir') + + def test_add_scalars(self): + local_vis_backend = LocalVisBackend('temp_dir') + input_dict = {'map': 0.7, 'acc': 0.9} + local_vis_backend.add_scalars(input_dict) + out_dict = load(local_vis_backend._scalar_save_file, 'json') + assert out_dict == {'map': 0.7, 'acc': 0.9, 'step': 0} + + # test append mode + local_vis_backend.add_scalars({'map': 0.8, 'acc': 0.8}, step=1) + with open(local_vis_backend._scalar_save_file) as f: + out_dict = f.read() + assert out_dict == '{"map": 0.7, "acc": 0.9, ' \ + '"step": 0}\n{"map": 0.8, "acc": 0.8, "step": 1}\n' + + # test file_path + local_vis_backend = LocalVisBackend('temp_dir') + local_vis_backend.add_scalars(input_dict, file_path='temp.json') + assert os.path.exists(local_vis_backend._scalar_save_file) + assert os.path.exists( + os.path.join(local_vis_backend._save_dir, 'temp.json')) + + # file_path and scalar_save_file cannot be the same + with pytest.raises(AssertionError): + local_vis_backend.add_scalars(input_dict, file_path='scalars.json') + + shutil.rmtree('temp_dir') + + +class TestTensorboardVisBackend: + sys.modules['torch.utils.tensorboard'] = MagicMock() + sys.modules['tensorboardX'] = MagicMock() + + def test_init(self): + + TensorboardVisBackend('temp_dir') + VISBACKENDS.build( + dict(type='TensorboardVisBackend', save_dir='temp_dir')) + + def test_experiment(self): + tensorboard_vis_backend = TensorboardVisBackend('temp_dir') + assert (tensorboard_vis_backend.experiment == + tensorboard_vis_backend._tensorboard) + + def test_add_graph(self): + # TODO + pass + + def test_add_config(self): + # TODO + pass + + def test_add_image(self): + image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) + + tensorboard_vis_backend = TensorboardVisBackend('temp_dir') + tensorboard_vis_backend.add_image('img', image) + + tensorboard_vis_backend.add_image('img', image, step=2) + + def test_add_scalar(self): + tensorboard_vis_backend = TensorboardVisBackend('temp_dir') + tensorboard_vis_backend.add_scalar('map', 0.9) + # test append mode + tensorboard_vis_backend.add_scalar('map', 0.9, step=0) + tensorboard_vis_backend.add_scalar('map', 0.95, step=1) + + def test_add_scalars(self): + tensorboard_vis_backend = TensorboardVisBackend('temp_dir') + # The step value must be passed through the parameter + with pytest.raises(AssertionError): + tensorboard_vis_backend.add_scalars({ + 'map': 0.7, + 'acc': 0.9, + 'step': 1 + }) + + input_dict = {'map': 0.7, 'acc': 0.9} + tensorboard_vis_backend.add_scalars(input_dict) + # test append mode + tensorboard_vis_backend.add_scalars({'map': 0.8, 'acc': 0.8}, step=1) + + +class TestWandbVisBackend: + sys.modules['wandb'] = MagicMock() + + def test_init(self): + WandbVisBackend() + VISBACKENDS.build(dict(type='WandbVisBackend', save_dir='temp_dir')) + + def test_experiment(self): + wandb_vis_backend = WandbVisBackend() + assert wandb_vis_backend.experiment == wandb_vis_backend._wandb + + def test_add_config(self): + # TODO + pass + + def test_add_image(self): + image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) + + wandb_vis_backend = WandbVisBackend() + wandb_vis_backend.add_image('img', image) + + wandb_vis_backend.add_image('img', image, step=2) + + def test_add_scalar(self): + wandb_vis_backend = WandbVisBackend() + wandb_vis_backend.add_scalar('map', 0.9) + # test append mode + wandb_vis_backend.add_scalar('map', 0.9, step=0) + wandb_vis_backend.add_scalar('map', 0.95, step=1) + + def test_add_scalars(self): + wandb_vis_backend = WandbVisBackend() + input_dict = {'map': 0.7, 'acc': 0.9} + wandb_vis_backend.add_scalars(input_dict) + # test append mode + wandb_vis_backend.add_scalars({'map': 0.8, 'acc': 0.8}, step=1) diff --git a/tests/test_visualizer/test_visualizer.py b/tests/test_visualizer/test_visualizer.py index 5a7da41b..ce3de94d 100644 --- a/tests/test_visualizer/test_visualizer.py +++ b/tests/test_visualizer/test_visualizer.py @@ -1,16 +1,64 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Optional +import copy +from typing import Any, List, Optional, Union from unittest import TestCase import matplotlib.pyplot as plt import numpy as np import pytest import torch +import torch.nn as nn -from mmengine.data import BaseDataElement +from mmengine import VISBACKENDS from mmengine.visualization import Visualizer +@VISBACKENDS.register_module() +class MockVisBackend: + + def __init__(self, save_dir: Optional[str] = None): + self._save_dir = save_dir + self._close = False + + @property + def experiment(self) -> Any: + return self + + def add_config(self, params_dict: dict, **kwargs) -> None: + self._add_config = True + + def add_graph(self, model: torch.nn.Module, + input_tensor: Union[torch.Tensor, + List[torch.Tensor]], **kwargs) -> None: + + self._add_graph = True + + def add_image(self, + name: str, + image: np.ndarray, + step: int = 0, + **kwargs) -> None: + self._add_image = True + + def add_scalar(self, + name: str, + value: Union[int, float], + step: int = 0, + **kwargs) -> None: + self._add_scalar = True + + def add_scalars(self, + scalar_dict: dict, + step: int = 0, + file_path: Optional[str] = None, + **kwargs) -> None: + self._add_scalars = True + + def close(self) -> None: + """close an opened object.""" + self._close = True + + class TestVisualizer(TestCase): def setUp(self): @@ -21,11 +69,27 @@ class TestVisualizer(TestCase): """ self.image = np.random.randint( 0, 256, size=(10, 10, 3)).astype('uint8') + self.vis_backend_cfg = [ + dict(type='MockVisBackend', name='mock1', save_dir='tmp'), + dict(type='MockVisBackend', name='mock2', save_dir='tmp') + ] def test_init(self): visualizer = Visualizer(image=self.image) visualizer.get_image() + visualizer = Visualizer( + vis_backends=copy.deepcopy(self.vis_backend_cfg)) + assert isinstance(visualizer.get_backend('mock1'), MockVisBackend) + assert len(visualizer._vis_backends) == 2 + + # test global + visualizer = Visualizer.get_instance( + 'visualizer', vis_backends=copy.deepcopy(self.vis_backend_cfg)) + assert len(visualizer._vis_backends) == 2 + visualizer_any = Visualizer.get_instance('visualizer') + assert visualizer_any == visualizer + def test_set_image(self): visualizer = Visualizer() visualizer.set_image(self.image) @@ -45,7 +109,7 @@ class TestVisualizer(TestCase): visualizer.draw_bboxes(torch.tensor([1, 1, 1, 2])) bboxes = torch.tensor([[1, 1, 2, 2], [1, 2, 2, 2.5]]) visualizer.draw_bboxes( - bboxes, alpha=0.5, edgecolors='b', linestyles='-') + bboxes, alpha=0.5, edge_colors=(255, 0, 0), line_styles='-') bboxes = bboxes.numpy() visualizer.draw_bboxes(bboxes) @@ -66,19 +130,26 @@ class TestVisualizer(TestCase): visualizer.draw_bboxes([1, 1, 2, 2]) def test_close(self): - visualizer = Visualizer(image=self.image) - fig_num = visualizer.fig.number + visualizer = Visualizer( + image=self.image, vis_backends=copy.deepcopy(self.vis_backend_cfg)) + fig_num = visualizer.fig_save_num assert fig_num in plt.get_fignums() + for name in ['mock1', 'mock2']: + assert visualizer.get_backend(name)._close is False visualizer.close() assert fig_num not in plt.get_fignums() + for name in ['mock1', 'mock2']: + assert visualizer.get_backend(name)._close is True def test_draw_texts(self): visualizer = Visualizer(image=self.image) # only support tensor and numpy - visualizer.draw_texts('text1', positions=torch.tensor([5, 5])) + visualizer.draw_texts( + 'text1', positions=torch.tensor([5, 5]), colors=(0, 255, 0)) visualizer.draw_texts(['text1', 'text2'], - positions=torch.tensor([[5, 5], [3, 3]])) + positions=torch.tensor([[5, 5], [3, 3]]), + colors=[(255, 0, 0), (255, 0, 0)]) visualizer.draw_texts('text1', positions=np.array([5, 5])) visualizer.draw_texts(['text1', 'text2'], positions=np.array([[5, 5], [3, 3]])) @@ -111,11 +182,11 @@ class TestVisualizer(TestCase): with pytest.raises(AssertionError): visualizer.draw_texts(['text1', 'test2'], positions=torch.tensor([[5, 5], [3, 3]]), - verticalalignments=['top']) + vertical_alignments=['top']) with pytest.raises(AssertionError): visualizer.draw_texts(['text1', 'test2'], positions=torch.tensor([[5, 5], [3, 3]]), - horizontalalignments=['left']) + horizontal_alignments=['left']) with pytest.raises(AssertionError): visualizer.draw_texts(['text1', 'test2'], positions=torch.tensor([[5, 5], [3, 3]]), @@ -140,8 +211,8 @@ class TestVisualizer(TestCase): x_datas=np.array([[1, 5], [2, 4]]), y_datas=np.array([[2, 6], [4, 7]]), colors='r', - linestyles=['-', '-.'], - linewidths=[1, 2]) + line_styles=['-', '-.'], + line_widths=[1, 2]) # test out of bounds with pytest.warns( UserWarning, @@ -171,19 +242,20 @@ class TestVisualizer(TestCase): visualizer.draw_circles( torch.tensor([[1, 5], [2, 6]]), radius=torch.tensor([1, 2])) - # test filling + # test face_colors visualizer.draw_circles( torch.tensor([[1, 5], [2, 6]]), radius=torch.tensor([1, 2]), - is_filling=True) + face_colors=(255, 0, 0), + edge_colors=(255, 0, 0)) # test config visualizer.draw_circles( torch.tensor([[1, 5], [2, 6]]), radius=torch.tensor([1, 2]), - edgecolors=['g', 'r'], - linestyles=['-', '-.'], - linewidths=[1, 2]) + edge_colors=['g', 'r'], + line_styles=['-', '-.'], + line_widths=[1, 2]) # test out of bounds with pytest.warns( @@ -220,15 +292,16 @@ class TestVisualizer(TestCase): np.array([[1, 1], [2, 2], [3, 4]]), torch.tensor([[1, 1], [2, 2], [3, 4]]) ], - is_filling=True) + face_colors=(255, 0, 0), + edge_colors=(255, 0, 0)) visualizer.draw_polygons( polygons=[ np.array([[1, 1], [2, 2], [3, 4]]), torch.tensor([[1, 1], [2, 2], [3, 4]]) ], - edgecolors=['r', 'g'], - linestyles='-', - linewidths=[2, 1]) + edge_colors=['r', 'g'], + line_styles='-', + line_widths=[2, 1]) # test out of bounds with pytest.warns( @@ -242,7 +315,10 @@ class TestVisualizer(TestCase): visualizer = Visualizer(image=self.image) visualizer.draw_binary_masks(binary_mask) visualizer.draw_binary_masks(torch.from_numpy(binary_mask)) - + # multi binary + binary_mask = np.random.randint(0, 2, size=(2, 10, 10)).astype(np.bool) + visualizer = Visualizer(image=self.image) + visualizer.draw_binary_masks(binary_mask, colors=['r', (0, 255, 0)]) # test the error that the size of mask and image are different. with pytest.raises(AssertionError): binary_mask = np.random.randint(0, 2, size=(8, 10)).astype(np.bool) @@ -269,7 +345,7 @@ class TestVisualizer(TestCase): visualizer.draw_featmap(torch.randn(1, 1, 3, 3)) # test mode parameter - # mode only supports 'mean' and 'max' and 'min + # mode only supports 'mean' and 'max' with pytest.raises(AssertionError): visualizer.draw_featmap(torch.randn(2, 3, 3), mode='xx') # test tensor_chw and img have difference height and width @@ -289,7 +365,6 @@ class TestVisualizer(TestCase): visualizer.draw_featmap(torch.randn(6, 3, 3), mode='mean') visualizer.draw_featmap(torch.randn(1, 3, 3), mode='mean') visualizer.draw_featmap(torch.randn(6, 3, 3), mode='max') - visualizer.draw_featmap(torch.randn(6, 3, 3), mode='min') visualizer.draw_featmap(torch.randn(6, 3, 3), mode='max', topk=10) visualizer.draw_featmap(torch.randn(1, 3, 3), mode=None, topk=-1) visualizer.draw_featmap( @@ -325,57 +400,76 @@ class TestVisualizer(TestCase): draw_polygons(torch.tensor([[1, 1], [2, 2], [3, 4]])). \ draw_binary_masks(binary_mask) - def test_register_task(self): + def test_get_backend(self): + visualizer = Visualizer( + image=self.image, vis_backends=copy.deepcopy(self.vis_backend_cfg)) + for name in ['mock1', 'mock2']: + assert isinstance(visualizer.get_backend(name), MockVisBackend) - class DetVisualizer(Visualizer): + def test_add_config(self): + visualizer = Visualizer( + vis_backends=copy.deepcopy(self.vis_backend_cfg)) - @Visualizer.register_task('instances') - def draw_instance(self, instances, data_type): - pass + params_dict = dict(lr=0.1, wd=0.2, mode='linear') + visualizer.add_config(params_dict) + for name in ['mock1', 'mock2']: + assert visualizer.get_backend(name)._add_config is True - assert len(Visualizer.task_dict) == 1 - assert 'instances' in Visualizer.task_dict + def test_add_graph(self): + visualizer = Visualizer( + vis_backends=copy.deepcopy(self.vis_backend_cfg)) - # test registration of the same names. - with pytest.raises( - KeyError, - match=('"instances" is already registered in task_dict, ' - 'add "force=True" if you want to override it')): - - class DetVisualizer1(Visualizer): - - @Visualizer.register_task('instances') - def draw_instance1(self, instances, data_type): - pass - - @Visualizer.register_task('instances') - def draw_instance2(self, instances, data_type): - pass - - Visualizer.task_dict = dict() - - class DetVisualizer2(Visualizer): - - @Visualizer.register_task('instances') - def draw_instance1(self, instances, data_type): - pass - - @Visualizer.register_task('instances', force=True) - def draw_instance2(self, instances, data_type): - pass - - def draw(self, - image: Optional[np.ndarray] = None, - gt_sample: Optional['BaseDataElement'] = None, - pred_sample: Optional['BaseDataElement'] = None, - draw_gt: bool = True, - draw_pred: bool = True) -> None: - return super().draw(image, gt_sample, pred_sample, draw_gt, - draw_pred) - - det_visualizer = DetVisualizer2() - det_visualizer.draw(gt_sample={}, pred_sample={}) - assert len(det_visualizer.task_dict) == 1 - assert 'instances' in det_visualizer.task_dict - assert det_visualizer.task_dict[ - 'instances'].__name__ == 'draw_instance2' + class Model(nn.Module): + + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(1, 2, 1) + + def forward(self, x, y=None): + return self.conv(x) + + visualizer.add_graph(Model(), np.zeros([1, 1, 3, 3])) + for name in ['mock1', 'mock2']: + assert visualizer.get_backend(name)._add_graph is True + + def test_add_image(self): + image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) + visualizer = Visualizer( + vis_backends=copy.deepcopy(self.vis_backend_cfg)) + + visualizer.add_image('img', image) + for name in ['mock1', 'mock2']: + assert visualizer.get_backend(name)._add_image is True + + def test_add_scalar(self): + visualizer = Visualizer( + vis_backends=copy.deepcopy(self.vis_backend_cfg)) + visualizer.add_scalar('map', 0.9, step=0) + for name in ['mock1', 'mock2']: + assert visualizer.get_backend(name)._add_scalar is True + + def test_add_scalars(self): + visualizer = Visualizer( + vis_backends=copy.deepcopy(self.vis_backend_cfg)) + input_dict = {'map': 0.7, 'acc': 0.9} + visualizer.add_scalars(input_dict) + for name in ['mock1', 'mock2']: + assert visualizer.get_backend(name)._add_scalars is True + + def test_get_instance(self): + + class DetLocalVisualizer(Visualizer): + + def __init__(self, name): + super().__init__(name) + + visualizer1 = DetLocalVisualizer.get_instance('name1') + visualizer2 = Visualizer.get_current_instance() + visualizer3 = DetLocalVisualizer.get_current_instance() + assert id(visualizer1) == id(visualizer2) == id(visualizer3) + + +if __name__ == '__main__': + t = TestVisualizer() + t.setUp() + t.test_init() diff --git a/tests/test_visualizer/test_writer.py b/tests/test_visualizer/test_writer.py deleted file mode 100644 index 447a246d..00000000 --- a/tests/test_visualizer/test_writer.py +++ /dev/null @@ -1,484 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os -import shutil -import sys -from unittest.mock import MagicMock, Mock, patch - -import numpy as np -import pytest -import torch -import torch.nn as nn - -from mmengine.fileio import load -from mmengine.registry import VISUALIZERS, WRITERS -from mmengine.visualization import (ComposedWriter, LocalWriter, - TensorboardWriter, WandbWriter) - - -def draw(self, image, gt_sample, pred_sample, show_gt=True, show_pred=True): - self.set_image(image) - - -class TestLocalWriter: - - def test_init(self): - # visuailzer must be a dictionary or an instance - # of Visualizer and its subclasses - with pytest.raises(AssertionError): - LocalWriter('temp_dir', [dict(type='Visualizer')]) - - # 'params_save_file' format must be yaml - with pytest.raises(AssertionError): - LocalWriter('temp_dir', params_save_file='a.txt') - - # 'scalar_save_file' format must be json - with pytest.raises(AssertionError): - LocalWriter('temp_dir', scalar_save_file='a.yaml') - - local_writer = LocalWriter('temp_dir') - assert os.path.exists(local_writer._save_dir) - shutil.rmtree('temp_dir') - - local_writer = WRITERS.build( - dict( - type='LocalWriter', - visualizer=dict(type='Visualizer'), - save_dir='temp_dir')) - assert os.path.exists(local_writer._save_dir) - shutil.rmtree('temp_dir') - - def test_experiment(self): - local_writer = LocalWriter('temp_dir') - assert local_writer.experiment == local_writer - shutil.rmtree('temp_dir') - - def test_add_params(self): - local_writer = LocalWriter('temp_dir') - - # 'params_dict' must be dict - with pytest.raises(AssertionError): - local_writer.add_params(['lr', 0]) - - params_dict = dict(lr=0.1, wd=[1.0, 0.1, 0.001], mode='linear') - local_writer.add_params(params_dict) - out_dict = load(local_writer._params_save_file, 'yaml') - assert out_dict == params_dict - shutil.rmtree('temp_dir') - - @patch('mmengine.visualization.visualizer.Visualizer.draw', draw) - def test_add_image(self): - image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) - - # The visuailzer parameter must be set when - # the local_writer object is instantiated and - # the `add_image` method is called. - with pytest.raises(AssertionError): - local_writer = LocalWriter('temp_dir') - local_writer.add_image('img', image) - - local_writer = LocalWriter('temp_dir', dict(type='Visualizer')) - local_writer.add_image('img', image) - assert os.path.exists( - os.path.join(local_writer._img_save_dir, 'img_0.png')) - - bboxes = np.array([[1, 1, 2, 2], [1, 1.5, 1, 2.5]]) - local_writer.visualizer.draw_bboxes(bboxes) - local_writer.add_image( - 'img', local_writer.visualizer.get_image(), step=2) - assert os.path.exists( - os.path.join(local_writer._img_save_dir, 'img_2.png')) - - visuailzer = VISUALIZERS.build(dict(type='Visualizer')) - local_writer = LocalWriter('temp_dir', visuailzer) - local_writer.add_image('img', image) - assert os.path.exists( - os.path.join(local_writer._img_save_dir, 'img_0.png')) - - shutil.rmtree('temp_dir') - - def test_add_scalar(self): - local_writer = LocalWriter('temp_dir') - local_writer.add_scalar('map', 0.9) - out_dict = load(local_writer._scalar_save_file, 'json') - assert out_dict == {'map': 0.9, 'step': 0} - shutil.rmtree('temp_dir') - - # test append mode - local_writer = LocalWriter('temp_dir') - local_writer.add_scalar('map', 0.9, step=0) - local_writer.add_scalar('map', 0.95, step=1) - with open(local_writer._scalar_save_file) as f: - out_dict = f.read() - assert out_dict == '{"map": 0.9, "step": 0}\n{"map": ' \ - '0.95, "step": 1}\n' - shutil.rmtree('temp_dir') - - def test_add_scalars(self): - local_writer = LocalWriter('temp_dir') - input_dict = {'map': 0.7, 'acc': 0.9} - local_writer.add_scalars(input_dict) - out_dict = load(local_writer._scalar_save_file, 'json') - assert out_dict == {'map': 0.7, 'acc': 0.9, 'step': 0} - - # test append mode - local_writer.add_scalars({'map': 0.8, 'acc': 0.8}, step=1) - with open(local_writer._scalar_save_file) as f: - out_dict = f.read() - assert out_dict == '{"map": 0.7, "acc": 0.9, ' \ - '"step": 0}\n{"map": 0.8, "acc": 0.8, "step": 1}\n' - - # test file_path - local_writer = LocalWriter('temp_dir') - local_writer.add_scalars(input_dict, file_path='temp.json') - assert os.path.exists(local_writer._scalar_save_file) - assert os.path.exists( - os.path.join(local_writer._save_dir, 'temp.json')) - - # file_path and scalar_save_file cannot be the same - with pytest.raises(AssertionError): - local_writer.add_scalars(input_dict, file_path='scalars.json') - - shutil.rmtree('temp_dir') - - -class TestTensorboardWriter: - sys.modules['torch.utils.tensorboard'] = MagicMock() - sys.modules['tensorboardX'] = MagicMock() - - def test_init(self): - # visuailzer must be a dictionary or an instance - # of Visualizer and its subclasses - with pytest.raises(AssertionError): - LocalWriter('temp_dir', [dict(type='Visualizer')]) - - TensorboardWriter('temp_dir') - WRITERS.build( - dict( - type='TensorboardWriter', - visualizer=dict(type='Visualizer'), - save_dir='temp_dir')) - - def test_experiment(self): - tensorboard_writer = TensorboardWriter('temp_dir') - assert tensorboard_writer.experiment == tensorboard_writer._tensorboard - - def test_add_graph(self): - - class Model(nn.Module): - - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(1, 2, 1) - - def forward(self, x, y=None): - return self.conv(x) - - tensorboard_writer = TensorboardWriter('temp_dir') - - # input must be tensor - with pytest.raises(AssertionError): - tensorboard_writer.add_graph(Model(), np.zeros([1, 1, 3, 3])) - - # input must be 4d tensor - with pytest.raises(AssertionError): - tensorboard_writer.add_graph(Model(), torch.zeros([1, 3, 3])) - - # If the input is a list, the inner element must be a 4d tensor - with pytest.raises(AssertionError): - tensorboard_writer.add_graph( - Model(), [torch.zeros([1, 1, 3, 3]), - torch.zeros([1, 3, 3])]) - - tensorboard_writer.add_graph(Model(), torch.zeros([1, 1, 3, 3])) - tensorboard_writer.add_graph( - Model(), [torch.zeros([1, 1, 3, 3]), - torch.zeros([1, 1, 3, 3])]) - - def test_add_params(self): - tensorboard_writer = TensorboardWriter('temp_dir') - - # 'params_dict' must be dict - with pytest.raises(AssertionError): - tensorboard_writer.add_params(['lr', 0]) - - params_dict = dict(lr=0.1, wd=0.2, mode='linear') - tensorboard_writer.add_params(params_dict) - - @patch('mmengine.visualization.visualizer.Visualizer.draw', draw) - def test_add_image(self): - image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) - - # The visuailzer parameter must be set when - # the local_writer object is instantiated and - # the `add_image` method is called. - with pytest.raises(AssertionError): - tensorboard_writer = TensorboardWriter('temp_dir') - tensorboard_writer.add_image('img', image) - - tensorboard_writer = TensorboardWriter('temp_dir', - dict(type='Visualizer')) - tensorboard_writer.add_image('img', image) - - bboxes = np.array([[1, 1, 2, 2], [1, 1.5, 1, 2.5]]) - tensorboard_writer.visualizer.draw_bboxes(bboxes) - tensorboard_writer.add_image( - 'img', tensorboard_writer.visualizer.get_image(), step=2) - - visuailzer = VISUALIZERS.build(dict(type='Visualizer')) - tensorboard_writer = TensorboardWriter('temp_dir', visuailzer) - tensorboard_writer.add_image('img', image) - - def test_add_scalar(self): - tensorboard_writer = TensorboardWriter('temp_dir') - tensorboard_writer.add_scalar('map', 0.9) - # test append mode - tensorboard_writer.add_scalar('map', 0.9, step=0) - tensorboard_writer.add_scalar('map', 0.95, step=1) - - def test_add_scalars(self): - tensorboard_writer = TensorboardWriter('temp_dir') - # The step value must be passed through the parameter - with pytest.raises(AssertionError): - tensorboard_writer.add_scalars({'map': 0.7, 'acc': 0.9, 'step': 1}) - - input_dict = {'map': 0.7, 'acc': 0.9} - tensorboard_writer.add_scalars(input_dict) - # test append mode - tensorboard_writer.add_scalars({'map': 0.8, 'acc': 0.8}, step=1) - - -class TestWandbWriter: - sys.modules['wandb'] = MagicMock() - - def test_init(self): - WandbWriter() - WRITERS.build( - dict( - type='WandbWriter', - visualizer=dict(type='Visualizer'), - save_dir='temp_dir')) - - def test_experiment(self): - wandb_writer = WandbWriter() - assert wandb_writer.experiment == wandb_writer._wandb - - def test_add_params(self): - wandb_writer = WandbWriter() - - # 'params_dict' must be dict - with pytest.raises(AssertionError): - wandb_writer.add_params(['lr', 0]) - - params_dict = dict(lr=0.1, wd=0.2, mode='linear') - wandb_writer.add_params(params_dict) - - @patch('mmengine.visualization.visualizer.Visualizer.draw', draw) - @patch('mmengine.visualization.writer.WandbWriter.add_image_to_wandb', - Mock) - def test_add_image(self): - image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) - - wandb_writer = WandbWriter() - wandb_writer.add_image('img', image) - - wandb_writer = WandbWriter(visualizer=dict(type='Visualizer')) - bboxes = np.array([[1, 1, 2, 2], [1, 1.5, 1, 2.5]]) - wandb_writer.visualizer.set_image(image) - wandb_writer.visualizer.draw_bboxes(bboxes) - wandb_writer.add_image( - 'img', wandb_writer.visualizer.get_image(), step=2) - - visuailzer = VISUALIZERS.build(dict(type='Visualizer')) - wandb_writer = WandbWriter(visualizer=visuailzer) - wandb_writer.add_image('img', image) - - def test_add_scalar(self): - wandb_writer = WandbWriter() - wandb_writer.add_scalar('map', 0.9) - # test append mode - wandb_writer.add_scalar('map', 0.9, step=0) - wandb_writer.add_scalar('map', 0.95, step=1) - - def test_add_scalars(self): - wandb_writer = WandbWriter() - input_dict = {'map': 0.7, 'acc': 0.9} - wandb_writer.add_scalars(input_dict) - # test append mode - wandb_writer.add_scalars({'map': 0.8, 'acc': 0.8}, step=1) - - -class TestComposedWriter: - sys.modules['torch.utils.tensorboard'] = MagicMock() - sys.modules['tensorboardX'] = MagicMock() - sys.modules['wandb'] = MagicMock() - - def test_init(self): - - class A: - pass - - # The writers inner element must be a dictionary or a - # subclass of Writer. - with pytest.raises(AssertionError): - ComposedWriter(writers=[A()]) - - composed_writer = ComposedWriter(writers=[ - WandbWriter(), - dict( - type='TensorboardWriter', - visualizer=dict(type='Visualizer'), - save_dir='temp_dir') - ]) - assert len(composed_writer._writers) == 2 - - # test global - composed_writer = ComposedWriter.get_instance( - 'composed_writer', - writers=[ - WandbWriter(), - dict( - type='TensorboardWriter', - visualizer=dict(type='Visualizer'), - save_dir='temp_dir') - ]) - assert len(composed_writer._writers) == 2 - composed_writer_any = ComposedWriter.get_instance('composed_writer') - assert composed_writer_any == composed_writer - - def test_get_writer(self): - composed_writer = ComposedWriter(writers=[ - WandbWriter(), - dict( - type='TensorboardWriter', - visualizer=dict(type='Visualizer'), - save_dir='temp_dir') - ]) - assert isinstance(composed_writer.get_writer(0), WandbWriter) - assert isinstance(composed_writer.get_writer(1), TensorboardWriter) - - def test_get_experiment(self): - composed_writer = ComposedWriter(writers=[ - WandbWriter(), - dict( - type='TensorboardWriter', - visualizer=dict(type='Visualizer'), - save_dir='temp_dir') - ]) - assert composed_writer.get_experiment( - 0) == composed_writer._writers[0].experiment - assert composed_writer.get_experiment( - 1) == composed_writer._writers[1].experiment - - def test_get_visualizer(self): - composed_writer = ComposedWriter(writers=[ - WandbWriter(), - dict( - type='TensorboardWriter', - visualizer=dict(type='Visualizer'), - save_dir='temp_dir') - ]) - assert composed_writer.get_visualizer( - 0) == composed_writer._writers[0].visualizer - assert composed_writer.get_visualizer( - 1) == composed_writer._writers[1].visualizer - - def test_add_params(self): - composed_writer = ComposedWriter(writers=[ - WandbWriter(), - dict( - type='TensorboardWriter', - visualizer=dict(type='Visualizer'), - save_dir='temp_dir') - ]) - - # 'params_dict' must be dict - with pytest.raises(AssertionError): - composed_writer.add_params(['lr', 0]) - - params_dict = dict(lr=0.1, wd=0.2, mode='linear') - composed_writer.add_params(params_dict) - - def test_add_graph(self): - composed_writer = ComposedWriter(writers=[ - WandbWriter(), - dict( - type='TensorboardWriter', - visualizer=dict(type='Visualizer'), - save_dir='temp_dir') - ]) - - class Model(nn.Module): - - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(1, 2, 1) - - def forward(self, x, y=None): - return self.conv(x) - - # input must be tensor - with pytest.raises(AssertionError): - composed_writer.add_graph(Model(), np.zeros([1, 1, 3, 3])) - - # input must be 4d tensor - with pytest.raises(AssertionError): - composed_writer.add_graph(Model(), torch.zeros([1, 3, 3])) - - # If the input is a list, the inner element must be a 4d tensor - with pytest.raises(AssertionError): - composed_writer.add_graph( - Model(), [torch.zeros([1, 1, 3, 3]), - torch.zeros([1, 3, 3])]) - - composed_writer.add_graph(Model(), torch.zeros([1, 1, 3, 3])) - composed_writer.add_graph( - Model(), [torch.zeros([1, 1, 3, 3]), - torch.zeros([1, 1, 3, 3])]) - - @patch('mmengine.visualization.visualizer.Visualizer.draw', draw) - @patch('mmengine.visualization.writer.WandbWriter.add_image_to_wandb', - Mock) - def test_add_image(self): - composed_writer = ComposedWriter(writers=[ - WandbWriter(), - dict( - type='TensorboardWriter', - visualizer=dict(type='Visualizer'), - save_dir='temp_dir') - ]) - - image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) - composed_writer.add_image('img', image) - - bboxes = np.array([[1, 1, 2, 2], [1, 1.5, 1, 2.5]]) - composed_writer.get_writer(1).visualizer.draw_bboxes(bboxes) - composed_writer.get_writer(1).add_image( - 'img', - composed_writer.get_writer(1).visualizer.get_image(), - step=2) - - def test_add_scalar(self): - composed_writer = ComposedWriter(writers=[ - WandbWriter(), - dict( - type='TensorboardWriter', - visualizer=dict(type='Visualizer'), - save_dir='temp_dir') - ]) - composed_writer.add_scalar('map', 0.9) - # test append mode - composed_writer.add_scalar('map', 0.9, step=0) - composed_writer.add_scalar('map', 0.95, step=1) - - def test_add_scalars(self): - composed_writer = ComposedWriter(writers=[ - WandbWriter(), - dict( - type='TensorboardWriter', - visualizer=dict(type='Visualizer'), - save_dir='temp_dir') - ]) - input_dict = {'map': 0.7, 'acc': 0.9} - composed_writer.add_scalars(input_dict) - # test append mode - composed_writer.add_scalars({'map': 0.8, 'acc': 0.8}, step=1) -- GitLab