diff --git a/docs/zh_cn/tutorials/evaluator.md b/docs/zh_cn/tutorials/evaluator.md deleted file mode 100644 index 1dfbb85af5a8e66723859e3941b2ef6f5dd3712c..0000000000000000000000000000000000000000 --- a/docs/zh_cn/tutorials/evaluator.md +++ /dev/null @@ -1,158 +0,0 @@ -# 评测器(Evaluator) - -在模型验è¯å’Œæ¨¡åž‹æµ‹è¯•ä¸ï¼Œé€šå¸¸éœ€è¦å¯¹æ¨¡åž‹ç²¾åº¦åšå®šé‡è¯„测。在 MMEngine ä¸å®žçŽ°äº†[评测器](Todo:evaluator-doc-link)æ¥å®Œæˆè¿™ä¸€åŠŸèƒ½ã€‚评测器å¯ä»¥æ ¹æ®æ¨¡åž‹çš„输入数æ®å’Œé¢„æµ‹ç»“æžœï¼Œè®¡ç®—ç‰¹å®šçš„è¯„æµ‹æŒ‡æ ‡ï¼ˆMetric)。评测器与数æ®é›†ä¹‹é—´ç›¸äº’解耦,这使得用户å¯ä»¥ä»»æ„组åˆæ‰€éœ€çš„测试数æ®å’Œè¯„测器。如 [COCOEvaluator](Todo:coco-evaluator-doc-link) å¯ç”¨äºŽè®¡ç®— COCO æ•°æ®é›†çš„ AP,AR ç‰è¯„æµ‹æŒ‡æ ‡ï¼Œä¹Ÿå¯ç”¨äºŽå…¶ä»–çš„ç›®æ ‡æ£€æµ‹æ•°æ®é›†ä¸Šã€‚ - -## 模型精度评测 - -使用评测器计算模型精度的过程如下图所示。 - -测试数æ®é€šå¸¸ä¼šè¢«åˆ’分为若干批次(batch)。通过一个循环,ä¾æ¬¡å°†æ¯ä¸ªæ‰¹æ¬¡çš„æ•°æ®é€å…¥æ¨¡åž‹ï¼Œå¾—到对应的预测结果,并将预测结果连åŒæ¨¡åž‹çš„输入数æ®ä¸€èµ·é€šè¿‡è¯„测器的 `process()` 方法é€å…¥è¯„测器。当循环结æŸåŽï¼Œå†è°ƒç”¨è¯„测器的 `evaluate()` 方法,å³å¯è®¡ç®—å¾—åˆ°å¯¹åº”çš„è¯„æµ‹æŒ‡æ ‡ã€‚ - -在实际使用ä¸ï¼Œè¿™äº›æ“作å‡ç”±ä»»åŠ¡æ‰§è¡Œå™¨å®Œæˆã€‚用户åªéœ€è¦åœ¨é…置文件ä¸é€‰æ‹©è¦ä½¿ç”¨çš„评测器并é…置相应å‚æ•°å³å¯ã€‚ - -<div align="center"> - <img src="https://user-images.githubusercontent.com/15977946/154652635-f4bda588-9f94-462f-b68f-b900690e6215.png"/> -</div> - - -### 在é…置文件ä¸é…置评测器 - -在é…置文件ä¸é…置评测器时,需è¦æŒ‡å®šè¯„测器的类别ã€å‚数以åŠè°ƒç”¨æ–¹å¼ç‰ã€‚å…¶ä¸ï¼Œè°ƒç”¨æ–¹å¼é€šå¸¸é’ˆå¯¹æ¨¡åž‹éªŒè¯é˜¶æ®µï¼ŒåŒ…括调用评测器的间隔时间å•ä½ï¼ˆepoch 或 iteration)ã€é—´éš”时间ã€ä¸»è¦è¯„æµ‹æŒ‡æ ‡ï¼ˆå³ç›é€‰æœ€ä½³ checkpoint 所ä¾æ®çš„æŒ‡æ ‡ï¼‰ç‰ã€‚ - -例如,用户希望在模型验è¯æ—¶ä½¿ç”¨ COCO è¯„æµ‹å™¨ï¼Œæ¯ 10 epoch 评测一次,并以 AP 作为主è¦è¯„æµ‹æŒ‡æ ‡ï¼Œå¯¹åº”çš„é…置文件部分如下: - -```python -validation_cfg=dict( - evaluator=dict(type='COCO'), # 使用 COCO è¯„æµ‹å™¨ï¼Œæ— å‚æ•° - main_metric='AP', # 主è¦è¯„æµ‹æŒ‡æ ‡ä¸º AP - interval=10, # æ¯ 10 epoch 评测一次 - by_epoch=True, -) -``` - -### 使用多个评测器 - -评测器支æŒç»„åˆä½¿ç”¨ã€‚用户å¯ä»¥é€šè¿‡é…置多个评测器,在模型验è¯æˆ–模型测试阶段åŒæ—¶è®¡ç®—å¤šä¸ªè¯„æµ‹æŒ‡æ ‡ã€‚ä½¿ç”¨å¤šä¸ªè¯„æµ‹å™¨æ—¶ï¼Œåªéœ€è¦åœ¨é…置文件里将所有评测器的é…置写在一个列表里å³å¯ï¼š - -```python -validation_cfg=dict( - evaluator=[ - dict(type='Accuracy', top_k=1), # 使用分类æ£ç¡®çŽ‡è¯„测器 - dict(type='F1Score') # 使用 F1_score 评测器 - ], - main_metric='accuracy', - interval=10, - by_epoch=True, -) -``` - -使用多个评测器时,å¯èƒ½å‡ºçŽ°è¯„æµ‹æŒ‡æ ‡åŒå的情况。比如,在下é¢çš„例åä¸ä½¿ç”¨äº† 2 个 `COCOEvaluator` åˆ†åˆ«å¯¹æ£€æµ‹æ¡†å’Œå…³é”®ç‚¹çš„é¢„æµ‹ç»“æžœè¿›è¡Œè¯„æµ‹ï¼Œå®ƒä»¬çš„è¯„æµ‹æŒ‡æ ‡éƒ½åŒ…æ‹¬ `AP`,`AR` ç‰ã€‚为了é¿å…åŒåè¯„æµ‹æŒ‡æ ‡å¼•å‘æ§ä¹‰ï¼Œ`Evaluator` ä¸æ”¯æŒé€šè¿‡ `prefix` å‚æ•°ä¸ºè¯„æµ‹æŒ‡æ ‡åå¢žåŠ å‰ç¼€ã€‚通常,一个 `Evaluator` 会有默认的å‰ç¼€ï¼Œç”¨æˆ·ä¹Ÿå¯ä»¥åœ¨é…置文件ä¸è¿›è¡ŒæŒ‡å®šã€‚ - -```python -validation_cfg=dict( - evaluator=[ - dict(type='COCO', iou_type='bbox'), # 使用默认å‰ç¼€ `COCO` - dict(type='COCO', iou_type='keypoints', prefix='COCOKpts') # 自定义å‰ç¼€ `COCOKpts` - ], - # 指定使用å‰ç¼€ä¸º COCO çš„ AP 为主è¦è¯„æµ‹æŒ‡æ ‡ - # 在没有é‡åæŒ‡æ ‡æ§ä¹‰çš„情况下,æ¤å¤„å¯ä»¥ä¸å†™å‰ç¼€ï¼Œåªå†™è¯„æµ‹æŒ‡æ ‡å - main_metric='COCO/AP', - interval=10, - by_epoch=True, -) -``` - -## å¢žåŠ è‡ªå®šä¹‰è¯„æµ‹å™¨ - -在 OpenMMLab çš„å„个算法库ä¸ï¼Œå·²ç»å®žçŽ°äº†å¯¹åº”æ–¹å‘的常用评测器。如 MMDetection ä¸æ供了 COCO 评测器,MMClassification ä¸æ供了 Accuracyã€F1Score ç‰è¯„测器ç‰ã€‚ - -用户也å¯ä»¥æ ¹æ®è‡ªèº«éœ€æ±‚ï¼Œå¢žåŠ è‡ªå®šä¹‰çš„è¯„æµ‹å™¨ã€‚åœ¨å®žçŽ°è‡ªå®šä¹‰è¯„æµ‹å™¨æ—¶ï¼Œç”¨æˆ·éœ€è¦ç»§æ‰¿ MMEngine ä¸æ供的评测器基类 [BaseEvaluator](Todo:baseevaluator-doc-link),并实现对应的抽象方法。 - -### 评测器基类 - -评测器基类 `BaseEvaluator` 是一个抽象类,具有以下 2 个抽象方法: - -- `process()`: 处ç†æ¯ä¸ªæ‰¹æ¬¡çš„测试数æ®å’Œæ¨¡åž‹é¢„测结果。处ç†ç»“果应å˜æ”¾åœ¨ `self.results` 列表ä¸ï¼Œç”¨äºŽåœ¨å¤„ç†å®Œæ‰€æœ‰æµ‹è¯•æ•°æ®åŽè®¡ç®—è¯„æµ‹æŒ‡æ ‡ã€‚ -- `compute_metrics()`: è®¡ç®—è¯„æµ‹æŒ‡æ ‡ï¼Œå¹¶å°†æ‰€è¯„æµ‹æŒ‡æ ‡å˜æ”¾åœ¨ä¸€ä¸ªå—å…¸ä¸è¿”回。 - -å…¶ä¸ï¼Œ`compute_metrics()` 会在 `evaluate()` 方法ä¸è¢«è°ƒç”¨ï¼›åŽè€…åœ¨è®¡ç®—è¯„æµ‹æŒ‡æ ‡å‰ï¼Œä¼šåœ¨åˆ†å¸ƒå¼æµ‹è¯•æ—¶æ”¶é›†å’Œæ±‡æ€»ä¸åŒ rank çš„ä¸é—´å¤„ç†ç»“果。而 `process()` å’Œ `evaluate()` éƒ½ä¼šç”±ä»»åŠ¡æ‰§è¡Œå™¨è°ƒç”¨ã€‚å› æ¤ï¼Œç”¨æˆ·åªéœ€è¦åœ¨ç»§æ‰¿ `BaseEvaluator` åŽå®žçŽ° `process()` å’Œ `compute_metrics()` 方法å³å¯ã€‚ - -需è¦æ³¨æ„的是,`self.results` ä¸å˜æ”¾çš„具体类型å–å†³äºŽè‡ªå®šä¹‰è¯„æµ‹å™¨ç±»çš„å®žçŽ°ã€‚ä¾‹å¦‚ï¼Œå½“æµ‹è¯•æ ·æœ¬æˆ–æ¨¡åž‹è¾“å‡ºæ•°æ®é‡è¾ƒå¤§ï¼ˆå¦‚è¯ä¹‰åˆ†å‰²ã€å›¾åƒç”Ÿæˆç‰ä»»åŠ¡ï¼‰ï¼Œä¸å®œå…¨éƒ¨å˜æ”¾åœ¨å†…å˜ä¸æ—¶ï¼Œå¯ä»¥åœ¨ `self.results` ä¸å˜æ”¾æ¯ä¸ªæ‰¹æ¬¡è®¡ç®—å¾—åˆ°çš„æŒ‡æ ‡ï¼Œå¹¶åœ¨ `compute_metrics()` ä¸æ±‡æ€»ï¼›æˆ–å°†æ¯ä¸ªæ‰¹æ¬¡çš„ä¸é—´ç»“æžœå˜å‚¨åˆ°ä¸´æ—¶æ–‡ä»¶ä¸ï¼Œå¹¶åœ¨ `self.results` ä¸å˜æ”¾ä¸´æ—¶æ–‡ä»¶è·¯å¾„,最åŽç”± `compute_metrics()` 从文件ä¸è¯»å–æ•°æ®å¹¶è®¡ç®—æŒ‡æ ‡ã€‚ - -### 自定义评测器类 - -我们以实现分类æ£ç¡®çŽ‡ï¼ˆClassification Accuracy)评测器为例,说明实现自定义评测器的方法。 - -首先,自定义评测器类应继承自 `BaseEvaluator`ï¼Œå¹¶åº”åŠ å…¥æ³¨å†Œå™¨ `EVALUATORS` (关于注册器的说明请å‚考[相关文档](docs\zh_cn\tutorials\registry.md))。 - - `process()` 方法有 2 个输入å‚数,分别是一个批次的测试数æ®æ ·æœ¬ `data_batch` 和模型预测结果 `predictions`。我们从ä¸åˆ†åˆ«å–å‡ºæ ·æœ¬ç±»åˆ«æ ‡ç¾å’Œåˆ†ç±»é¢„测结果,并å˜æ”¾åœ¨ `self.results` ä¸ã€‚ - -`compute_metrics()` 方法有 1 个输入å‚æ•° `results`,里é¢å˜æ”¾äº†æ‰€æœ‰æ‰¹æ¬¡æµ‹è¯•æ•°æ®ç»è¿‡ `process()` 方法处ç†åŽå¾—到的结果。从ä¸å–å‡ºæ ·æœ¬ç±»åˆ«æ ‡ç¾å’Œåˆ†ç±»é¢„测结果,å³å¯è®¡ç®—得到分类æ£ç¡®çŽ‡ `acc`ã€‚æœ€ç»ˆï¼Œå°†è®¡ç®—å¾—åˆ°çš„è¯„æµ‹æŒ‡æ ‡ä»¥å—典的形å¼è¿”回。 - -æ¤å¤–,我们建议在åç±»ä¸ä¸ºç±»å±žæ€§ `default_prefix` 赋值。如果在åˆå§‹åŒ–å‚æ•°ï¼ˆå³ config ä¸ï¼‰æ²¡æœ‰æŒ‡å®š `prefix`,则会自动使用 `default_prefix` ä½œä¸ºè¯„æµ‹æŒ‡æ ‡åçš„å‰ç¼€ã€‚åŒæ—¶ï¼Œåº”在 docstring ä¸è¯´æ˜Žè¯¥è¯„测器的 `default_prefix` 值以åŠæ‰€æœ‰çš„è¯„æµ‹æŒ‡æ ‡ã€‚ - -具体的实现如下: - -```python -from mmengine.evaluator import BaseMetric -from mmengine.registry import METRICS - -import numpy as np - - -@METRICS.register_module() -class Accuracy(BaseMetric): - """ Accuracy Evaluator - - Default prefix: ACC - - Metrics: - - accuracy (float): classification accuracy - """ - - default_prefix = 'ACC' - - def process(self, data_batch: Sequence[Tuple[Any, dict]], - predictions: Sequence[dict]): - """Process one batch of data and predictions. The processed - Results should be stored in `self.results`, which will be used - to computed the metrics when all batches have been processed. - - Args: - data_batch (Sequence[Tuple[Any, dict]]): A batch of data - from the dataloader. - predictions (Sequence[dict]): A batch of outputs from - the model. - """ - - # å–å‡ºåˆ†ç±»é¢„æµ‹ç»“æžœå’Œç±»åˆ«æ ‡ç¾ - result = { - 'pred': predictions['pred_label'], - 'gt': data_batch['gt_label'] - } - - # å°†å½“å‰ batch 的结果å˜è¿› self.results - self.results.append(result) - - def compute_metrics(self, results: List): - """Compute the metrics from processed results. - - Args: - results (dict): The processed results of each batch. - - Returns: - Dict: The computed metrics. The keys are the names of the metrics, - and the values are corresponding results. - """ - - # æ±‡æ€»æ‰€æœ‰æ ·æœ¬çš„åˆ†ç±»é¢„æµ‹ç»“æžœå’Œç±»åˆ«æ ‡ç¾ - preds = np.concatenate([res['pred'] for res in results]) - gts = np.concatenate([res['gt'] for res in results]) - - # 计算分类æ£ç¡®çŽ‡ - acc = (preds == gts).sum() / preds.size - - # è¿”å›žè¯„æµ‹æŒ‡æ ‡ç»“æžœ - return {'accuracy': acc} - -``` diff --git a/docs/zh_cn/tutorials/metric_and_evaluator.md b/docs/zh_cn/tutorials/metric_and_evaluator.md new file mode 100644 index 0000000000000000000000000000000000000000..426a8d293d021d02a9156376a02ada403a0b4469 --- /dev/null +++ b/docs/zh_cn/tutorials/metric_and_evaluator.md @@ -0,0 +1,133 @@ +# è¯„æµ‹æŒ‡æ ‡ï¼ˆMetric)和评测器(Evaluator) + +在模型验è¯å’Œæ¨¡åž‹æµ‹è¯•ä¸ï¼Œé€šå¸¸éœ€è¦å¯¹æ¨¡åž‹ç²¾åº¦åšå®šé‡è¯„测。在 MMEngine ä¸å®žçŽ°äº†[è¯„æµ‹æŒ‡æ ‡](Todo:metric-doc-link)å’Œ[评测器](Todo:evaluator-doc-linek)æ¥å®Œæˆè¿™ä¸€åŠŸèƒ½ã€‚ + +**è¯„æµ‹æŒ‡æ ‡** æ ¹æ®æ¨¡åž‹çš„输入数æ®å’Œé¢„测结果,完æˆç‰¹å®šæŒ‡æ ‡ä¸‹æ¨¡åž‹ç²¾åº¦çš„è®¡ç®—ã€‚è¯„æµ‹æŒ‡æ ‡ä¸Žæ•°æ®é›†ä¹‹é—´ç›¸äº’解耦,这使得用户å¯ä»¥ä»»æ„组åˆæ‰€éœ€çš„测试数æ®å’Œè¯„æµ‹æŒ‡æ ‡ã€‚å¦‚ [COCOMetric](Todo:coco-metric-doc-link) å¯ç”¨äºŽè®¡ç®— COCO æ•°æ®é›†çš„ AP,AR ç‰è¯„æµ‹æŒ‡æ ‡ï¼Œä¹Ÿå¯ç”¨äºŽå…¶ä»–çš„ç›®æ ‡æ£€æµ‹æ•°æ®é›†ä¸Šã€‚ +**评测器** æ˜¯è¯„æµ‹æŒ‡æ ‡çš„ä¸Šå±‚æ¨¡å—,通常包å«ä¸€ä¸ªæˆ–å¤šä¸ªè¯„æµ‹æŒ‡æ ‡ã€‚è¯„æµ‹å™¨çš„ä½œç”¨æ˜¯åœ¨æ¨¡åž‹è¯„æµ‹æ—¶å®Œæˆå¿…è¦çš„æ•°æ®æ ¼å¼è½¬æ¢ï¼Œå¹¶è°ƒç”¨è¯„æµ‹æŒ‡æ ‡è®¡ç®—æ¨¡åž‹ç²¾åº¦ã€‚è¯„æµ‹å™¨é€šå¸¸ç”±[执行器](TODO:runner-doc-link)或测试脚本构建,分别用于在线评测和离线评测。 + +用户通常ä¸éœ€è¦æ·±å…¥äº†è§£æˆ–æ‰‹åŠ¨ä¿®æ”¹è¯„æµ‹å™¨ï¼Œå› æ¤è¯¥æ–‡æ¡£å°†é‡ç‚¹ä»‹ç»è¯„æµ‹æŒ‡æ ‡çš„åŽŸç†å’Œä½¿ç”¨æ–¹å¼ã€‚ + +## 模型精度评测 + +通常,模型精度评测的过程如下图所示。 + +**在线评测**:测试数æ®é€šå¸¸ä¼šè¢«åˆ’分为若干批次(batch)。通过一个循环,ä¾æ¬¡å°†æ¯ä¸ªæ‰¹æ¬¡çš„æ•°æ®é€å…¥æ¨¡åž‹ï¼Œå¾—到对应的预测结果,并将测试数æ®å’Œæ¨¡åž‹é¢„测结果é€å…¥è¯„æµ‹å™¨ã€‚è¯„æµ‹å™¨ä¼šè°ƒç”¨è¯„æµ‹æŒ‡æ ‡çš„ `process()` 方法对数æ®å’Œé¢„测结果进行处ç†ã€‚当循环结æŸåŽï¼Œè¯„æµ‹å™¨ä¼šè°ƒç”¨è¯„æµ‹æŒ‡æ ‡çš„ `evaluate()` 方法,å¯è®¡ç®—å¾—åˆ°å¯¹åº”æŒ‡æ ‡çš„æ¨¡åž‹ç²¾åº¦ã€‚ + +**离线评测**:与在线评测过程类似,区别是直接读å–预先ä¿å˜çš„模型预测结果æ¥è¿›è¡Œè¯„测。评测器æ供了 `offline_evaluate` 接å£ï¼Œç”¨äºŽåœ¨ç¦»çº¿æ–¹å¼ä¸‹è°ƒç”¨è¯„æµ‹æŒ‡æ ‡æ¥è®¡ç®—模型精度。为了é¿å…åŒæ—¶å¤„ç†å¤§é‡æ•°æ®å¯¼è‡´å†…å˜æº¢å‡ºï¼Œç¦»çº¿è¯„测时会将测试数æ®å’Œé¢„测结果分æˆè‹¥å¹²ä¸ªå—(Chunk)进行处ç†ï¼Œç±»ä¼¼åœ¨çº¿è¯„测ä¸çš„批次。 + +<div align="center"> + <img src="https://user-images.githubusercontent.com/15977946/163718224-20a4970a-e540-4a3a-8b01-bf0a604c6841.jpg" width="500"/> +</div> + +## 在é…置文件ä¸é…ç½®è¯„æµ‹æŒ‡æ ‡ + +在é…置文件ä¸å¯ä»¥é€šè¿‡ `val_evaluator` å’Œ `test_evaluator` 2 个å—段分别指定模型验è¯å’Œæµ‹è¯•é˜¶æ®µçš„è¯„æµ‹æŒ‡æ ‡ã€‚ä¾‹å¦‚ï¼Œç”¨æˆ·åœ¨è®ç»ƒåˆ†ç±»æ¨¡åž‹æ—¶ï¼Œå¸Œæœ›åœ¨æ¨¡åž‹éªŒè¯é˜¶æ®µä½¿ç”¨åˆ†ç±»æ£ç¡®çŽ‡å’Œ F1 Score ä¸¤ä¸ªè¯„æµ‹æŒ‡æ ‡ï¼Œå¯ä»¥æŒ‰ä»¥ä¸‹æ–¹å¼é…置: + +```python +val_evaluator = [ + dict(type='Accuracy', top_k=1), # 使用分类æ£ç¡®çŽ‡è¯„æµ‹æŒ‡æ ‡ + dict(type='F1Score') # 使用 F1_score è¯„æµ‹æŒ‡æ ‡ +] +``` + +é…ç½®ä¸çš„`val_evaluator` 会被用于构建一个包å«å¤šä¸ªè¯„æµ‹æŒ‡æ ‡çš„è¯„æµ‹å™¨ï¼Œå…¶ä¸çš„æ¯ä¸ªå—å…¸å¯¹åº”äºŽä¸€ä¸ªè¯„æµ‹æŒ‡æ ‡çš„ç±»åˆ«å’Œå‚数。 +如果åªä½¿ç”¨å•ä¸ªè¯„æµ‹æŒ‡æ ‡ï¼Œä¹Ÿå¯ä»¥çœç•¥æŽ‰é…ç½®ä¸çš„åˆ—è¡¨ï¼Œç›´æŽ¥æŒ‡å®šè¯„æµ‹æŒ‡æ ‡å‚数。例如,在模型测试阶段使用分类æ£ç¡®çŽ‡è¯„æµ‹æŒ‡æ ‡ï¼Œå¯¹åº”çš„é…置如下: + +```python +test_evaluator = dict(type='Accuracy', top_k=1) +``` + +## å¢žåŠ è‡ªå®šä¹‰è¯„æµ‹æŒ‡æ ‡ + +在 OpenMMLab çš„å„个算法库ä¸ï¼Œå·²ç»å®žçŽ°äº†å¯¹åº”æ–¹å‘çš„å¸¸ç”¨è¯„æµ‹æŒ‡æ ‡ã€‚å¦‚ MMDetection ä¸æ供了 COCO è¯„æµ‹æŒ‡æ ‡ï¼ŒMMClassification ä¸æ供了 Accuracyã€F1Score ç‰è¯„æµ‹æŒ‡æ ‡ç‰ã€‚ + +用户也å¯ä»¥å¢žåŠ è‡ªå®šä¹‰çš„è¯„æµ‹æŒ‡æ ‡ã€‚åœ¨å®žçŽ°è‡ªå®šä¹‰è¯„æµ‹æŒ‡æ ‡æ—¶ï¼Œéœ€è¦ç»§æ‰¿ MMEngine ä¸æä¾›çš„è¯„æµ‹æŒ‡æ ‡åŸºç±» [BaseMetric](Todo:basemetric-doc-link),并实现对应的抽象方法。 + +### è¯„æµ‹æŒ‡æ ‡åŸºç±» + +è¯„æµ‹æŒ‡æ ‡åŸºç±» `BaseMetric` 是一个抽象类,具有以下 2 个抽象方法: + +- `process()`: 处ç†æ¯ä¸ªæ‰¹æ¬¡çš„测试数æ®å’Œæ¨¡åž‹é¢„测结果。处ç†ç»“果应å˜æ”¾åœ¨ `self.results` 列表ä¸ï¼Œç”¨äºŽåœ¨å¤„ç†å®Œæ‰€æœ‰æµ‹è¯•æ•°æ®åŽè®¡ç®—è¯„æµ‹æŒ‡æ ‡ã€‚ +- `compute_metrics()`: è®¡ç®—è¯„æµ‹æŒ‡æ ‡ï¼Œå¹¶å°†æ‰€è¯„æµ‹æŒ‡æ ‡å˜æ”¾åœ¨ä¸€ä¸ªå—å…¸ä¸è¿”回。 + +å…¶ä¸ï¼Œ`compute_metrics()` 会在 `evaluate()` 方法ä¸è¢«è°ƒç”¨ï¼›åŽè€…åœ¨è®¡ç®—è¯„æµ‹æŒ‡æ ‡å‰ï¼Œä¼šåœ¨åˆ†å¸ƒå¼æµ‹è¯•æ—¶æ”¶é›†å’Œæ±‡æ€»ä¸åŒ rank çš„ä¸é—´å¤„ç†ç»“果。 + +需è¦æ³¨æ„的是,`self.results` ä¸å˜æ”¾çš„具体类型å–å†³äºŽè¯„æµ‹æŒ‡æ ‡åç±»çš„å®žçŽ°ã€‚ä¾‹å¦‚ï¼Œå½“æµ‹è¯•æ ·æœ¬æˆ–æ¨¡åž‹è¾“å‡ºæ•°æ®é‡è¾ƒå¤§ï¼ˆå¦‚è¯ä¹‰åˆ†å‰²ã€å›¾åƒç”Ÿæˆç‰ä»»åŠ¡ï¼‰ï¼Œä¸å®œå…¨éƒ¨å˜æ”¾åœ¨å†…å˜ä¸æ—¶ï¼Œå¯ä»¥åœ¨ `self.results` ä¸å˜æ”¾æ¯ä¸ªæ‰¹æ¬¡è®¡ç®—å¾—åˆ°çš„æŒ‡æ ‡ï¼Œå¹¶åœ¨ `compute_metrics()` ä¸æ±‡æ€»ï¼›æˆ–å°†æ¯ä¸ªæ‰¹æ¬¡çš„ä¸é—´ç»“æžœå˜å‚¨åˆ°ä¸´æ—¶æ–‡ä»¶ä¸ï¼Œå¹¶åœ¨ `self.results` ä¸å˜æ”¾ä¸´æ—¶æ–‡ä»¶è·¯å¾„,最åŽç”± `compute_metrics()` 从文件ä¸è¯»å–æ•°æ®å¹¶è®¡ç®—æŒ‡æ ‡ã€‚ + +### è‡ªå®šä¹‰è¯„æµ‹æŒ‡æ ‡ç±» + +我们以实现分类æ£ç¡®çŽ‡ï¼ˆClassification Accuracyï¼‰è¯„æµ‹æŒ‡æ ‡ä¸ºä¾‹ï¼Œè¯´æ˜Žè‡ªå®šä¹‰è¯„æµ‹æŒ‡æ ‡çš„æ–¹æ³•ã€‚ + +é¦–å…ˆï¼Œè¯„æµ‹æŒ‡æ ‡ç±»åº”ç»§æ‰¿è‡ª `BaseMetric`ï¼Œå¹¶åº”åŠ å…¥æ³¨å†Œå™¨ `METRICS` (关于注册器的说明请å‚考[相关文档](docs\zh_cn\tutorials\registry.md))。 + + `process()` 方法有 2 个输入å‚数,分别是一个批次的测试数æ®æ ·æœ¬ `data_batch` 和模型预测结果 `predictions`。我们从ä¸åˆ†åˆ«å–å‡ºæ ·æœ¬ç±»åˆ«æ ‡ç¾å’Œåˆ†ç±»é¢„测结果,并å˜æ”¾åœ¨ `self.results` ä¸ã€‚ + +`compute_metrics()` 方法有 1 个输入å‚æ•° `results`,里é¢å˜æ”¾äº†æ‰€æœ‰æ‰¹æ¬¡æµ‹è¯•æ•°æ®ç»è¿‡ `process()` 方法处ç†åŽå¾—到的结果。从ä¸å–å‡ºæ ·æœ¬ç±»åˆ«æ ‡ç¾å’Œåˆ†ç±»é¢„测结果,å³å¯è®¡ç®—得到分类æ£ç¡®çŽ‡ `acc`ã€‚æœ€ç»ˆï¼Œå°†è®¡ç®—å¾—åˆ°çš„è¯„æµ‹æŒ‡æ ‡ä»¥å—典的形å¼è¿”回。 + +æ¤å¤–,我们建议在åç±»ä¸ä¸ºç±»å±žæ€§ `default_prefix` 赋值。如果在åˆå§‹åŒ–å‚æ•°ï¼ˆå³ config ä¸ï¼‰æ²¡æœ‰æŒ‡å®š `prefix`,则会自动使用 `default_prefix` ä½œä¸ºè¯„æµ‹æŒ‡æ ‡åçš„å‰ç¼€ã€‚åŒæ—¶ï¼Œåº”在 docstring ä¸è¯´æ˜Žè¯¥è¯„æµ‹æŒ‡æ ‡ç±»çš„ `default_prefix` 值以åŠæ‰€æœ‰çš„è¿”å›žæŒ‡æ ‡å称。 + +具体的实现如下: + +```python +from mmengine.evaluator import BaseMetric +from mmengine.registry import METRICS + +import numpy as np + + +@METRICS.register_module() # å°† Accuracy 类注册到 METRICS 注册器 +class Accuracy(BaseMetric): + """ Accuracy Evaluator + + Default prefix: ACC + + Metrics: + - accuracy (float): classification accuracy + """ + + default_prefix = 'ACC' # 设置 default_prefix + + def process(self, data_batch: data_batch: Sequence[dict], + predictions: Sequence[dict]): + """Process one batch of data and predictions. The processed + Results should be stored in `self.results`, which will be used + to computed the metrics when all batches have been processed. + + Args: + data_batch (Sequence[Tuple[Any, dict]]): A batch of data + from the dataloader. + predictions (Sequence[dict]): A batch of outputs from + the model. + """ + + # å–å‡ºåˆ†ç±»é¢„æµ‹ç»“æžœå’Œç±»åˆ«æ ‡ç¾ + result = { + 'pred': predictions['pred_label'], + 'gt': data_batch['data_sample']['gt_label'] + } + + # å°†å½“å‰ batch 的结果å˜è¿› self.results + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + + # æ±‡æ€»æ‰€æœ‰æ ·æœ¬çš„åˆ†ç±»é¢„æµ‹ç»“æžœå’Œç±»åˆ«æ ‡ç¾ + preds = np.concatenate([res['pred'] for res in results]) + gts = np.concatenate([res['gt'] for res in results]) + + # 计算分类æ£ç¡®çŽ‡ + acc = (preds == gts).sum() / preds.size + + # è¿”å›žè¯„æµ‹æŒ‡æ ‡ç»“æžœ + return {'accuracy': acc} + +``` diff --git a/docs/zh_cn/tutorials/visualization.md b/docs/zh_cn/tutorials/visualization.md deleted file mode 100644 index 80acabccf07a68e872cdfc0362f00aef9e34f509..0000000000000000000000000000000000000000 --- a/docs/zh_cn/tutorials/visualization.md +++ /dev/null @@ -1,301 +0,0 @@ -# å¯è§†åŒ– (Visualization) - -## 概述 - -**(1) 总体介ç»** - -å¯è§†åŒ–å¯ä»¥ç»™æ·±åº¦å¦ä¹ 的模型è®ç»ƒå’Œæµ‹è¯•è¿‡ç¨‹æ供直观解释。在 OpenMMLab 算法库ä¸ï¼Œæˆ‘们期望å¯è§†åŒ–功能的设计能满足以下需求: - -- æ供丰富的开箱å³ç”¨å¯è§†åŒ–功能,能够满足大部分计算机视觉å¯è§†åŒ–任务 -- 高扩展性,å¯è§†åŒ–åŠŸèƒ½é€šå¸¸å¤šæ ·åŒ–ï¼Œåº”è¯¥èƒ½å¤Ÿé€šè¿‡ç®€å•æ‰©å±•å®žçŽ°å®šåˆ¶éœ€æ±‚ -- 能够在è®ç»ƒå’Œæµ‹è¯•æµç¨‹çš„ä»»æ„点ä½è¿›è¡Œå¯è§†åŒ– -- OpenMMLab å„个算法库具有统一å¯è§†åŒ–接å£ï¼Œåˆ©äºŽç”¨æˆ·ç†è§£å’Œç»´æŠ¤ - -基于上述需求,OpenMMLab 2.0 引入了绘制对象 Visualizer 和写端对象 Writer 的概念 - -- **Visualizer è´Ÿè´£å•å¼ 图片的绘制功能** - - MMEngine æ供了以 Matplotlib 库为绘制åŽç«¯çš„ `Visualizer` 类,其具备如下功能: - - - æä¾›äº†ä¸€ç³»åˆ—å’Œè§†è§‰ä»»åŠ¡æ— å…³çš„åŸºç¡€æ–¹æ³•ï¼Œä¾‹å¦‚ `draw_bboxes` å’Œ `draw_texts` ç‰ - - å„个基础方法支æŒé“¾å¼è°ƒç”¨ï¼Œæ–¹ä¾¿å åŠ ç»˜åˆ¶æ˜¾ç¤º - - 通过 `draw_featmap` æ供绘制特å¾å›¾åŠŸèƒ½ - - å„个下游算法库å¯ä»¥ç»§æ‰¿ `Visualizer` 并在 `draw` 接å£ä¸å®žçŽ°æ‰€éœ€çš„å¯è§†åŒ–功能,例如 MMDetection ä¸çš„ `DetVisualizer` 继承自 `Visualizer` 并在 `draw` 接å£ä¸å®žçŽ°å¯è§†åŒ–检测框ã€å®žä¾‹æŽ©ç å’Œè¯ä¹‰åˆ†å‰²å›¾ç‰åŠŸèƒ½ã€‚Visualizer 类的 UML 关系图如下 - - <div align="center"> - <img src="https://user-images.githubusercontent.com/17425982/154475592-7208a34b-f6cb-4171-b0be-9dbb13306862.png" > - </div> - -- **Writer 负责将å„类数æ®å†™å…¥åˆ°æŒ‡å®šåŽç«¯** - - 为了统一接å£è°ƒç”¨ï¼ŒMMEngine æ供了统一的抽象类 `BaseWriter`,和一些常用的 Writer 如 `LocalWriter` æ¥æ”¯æŒå°†æ•°æ®å†™å…¥æœ¬åœ°ï¼Œ`TensorboardWriter` æ¥æ”¯æŒå°†æ•°æ®å†™å…¥ Tensorboard,`WandbWriter` æ¥æ”¯æŒå°†æ•°æ®å†™å…¥ Wandb。用户也å¯ä»¥è‡ªå®šä¹‰ Writer æ¥å°†æ•°æ®å†™å…¥è‡ªå®šä¹‰åŽç«¯ã€‚写入的数æ®å¯ä»¥æ˜¯å›¾ç‰‡ï¼Œæ¨¡åž‹ç»“æž„å›¾ï¼Œæ ‡é‡å¦‚æ¨¡åž‹ç²¾åº¦æŒ‡æ ‡ç‰ã€‚ - - 考虑到在è®ç»ƒæˆ–者测试过程ä¸å¯èƒ½åŒæ—¶å˜åœ¨å¤šä¸ª Writer 对象,例如åŒæ—¶æƒ³è¿›è¡Œæœ¬åœ°å’Œè¿œç¨‹ç«¯å†™æ•°æ®ï¼Œä¸ºæ¤è®¾è®¡äº† `ComposedWriter` 负责管ç†æ‰€æœ‰è¿è¡Œä¸å®žä¾‹åŒ–çš„ Writer 对象,其会自动管ç†æ‰€æœ‰ Writer 对象,并é历调用所有 Writer 对象的方法。Writer 类的 UML 关系图如下 - <div align="center"> - <img src="https://user-images.githubusercontent.com/17425982/157000633-9f552539-f722-44b1-b253-1abaf4a8eba6.png" > - </div> - -**(2) Writer å’Œ Visualizer 关系** - -Writer å¯¹è±¡çš„æ ¸å¿ƒåŠŸèƒ½æ˜¯å†™å„类数æ®åˆ°æŒ‡å®šåŽç«¯ä¸ï¼Œä¾‹å¦‚写图片ã€å†™æ¨¡åž‹å›¾ã€å†™è¶…å‚å’Œå†™æ¨¡åž‹ç²¾åº¦æŒ‡æ ‡ç‰ï¼ŒåŽç«¯å¯ä»¥æŒ‡å®šä¸ºæœ¬åœ°å˜å‚¨ã€Wandb å’Œ Tensorboard ç‰ç‰ã€‚在写图片过程ä¸ï¼Œé€šå¸¸å¸Œæœ›èƒ½å¤Ÿå°†é¢„æµ‹ç»“æžœæˆ–è€…æ ‡æ³¨ç»“æžœç»˜åˆ¶åˆ°å›¾ç‰‡ä¸Šï¼Œç„¶åŽå†è¿›è¡Œå†™æ“作,为æ¤åœ¨ Writer 内部维护了 Visualizer 对象,将 Visualizer 作为 Writer 的一个属性。需è¦æ³¨æ„的是: - -- åªæœ‰è°ƒç”¨äº† Writer ä¸çš„ `add_image` 写图片功能时候æ‰å¯èƒ½ä¼šç”¨åˆ° Visualizer 对象,其余接å£å’Œ Visualizer 没有关系 -- 考虑到æŸäº› Writer åŽç«¯æœ¬èº«å°±å…·å¤‡ç»˜åˆ¶åŠŸèƒ½ä¾‹å¦‚ `WandbWriter`,æ¤æ—¶ `WandbWriter` ä¸çš„ Visualizer 属性就是å¯é€‰çš„,如果用户在åˆå§‹åŒ–æ—¶å€™ä¼ å…¥äº† Visualizer 对象,则在 `add_image` 时候会调用 Visualizer 对象,å¦åˆ™ä¼šç›´æŽ¥è°ƒç”¨ Wandb 本身 API 进行图片绘制 -- `LocalWriter` å’Œ `TensorboardWriter` 由于绘制功能å•ä¸€ï¼Œç›®å‰å¼ºåˆ¶ç”± Visualizer 对象绘制,所以这两个 Writer å¿…é¡»ä¼ å…¥ Visualizer 或者å类对象 - -`WandbWriter` 的一个简略的演示代ç 如下 - -```python -# 为了方便ç†è§£ï¼Œæ²¡æœ‰ç»§æ‰¿ BaseWriter -class WandbWriter: - def __init__(self, visualizer=None): - self._visualizer = None - if visualizer: - # 示例é…ç½® visualizer=dict(type='DetVisualizer') - self._visualizer = VISUALIZERS.build(visualizer) - - @property - def visualizer(self): - return self._visualizer - - def add_image(self, name, image, gt_sample=None, pred_sample=None, draw_gt=True, draw_pred=True, step=0, **kwargs): - if self._visualize: - self._visualize.draw(image, gt_sample, pred_sample, draw_gt, draw_pred) - # 调用 Writer API 写图片到åŽç«¯ - self.wandb.log({name: self.visualizer.get_image()}, ...) - ... - else: - # 调用 Writer API 汇总并写图片到åŽç«¯ - ... - - def add_scalar(self, name, value, step): - self.wandb.log({name: value}, ...) -``` - - -## 绘制对象 Visualizer - -绘制对象 Visualizer è´Ÿè´£å•å¼ 图片的å„类绘制功能,默认绘制åŽç«¯ä¸º Matplotlib。为了统一 OpenMMLab å„个算法库的å¯è§†åŒ–接å£ï¼ŒMMEngine 定义æ供了基础绘制功能的 `Visualizer` 类,下游库å¯ä»¥ç»§æ‰¿ `Visualizer` 并实现 `draw` 接å£æ¥æ»¡è¶³è‡ªå·±çš„绘制需求。 - -### Visualizer - -`Visualizer` æ供了基础而通用的绘制功能,主è¦æŽ¥å£å¦‚下: - -**(1) ç»˜åˆ¶æ— å…³çš„åŠŸèƒ½æ€§æŽ¥å£** - -- [set_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.set_image) 设置原始图片数æ®ï¼Œé»˜è®¤è¾“å…¥å›¾ç‰‡æ ¼å¼ä¸º RGB -- [get_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.get_image) 获å–绘制åŽçš„ Numpy æ ¼å¼å›¾ç‰‡æ•°æ®ï¼Œé»˜è®¤è¾“å‡ºæ ¼å¼ä¸º RGB -- [show](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.show) å¯è§†åŒ– -- [register_task](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.register_task) 注册绘制函数(其作用在 *自定义 Visualizer* å°èŠ‚æè¿°) - -**(2) 绘制相关接å£** - -- [draw](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw) ç”¨æˆ·ä½¿ç”¨çš„æŠ½è±¡ç»˜åˆ¶æŽ¥å£ -- [draw_featmap](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_featmap) 绘制特å¾å›¾ -- [draw_bboxes](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_bboxes) 绘制å•ä¸ªæˆ–者多个边界框 -- [draw_texts](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_texts) 绘制å•ä¸ªæˆ–者多个文本框 -- [draw_lines](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.lines) 绘制å•ä¸ªæˆ–者多个线段 -- [draw_circles](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_circles) 绘制å•ä¸ªæˆ–者多个圆 -- [draw_polygons](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_polygons) 绘制å•ä¸ªæˆ–者多个多边形 -- [draw_binary_masks](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_binary_mask) 绘制å•ä¸ªæˆ–者多个二值掩ç - -用户除了å¯ä»¥å•ç‹¬è°ƒç”¨ `Visualizer` ä¸åŸºç¡€ç»˜åˆ¶æŽ¥å£ï¼ŒåŒæ—¶ä¹Ÿæ供了链å¼è°ƒç”¨åŠŸèƒ½å’Œç‰¹å¾å›¾å¯è§†åŒ–功能。`draw` 函数是抽象接å£ï¼Œå†…部没有任何实现,继承了 Visualizer çš„ç±»å¯ä»¥å®žçŽ°è¯¥æŽ¥å£ï¼Œä»Žè€Œå¯¹å¤–æ供统一的绘制功能,而 `draw_xxx` ç‰ç›®çš„是æä¾›æœ€åŸºç¡€çš„ç»˜åˆ¶åŠŸèƒ½ï¼Œç”¨æˆ·ä¸€èˆ¬æ— éœ€é‡å†™ã€‚ - -**(1) 链å¼è°ƒç”¨** - -例如用户先绘制边界框,在æ¤åŸºç¡€ä¸Šç»˜åˆ¶æ–‡æœ¬ï¼Œç»˜åˆ¶çº¿æ®µï¼Œåˆ™è°ƒç”¨è¿‡ç¨‹ä¸ºï¼š - -```python -visualizer.set_image(image) -visualizer.draw_bboxes(...).draw_texts(...).draw_lines(...) -visualizer.show() # å¯è§†åŒ–绘制结果 -``` - -**(2) å¯è§†åŒ–特å¾å›¾** - -特å¾å›¾å¯è§†åŒ–是一个常è§çš„功能,通过调用 `draw_featmap` å¯ä»¥ç›´æŽ¥å¯è§†åŒ–特å¾å›¾ï¼Œå…¶å‚数定义为: - -```python -@staticmethod -def draw_featmap(tensor_chw: torch.Tensor, # è¾“å…¥æ ¼å¼è¦æ±‚为 CHW - image: Optional[np.ndarray] = None, # 如果åŒæ—¶è¾“入了 image æ•°æ®ï¼Œåˆ™ç‰¹å¾å›¾ä¼šå åŠ åˆ° image 上绘制 - mode: Optional[str] = 'mean', # 多个通é“压缩为å•é€šé“çš„ç–ç•¥ - topk: int = 10, # å¯é€‰æ‹©æ¿€æ´»åº¦æœ€é«˜çš„ topk 个特å¾å›¾æ˜¾ç¤º - arrangement: Tuple[int, int] = (5, 2), # 多通é“å±•å¼€ä¸ºå¤šå¼ å›¾æ—¶å€™å¸ƒå±€ - alpha: float = 0.3) -> np.ndarray: # 图片和特å¾å›¾ç»˜åˆ¶çš„å åŠ æ¯”ä¾‹ -``` - -特å¾å›¾å¯è§†åŒ–功能较多,目å‰ä¸æ”¯æŒ Batch 输入 - -- mode ä¸æ˜¯ None,topk æ— æ•ˆï¼Œä¼šå°†å¤šä¸ªé€šé“输出采用 mode 模å¼å‡½æ•°åŽ‹ç¼©ä¸ºå•é€šé“,å˜æˆå•å¼ å›¾ç‰‡æ˜¾ç¤ºï¼Œç›®å‰ mode ä»…æ”¯æŒ Noneã€'mean'ã€'max' å’Œ 'min' å‚数输入 -- mode 是 None,topk 有效,如果 topk ä¸æ˜¯ -1,则会按照激活度排åºé€‰æ‹© topk 个通é“显示,æ¤æ—¶å¯ä»¥é€šè¿‡ arrangement å‚数指定显示的布局 -- mode 是 None,topk 有效,如果 `topk = -1`,æ¤æ—¶é€šé“ C 必须是 1 或者 3 表示输入数æ®æ˜¯å›¾ç‰‡ï¼Œå¯ä»¥ç›´æŽ¥æ˜¾ç¤ºï¼Œå¦åˆ™æŠ¥é”™æ示用户应该设置 mode æ¥åŽ‹ç¼©é€šé“ - -```python -featmap=visualizer.draw_featmap(tensor_chw,image) -``` - -### 自定义 Visualizer - -自定义的 Visualizer ä¸å¤§éƒ¨åˆ†æƒ…况下åªéœ€è¦å®žçŽ° `get_image` å’Œ `draw` 接å£ã€‚`draw` 是最高层的用户调用接å£ï¼Œ`draw` 接å£è´Ÿè´£æ‰€æœ‰ç»˜åˆ¶åŠŸèƒ½ï¼Œä¾‹å¦‚绘制检测框ã€æ£€æµ‹æŽ©ç mask å’Œ 检测è¯ä¹‰åˆ†å‰²å›¾ç‰ç‰ã€‚ä¾æ®ä»»åŠ¡çš„ä¸åŒï¼Œ`draw` 接å£å®žçŽ°çš„å¤æ‚度也ä¸åŒã€‚ - -ä»¥ç›®æ ‡æ£€æµ‹å¯è§†åŒ–需求为例,å¯èƒ½éœ€è¦åŒæ—¶ç»˜åˆ¶è¾¹ç•Œæ¡† bboxã€æŽ©ç mask å’Œè¯ä¹‰åˆ†å‰²å›¾ seg_map,如果如æ¤å¤šåŠŸèƒ½å…¨éƒ¨å†™åˆ° `draw` 方法ä¸ä¼šéš¾ä»¥ç†è§£å’Œç»´æŠ¤ã€‚为了解决该问题,`Visualizer` 基于 OpenMMLab 2.0 抽象数æ®æŽ¥å£è§„范支æŒäº† `register_task` 函数。å‡è®¾ MMDetection ä¸éœ€è¦åŒæ—¶ç»˜åˆ¶é¢„测结果ä¸çš„ instances å’Œ sem_seg,å¯ä»¥åœ¨ MMDetection çš„ `DetVisualizer` ä¸å®žçŽ° `draw_instances` å’Œ `draw_sem_seg` 两个方法,用于绘制预测实例和预测è¯ä¹‰åˆ†å‰²å›¾ï¼Œ 我们希望åªè¦è¾“入数æ®ä¸å˜åœ¨ instances 或 sem_seg 时候,对应的两个绘制函数 `draw_instances` å’Œ `draw_sem_seg` 能够自动被调用,而用户ä¸éœ€è¦æ‰‹åŠ¨è°ƒç”¨ã€‚为了实现上述功能,å¯ä»¥é€šè¿‡åœ¨ `draw_instances` å’Œ `draw_sem_seg` ä¸¤ä¸ªå‡½æ•°åŠ ä¸Š `@Visualizer.register_task` 装饰器,æ¤æ—¶ `task_dict` ä¸å°±ä¼šå˜å‚¨å—ç¬¦ä¸²å’Œå‡½æ•°çš„æ˜ å°„å…³ç³»ï¼Œåœ¨è°ƒç”¨ `draw` 方法时候就å¯ä»¥é€šè¿‡ `self.task_dict`获å–到已ç»è¢«æ³¨å†Œçš„函数。一个简略的实现如下所示 - -```python -class DetVisualizer(Visualizer): - - def draw(self, image, gt_sample=None, pred_sample=None, draw_gt=True, draw_pred=True): - # 将图片和 matplotlib å¸ƒå±€å…³è” - self.set_image(image) - - if draw_gt: - # self.task_dict 内部å˜å‚¨å¦‚下信æ¯ï¼š - # dict(instances=draw_instance 方法,sem_seg=draw_sem_seg 方法) - for task in self.task_dict: - task_attr = 'gt_' + task - if task_attr in gt_sample: - self.task_dict[task](self, gt_sample[task_attr], 'gt') - if draw_pred: - for task in self.task_dict: - task_attr = 'pred_' + task - if task_attr in pred_sample: - self.task_dict[task](self, pred_sample[task_attr], 'pred') - - # data_type 用于区分当å‰ç»˜åˆ¶çš„å†…å®¹æ˜¯æ ‡æ³¨è¿˜æ˜¯é¢„æµ‹ç»“æžœ - @Visualizer.register_task('instances') - def draw_instance(self, instances, data_type): - ... - - # data_type 用于区分当å‰ç»˜åˆ¶çš„å†…å®¹æ˜¯æ ‡æ³¨è¿˜æ˜¯é¢„æµ‹ç»“æžœ - @Visualizer.register_task('sem_seg') - def draw_sem_seg(self, pixel_data, data_type): - ... -``` - -注æ„:是å¦ä½¿ç”¨ `register_task` 装饰器函数ä¸æ˜¯å¿…须的,如果用户自定义 Visualizer,并且 `draw` 实现éžå¸¸ç®€å•ï¼Œåˆ™æ— 需考虑 `register_task`。 - -在使用 Jupyter notebook 或者其他地方ä¸éœ€è¦å†™æ•°æ®åˆ°æŒ‡å®šåŽç«¯çš„情形下,用户å¯ä»¥è‡ªå·±å®žä¾‹åŒ– visualizer。一个简å•çš„例å如下 - -```python -# 实例化 visualizer -visualizer=dict(type='DetVisualizer') -visualizer = VISUALIZERS.build(visualizer) -visualizer.draw(image, datasample) -visualizer.show() # å¯è§†åŒ–绘制结果 -``` - -## 写端 Writer - -Visualizer åªå®žçŽ°äº†å•å¼ 图片的绘制功能,但是在è®ç»ƒæˆ–者测试过程ä¸ï¼Œå¯¹ä¸€äº›å…³é”®æŒ‡æ ‡æˆ–者模型è®ç»ƒè¶…å‚的记录éžå¸¸é‡è¦ï¼Œæ¤åŠŸèƒ½é€šè¿‡å†™ç«¯ Writer 实现。为了统一接å£è°ƒç”¨ï¼ŒMMEngine æ供了统一的抽象类 `BaseWriter`,和一些常用的 Writer 如 `LocalWriter` ã€`TensorboardWriter` å’Œ `WandbWriter` 。 - -### BaseWriter - -BaseWriter 定义了对外调用的接å£è§„范,主è¦æŽ¥å£å’Œå±žæ€§å¦‚下: - -- [add_params](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_params) 写超å‚到特定åŽç«¯ï¼Œå¸¸è§çš„è®ç»ƒè¶…å‚如åˆå§‹å¦ä¹ 率 LRã€æƒé‡è¡°å‡ç³»æ•°å’Œæ‰¹å¤§å°ç‰ç‰ -- [add_graph](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_graph) 写模型图到特定åŽç«¯ -- [add_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_image) 写图片到特定åŽç«¯ -- [add_scalar](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_scalar) å†™æ ‡é‡åˆ°ç‰¹å®šåŽç«¯ -- [add_scalars](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_scalars) ä¸€æ¬¡æ€§å†™å¤šä¸ªæ ‡é‡åˆ°ç‰¹å®šåŽç«¯ -- [visualizer](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.visualizer) 绘制对象 -- [experiment](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.experiment) 写åŽç«¯å¯¹è±¡ï¼Œä¾‹å¦‚ Wandb 对象和 Tensorboard 对象 - -`BaseWriter` 定义了 5 个常è§çš„写数æ®æŽ¥å£ï¼Œè€ƒè™‘到æŸäº›å†™åŽç«¯åŠŸèƒ½éžå¸¸å¼ºå¤§ï¼Œä¾‹å¦‚ Wandbï¼Œå…¶å…·å¤‡å†™è¡¨æ ¼ï¼Œå†™è§†é¢‘ç‰ç‰åŠŸèƒ½ï¼Œé’ˆå¯¹è¿™ç±»éœ€æ±‚用户å¯ä»¥ç›´æŽ¥èŽ·å– experiment 对象,然åŽè°ƒç”¨å†™åŽç«¯å¯¹è±¡æœ¬èº«çš„ API å³å¯ã€‚ - -### LocalWriterã€TensorboardWriter å’Œ WandbWriter - -`LocalWriter` æ供了将数æ®å†™å…¥åˆ°æœ¬åœ°ç£ç›˜åŠŸèƒ½ã€‚如果用户需è¦å†™å›¾ç‰‡åˆ°ç¡¬ç›˜ï¼Œåˆ™**å¿…é¡»è¦é€šè¿‡åˆå§‹åŒ–å‚æ•°æä¾› Visualizer对象**。其典型用法为: - -```python -# é…置文件 -writer=dict(type='LocalWriter', save_dir='demo_dir', visualizer=dict(type='DetVisualizer')) -# 实例化和调用 -local_writer=WRITERS.build(writer) -# 写模型精度值 -local_writer.add_scalar('mAP', 0.9) -local_writer.add_scalars({'loss': 1.2, 'acc': 0.8}) -# å†™è¶…å‚ -local_writer.add_params(dict(lr=0.1, mode='linear')) -# 写图片 -local_writer.add_image('demo_image', image, datasample) -``` - -如果用户有自定义绘制需求,则å¯ä»¥é€šè¿‡èŽ·å–内部的 visualizer 属性æ¥å®žçŽ°ï¼Œå¦‚下所示 - -```python -# é…置文件 -writer=dict(type='LocalWriter', save_dir='demo_dir', visualizer=dict(type='DetVisualizer')) -# 实例化和调用 -local_writer=WRITERS.build(writer) -# 写图片 -local_writer.visualizer.draw_bboxes(np.array([0, 0, 1, 1])) -local_writer.add_image('img', local_writer.visualizer.get_image()) - -# 绘制特å¾å›¾å¹¶ä¿å˜åˆ°æœ¬åœ° -featmap_image=local_writer.visualizer.draw_featmap(tensor_chw) -local_writer.add_image('featmap', featmap_image) -``` - -`TensorboardWriter` æ供了将å„类数æ®å†™å…¥åˆ° Tensorboard 功能,其用法和 LocalWriter éžå¸¸ç±»ä¼¼ã€‚ 注æ„如果用户需è¦å†™å›¾ç‰‡åˆ° Tensorboard,则**å¿…é¡»è¦é€šè¿‡åˆå§‹åŒ–å‚æ•°æä¾› Visualizer对象**。 - -`WandbWriter` æ供了将å„类数æ®å†™å…¥åˆ° Wandb 功能。考虑到 Wandb 本身具备强大的图片功能,在调用 `WandbWriter` çš„ `add_image` 方法时 Visualizer 对象是å¯é€‰çš„,如果用户指定了 Visualizer 对象,则会调用 Visualizer 对象的绘制方法,å¦åˆ™ç›´æŽ¥è°ƒç”¨ Wandb 自带的图片处ç†åŠŸèƒ½ã€‚ - -## 组åˆå†™ç«¯ ComposedWriter - -考虑到在è®ç»ƒæˆ–者测试过程ä¸ï¼Œå¯èƒ½éœ€è¦åŒæ—¶è°ƒç”¨å¤šä¸ª Writer,例如想åŒæ—¶å†™åˆ°æœ¬åœ°å’Œ Wandb 端,为æ¤è®¾è®¡äº†å¯¹å¤–çš„ `ComposedWriter` 类,在è®ç»ƒæˆ–è€…æµ‹è¯•è¿‡ç¨‹ä¸ `ComposedWriter` 会ä¾æ¬¡è°ƒç”¨å„个 Writer 的接å£ï¼Œå…¶æŽ¥å£å’Œ `BaseWriter` 一致,主è¦æŽ¥å£å¦‚下: - -- [add_params](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.add_params) 写超å‚到所有已ç»åŠ 入的åŽç«¯ä¸ï¼Œå¸¸è§çš„è®ç»ƒè¶…å‚如åˆå§‹å¦ä¹ 率 LRã€æƒé‡è¡°å‡ç³»æ•°å’Œæ‰¹å¤§å°ç‰ç‰ -- [add_graph](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.add_graph) 写模型图到所有已ç»åŠ 入的åŽç«¯ä¸ -- [add_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.add_image) 写图片到所有已ç»åŠ 入的åŽç«¯ä¸ -- [add_scalar](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.add_scalar) å†™æ ‡é‡åˆ°æ‰€æœ‰å·²ç»åŠ 入的åŽç«¯ä¸ -- [add_scalars](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.add_scalars) ä¸€æ¬¡æ€§å†™å¤šä¸ªæ ‡é‡åˆ°æ‰€æœ‰å·²ç»åŠ 入的åŽç«¯ä¸ -- [get_writer](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.get_writer) 获å–指定索引的 Writer,任何一个 Writer ä¸åŒ…括了 experiment å’Œ visualizer 属性 -- [get_experiment](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.get_experiment) 获å–指定索引的 experiment -- [get_visualizer](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.get_visualizer) 获å–指定索引的 visualizer -- [close](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.close) 调用所有 Writer çš„ close 方法 - -为了让用户å¯ä»¥åœ¨ä»£ç çš„ä»»æ„ä½ç½®è¿›è¡Œæ•°æ®å¯è§†åŒ–,`ComposedWriter` 类继承至 [全局å¯è®¿é—®åŸºç±» BaseGlobalAccessible](./logging.md/#全局å¯è®¿é—®åŸºç±»baseglobalaccessible)。一旦继承了全局å¯è®¿é—®åŸºç±», 用户就å¯ä»¥é€šè¿‡è°ƒç”¨ `ComposedWriter` 对象的 `get_instance` æ¥èŽ·å–全局对象。其基本用法如下 - -```python -# 创建实例 -writers=[dict(type='LocalWriter', save_dir='temp_dir', visualizer=dict(type='DetVisualizer')), dict(type='WandbWriter')] - -ComposedWriter.create_instance('composed_writer', writers=writers) -``` - -一旦创建实例åŽï¼Œå¯ä»¥åœ¨ä»£ç ä»»æ„ä½ç½®èŽ·å– `ComposedWriter` 对象 - -```python -composed_writer=ComposedWriter.get_instance('composed_writer') - -# 写模型精度值 -composed_writer.add_scalar('mAP', 0.9) -composed_writer.add_scalars({'loss': 1.2, 'acc': 0.8}) -# å†™è¶…å‚ -composed_writer.add_params(dict(lr=0.1, mode='linear')) -# 写图片 -composed_writer.add_image('demo_image', image, datasample) -# 写模型图 -composed_writer.add_graph(model, input_array) -``` - -对于一些用户需è¦çš„自定义绘制需求或者上述接å£æ— 法满足的需求,用户å¯ä»¥é€šè¿‡ `get_xxx` 方法获å–具体对象æ¥å®žçŽ°ç‰¹å®šéœ€æ±‚ - -```python -composed_writer=ComposedWriter.get_instance('composed_writer') - -# 绘制特å¾å›¾ï¼ŒèŽ·å– LocalWriter ä¸çš„ visualizer -visualizer=composed_writer.get_visualizer(0) -featmap_image=visualizer.draw_featmap(tensor_chw) -composed_writer.add_image('featmap', featmap_image) - -# 扩展 add 功能,例如利用 Wandb å¯¹è±¡ç»˜åˆ¶è¡¨æ ¼ -wandb=composed_writer.get_experiment(1) -val_table = wandb.Table(data=my_data, columns=column_names) -wandb.log({'my_val_table': val_table}) - -# é…ç½®ä¸å˜åœ¨å¤šä¸ª Writer,在ä¸æƒ³æ”¹åŠ¨é…置情况下åªä½¿ç”¨ LocalWriter -local_writer=composed_writer.get_writer(0) -local_writer.add_image('demo_image', image, datasample) -``` diff --git a/mmengine/data/__init__.py b/mmengine/data/__init__.py index b867465c3b522022a81868026e3d1f33fbc1fc9f..801b839b30e6c32c44a595f345cb5feac150d5d0 100644 --- a/mmengine/data/__init__.py +++ b/mmengine/data/__init__.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base_data_element import BaseDataElement +from .instance_data import InstanceData from .sampler import DefaultSampler, InfiniteSampler from .utils import pseudo_collate, worker_init_fn __all__ = [ 'BaseDataElement', 'DefaultSampler', 'InfiniteSampler', 'worker_init_fn', - 'pseudo_collate' + 'pseudo_collate', 'InstanceData' ] diff --git a/mmengine/data/instance_data.py b/mmengine/data/instance_data.py new file mode 100644 index 0000000000000000000000000000000000000000..2c4932f7e3fe6bd165aa8092a678cbf5695fcdf9 --- /dev/null +++ b/mmengine/data/instance_data.py @@ -0,0 +1,209 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import itertools +from typing import List, Union + +import numpy as np +import torch + +from .base_data_element import BaseDataElement + +IndexType = Union[str, slice, int, torch.LongTensor, torch.cuda.LongTensor, + torch.BoolTensor, torch.cuda.BoolTensor, np.long, np.bool] + + +# Modified from +# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa +class InstanceData(BaseDataElement): + """Data structure for instance-level annnotations or predictions. + + Subclass of :class:`BaseDataElement`. All value in `data_fields` + should have the same length. This design refer to + https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501 + + Examples: + >>> from mmengine.data import InstanceData + >>> import numpy as np + >>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) + >>> instance_data = InstanceData(metainfo=img_meta) + >>> 'img_shape' in instance_data + True + >>> instance_data.det_labels = torch.LongTensor([2, 3]) + >>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7]) + >>> instance_data.bboxes = torch.rand((2, 4)) + >>> len(instance_data) + 4 + >>> print(instance_data) + <InstanceData( + + META INFORMATION + pad_shape: (800, 1196, 3) + img_shape: (800, 1216, 3) + + DATA FIELDS + det_labels: tensor([2, 3]) + det_scores: tensor([0.8, 0.7000]) + bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188], + [0.8101, 0.3105, 0.5123, 0.6263]]) + ) at 0x7fb492de6280> + >>> sorted_results = instance_data[instance_data.det_scores.sort().indices] + >>> sorted_results.det_scores + tensor([0.7000, 0.8000]) + >>> print(instance_data[instance_data.det_scores > 0.75]) + <InstanceData( + + META INFORMATION + pad_shape: (800, 1216, 3) + img_shape: (800, 1196, 3) + + DATA FIELDS + det_labels: tensor([0]) + bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188]]) + det_scores: tensor([0.8000]) + ) at 0x7fb5cf6e2790> + >>> instance_data[instance_data.det_scores > 0.75].det_labels + tensor([0]) + >>> instance_data[instance_data.det_scores > 0.75].det_scores + tensor([0.8000]) + """ + + def __setattr__(self, name: str, value: Union[torch.Tensor, np.ndarray, + list]): + if name in ('_metainfo_fields', '_data_fields'): + if not hasattr(self, name): + super().__setattr__(name, value) + else: + raise AttributeError( + f'{name} has been used as a ' + f'private attribute, which is immutable. ') + + else: + assert isinstance(value, (torch.Tensor, np.ndarray, list)), \ + f'Can set {type(value)}, only support' \ + f' {(torch.Tensor, np.ndarray, list)}' + + if len(self) > 0: + assert len(value) == len(self), f'the length of ' \ + f'values {len(value)} is ' \ + f'not consistent with' \ + f' the length of this ' \ + f':obj:`InstanceData` ' \ + f'{len(self)} ' + super().__setattr__(name, value) + + def __getitem__(self, item: IndexType) -> 'InstanceData': + """ + Args: + item (str, obj:`slice`, + obj`torch.LongTensor`, obj:`torch.BoolTensor`): + get the corresponding values according to item. + + Returns: + obj:`InstanceData`: Corresponding values. + """ + assert len(self) > 0, ' This is a empty instance' + + assert isinstance( + item, (str, slice, int, torch.LongTensor, torch.cuda.LongTensor, + torch.BoolTensor, torch.cuda.BoolTensor, np.bool, np.long)) + + if isinstance(item, str): + return getattr(self, item) + + if type(item) == int: + if item >= len(self) or item < -len(self): # type:ignore + raise IndexError(f'Index {item} out of range!') + else: + # keep the dimension + item = slice(item, None, len(self)) + + new_data = self.new(data={}) + if isinstance(item, torch.Tensor): + assert item.dim() == 1, 'Only support to get the' \ + ' values along the first dimension.' + if isinstance(item, (torch.BoolTensor, torch.cuda.BoolTensor)): + assert len(item) == len(self), f'The shape of the' \ + f' input(BoolTensor)) ' \ + f'{len(item)} ' \ + f' does not match the shape ' \ + f'of the indexed tensor ' \ + f'in results_filed ' \ + f'{len(self)} at ' \ + f'first dimension. ' + + for k, v in self.items(): + if isinstance(v, torch.Tensor): + new_data[k] = v[item] + elif isinstance(v, np.ndarray): + new_data[k] = v[item.cpu().numpy()] + elif isinstance(v, list): + r_list = [] + # convert to indexes from boolTensor + if isinstance(item, + (torch.BoolTensor, torch.cuda.BoolTensor)): + indexes = torch.nonzero(item).view(-1) + else: + indexes = item + for index in indexes: + r_list.append(v[index]) + new_data[k] = r_list + else: + # item is a slice + for k, v in self.items(): + new_data[k] = v[item] + return new_data # type:ignore + + @staticmethod + def cat(instances_list: List['InstanceData']) -> 'InstanceData': + """Concat the instances of all :obj:`InstanceData` in the list. + + Note: To ensure that cat returns as expected, make sure that + all elements in the list must have exactly the same keys. + + Args: + instances_list (list[:obj:`InstanceData`]): A list + of :obj:`InstanceData`. + + Returns: + obj:`InstanceData` + """ + assert all( + isinstance(results, InstanceData) for results in instances_list) + assert len(instances_list) > 0 + if len(instances_list) == 1: + return instances_list[0] + + # metainfo and data_fields must be exactly the + # same for each element to avoid exceptions. + field_keys_list = [ + instances.all_keys() for instances in instances_list + ] + assert len(set([len(field_keys) for field_keys in field_keys_list])) \ + == 1 and len(set(itertools.chain(*field_keys_list))) \ + == len(field_keys_list[0]), 'There are different keys in ' \ + '`instances_list`, which may ' \ + 'cause the cat operation ' \ + 'to fail. Please make sure all ' \ + 'elements in `instances_list` ' \ + 'have the exact same key ' + + new_data = instances_list[0].new(data={}) + for k in instances_list[0].keys(): + values = [results[k] for results in instances_list] + v0 = values[0] + if isinstance(v0, torch.Tensor): + values = torch.cat(values, dim=0) + elif isinstance(v0, np.ndarray): + values = np.concatenate(values, axis=0) + elif isinstance(v0, list): + values = list(itertools.chain(*values)) + else: + raise ValueError( + f'Can not concat the {k} which is a {type(v0)}') + new_data[k] = values + return new_data # type:ignore + + def __len__(self) -> int: + if len(self._data_fields) > 0: + return len(self.values()[0]) + else: + return 0 diff --git a/mmengine/data/sampler.py b/mmengine/data/sampler.py index 47b2c3b4c1e3e5ae5a5f1941263f598c3a6af2f2..ff1d13ec9954e315fddd1005ce02488ccce50fa1 100644 --- a/mmengine/data/sampler.py +++ b/mmengine/data/sampler.py @@ -2,18 +2,13 @@ import itertools import math from typing import Iterator, Optional, Sized -# from mmengine.dist import get_dist_info, sync_random_seed -from unittest.mock import MagicMock import torch from torch.utils.data import Sampler +from mmengine.dist import get_dist_info, sync_random_seed from mmengine.registry import DATA_SAMPLERS -# TODO, need to remove those lines after implementing dist module -get_dist_info = MagicMock(return_value=(0, 1)) -sync_random_seed = MagicMock(return_value=0) - @DATA_SAMPLERS.register_module() class DefaultSampler(Sampler): diff --git a/mmengine/data/utils.py b/mmengine/data/utils.py index 0f569d39803a4f24a6564eb7301749a7ae5b2dc6..c284a336dd47f2cb8b56de89dd6637e763a20219 100644 --- a/mmengine/data/utils.py +++ b/mmengine/data/utils.py @@ -1,13 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import random -from typing import Any, Sequence, Tuple +from typing import Sequence import numpy as np import torch -from .base_data_element import BaseDataElement - -DATA_BATCH = Sequence[Tuple[Any, BaseDataElement]] +DATA_BATCH = Sequence[dict] def worker_init_fn(worker_id: int, num_workers: int, rank: int, @@ -36,10 +34,10 @@ def pseudo_collate(data_batch: DATA_BATCH) -> DATA_BATCH: nothing just returns ``data_batch``. Args: - data_batch (Sequence[Tuple[Any, BaseDataElement]]): Batch of data from + data_batch (Sequence[dict]): Batch of data from dataloader. Returns: - Sequence[Tuple[Any, BaseDataElement]]: Return input ``data_batch``. + Sequence[dict]: Return input ``data_batch``. """ return data_batch diff --git a/mmengine/evaluator/evaluator.py b/mmengine/evaluator/evaluator.py index c653fb563aa90ce447ad0be964b14901b9aed657..34bb02e51a4953a31eb5d930de23ef67f339f7de 100644 --- a/mmengine/evaluator/evaluator.py +++ b/mmengine/evaluator/evaluator.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Iterator, List, Optional, Sequence, Tuple, Union +from typing import Iterator, List, Optional, Sequence, Union from mmengine.data import BaseDataElement from ..registry.root import METRICS @@ -37,23 +37,25 @@ class Evaluator: for metric in self.metrics: metric.dataset_meta = dataset_meta - def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]], + def process(self, data_batch: Sequence[dict], predictions: Sequence[BaseDataElement]): """Convert ``BaseDataSample`` to dict and invoke process method of each metric. Args: - data_batch (Sequence[Tuple[Any, BaseDataElement]]): A batch of data - from the dataloader. + data_batch (Sequence[dict]): A batch of data from the dataloader. predictions (Sequence[BaseDataElement]): A batch of outputs from the model. """ _data_batch = [] - for input, data in data_batch: - if isinstance(data, BaseDataElement): - _data_batch.append((input, data.to_dict())) + for data in data_batch: + if isinstance(data['data_sample'], BaseDataElement): + _data_batch.append( + dict( + inputs=data['inputs'], + data_sample=data['data_sample'].to_dict())) else: - _data_batch.append((input, data)) + _data_batch.append(data) _predictions = [] for pred in predictions: if isinstance(pred, BaseDataElement): diff --git a/mmengine/evaluator/metric.py b/mmengine/evaluator/metric.py index e8a71488023a344792acfc22726c3b9582e9e276..4bcf163be41c7f625aacc879675742286a0b6e1d 100644 --- a/mmengine/evaluator/metric.py +++ b/mmengine/evaluator/metric.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import warnings from abc import ABCMeta, abstractmethod -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Union from mmengine.dist import (broadcast_object_list, collect_results, is_main_process) @@ -50,15 +50,14 @@ class BaseMetric(metaclass=ABCMeta): self._dataset_meta = dataset_meta @abstractmethod - def process(self, data_batch: Sequence[Tuple[Any, dict]], + def process(self, data_batch: Sequence[dict], predictions: Sequence[dict]) -> None: """Process one batch of data samples and predictions. The processed results should be stored in ``self.results``, which will be used to compute the metrics when all batches have been processed. Args: - data_batch (Sequence[Tuple[Any, dict]]): A batch of data - from the dataloader. + data_batch (Sequence[dict]): A batch of data from the dataloader. predictions (Sequence[dict]): A batch of outputs from the model. """ diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index a373bbb5cad27e423c556ece9b045ac87bc3c3cb..017784b9a7cdfb49c81e1ae7a97840406ef59060 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -2,15 +2,14 @@ import os.path as osp import warnings from pathlib import Path -from typing import Any, Optional, Sequence, Tuple, Union +from typing import Optional, Sequence, Union -from mmengine.data import BaseDataElement from mmengine.dist import master_only from mmengine.fileio import FileClient from mmengine.registry import HOOKS from .hook import Hook -DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]] +DATA_BATCH = Optional[Sequence[dict]] @HOOKS.register_module() @@ -185,8 +184,8 @@ class CheckpointHook(Hook): Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data - from dataloader. Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. outputs (dict, optional): Outputs from model. Defaults to None. """ diff --git a/mmengine/hooks/empty_cache_hook.py b/mmengine/hooks/empty_cache_hook.py index c793f01b1ea62189e92459cb9b742dc26a985fe2..be6b5c2c3dded63b45c39c95fa71ec4e8d994e9e 100644 --- a/mmengine/hooks/empty_cache_hook.py +++ b/mmengine/hooks/empty_cache_hook.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Optional, Sequence, Tuple, Union +from typing import Optional, Sequence, Union import torch @@ -7,7 +7,7 @@ from mmengine.data import BaseDataElement from mmengine.registry import HOOKS from .hook import Hook -DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]] +DATA_BATCH = Optional[Sequence[dict]] @HOOKS.register_module() @@ -46,8 +46,8 @@ class EmptyCacheHook(Hook): Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data - from dataloader. Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. outputs (dict or sequence, optional): Outputs from model. Defaults to None. mode (str): Current mode of runner. Defaults to 'train'. diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py index 84060334f1bcd6bd382b51f0e04cb89882728c59..49995334c6ea6fa78464bc543482a56e5db0e47c 100644 --- a/mmengine/hooks/hook.py +++ b/mmengine/hooks/hook.py @@ -1,9 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Optional, Sequence, Tuple, Union +from typing import Optional, Sequence, Union from mmengine.data import BaseDataElement -DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]] +DATA_BATCH = Optional[Sequence[dict]] class Hook: @@ -174,8 +174,8 @@ class Hook: Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): - Data from dataloader. Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. """ self._before_iter( runner, batch_idx=batch_idx, data_batch=data_batch, mode='train') @@ -190,8 +190,8 @@ class Hook: Args: runner (Runner): The runner of the validation process. batch_idx (int): The index of the current batch in the val loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): - Data from dataloader. Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. """ self._before_iter( runner, batch_idx=batch_idx, data_batch=data_batch, mode='val') @@ -206,8 +206,8 @@ class Hook: Args: runner (Runner): The runner of the testing process. batch_idx (int): The index of the current batch in the test loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): - Data from dataloader. Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. """ self._before_iter( runner, batch_idx=batch_idx, data_batch=data_batch, mode='test') @@ -223,8 +223,8 @@ class Hook: Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): - Data from dataloader. Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. outputs (dict, optional): Outputs from model. Defaults to None. """ @@ -247,8 +247,8 @@ class Hook: Args: runner (Runner): The runner of the validation process. batch_idx (int): The index of the current batch in the val loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): - Data from dataloader. Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. outputs (dict or sequence, optional): Outputs from model. Defaults to None. """ @@ -271,8 +271,8 @@ class Hook: Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the test loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): - Data from dataloader. Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. outputs (dict, optional): Outputs from model. Defaults to None. """ @@ -317,8 +317,8 @@ class Hook: runner (Runner): The runner of the training, validation or testing process. batch_idx (int): The index of the current batch in the loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): - Data from dataloader. Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. mode (str): Current mode of runner. Defaults to 'train'. """ pass @@ -337,8 +337,8 @@ class Hook: runner (Runner): The runner of the training, validation or testing process. batch_idx (int): The index of the current batch in the loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): - Data from dataloader. Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. outputs (Sequence[BaseDataElement], optional): Outputs from model. Defaults to None. mode (str): Current mode of runner. Defaults to 'train'. @@ -387,19 +387,19 @@ class Hook: """ return (runner.iter + 1) % n == 0 if n > 0 else False - def end_of_epoch(self, runner, batch_idx: int) -> bool: + def end_of_epoch(self, dataloader, batch_idx: int) -> bool: """Check whether the current iteration reaches the last iteration of - current dataloader. + the dataloader. Args: - runner (Runner): The runner of the training, validation or testing - process. + dataloader (Dataloader): The dataloader of the training, + validation or testing process. batch_idx (int): The index of the current batch in the loop. Returns: bool: Whether reaches the end of current epoch or not. """ - return batch_idx + 1 == len(runner.cur_dataloader) + return batch_idx + 1 == len(dataloader) def is_last_train_epoch(self, runner) -> bool: """Test whether current epoch is the last train epoch. diff --git a/mmengine/hooks/iter_timer_hook.py b/mmengine/hooks/iter_timer_hook.py index d281745da4cc1de446f0897d45057dc53fe31cb2..8791dc96d86696a42eb0ec4773ca166dcf33f2a9 100644 --- a/mmengine/hooks/iter_timer_hook.py +++ b/mmengine/hooks/iter_timer_hook.py @@ -1,12 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import time -from typing import Any, Optional, Sequence, Tuple, Union +from typing import Optional, Sequence, Union from mmengine.data import BaseDataElement from mmengine.registry import HOOKS from .hook import Hook -DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]] +DATA_BATCH = Optional[Sequence[dict]] @HOOKS.register_module() @@ -37,8 +37,8 @@ class IterTimerHook(Hook): Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data - from dataloader. Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. mode (str): Current mode of runner. Defaults to 'train'. """ # TODO: update for new logging system @@ -57,8 +57,8 @@ class IterTimerHook(Hook): Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data - from dataloader. Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. outputs (dict or sequence, optional): Outputs from model. Defaults to None. mode (str): Current mode of runner. Defaults to 'train'. diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py index aed1d0e08e47863bcbd10d4beaf864f407964bfd..cd56624429d43eeaed52d4c78196ee3c52a6c9fa 100644 --- a/mmengine/hooks/logger_hook.py +++ b/mmengine/hooks/logger_hook.py @@ -5,18 +5,16 @@ import os import os.path as osp from collections import OrderedDict from pathlib import Path -from typing import Any, Optional, Sequence, Tuple, Union +from typing import Optional, Sequence, Union import torch -from mmengine.data import BaseDataElement -from mmengine.dist import master_only from mmengine.fileio import FileClient from mmengine.hooks import Hook from mmengine.registry import HOOKS from mmengine.utils import is_tuple_of, scandir -DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]] +DATA_BATCH = Optional[Sequence[dict]] @HOOKS.register_module() @@ -167,11 +165,7 @@ class LoggerHook(Hook): self.json_log_path = osp.join(runner.work_dir, f'{runner.timestamp}.log.json') - self.yaml_log_path = osp.join(runner.work_dir, - f'{runner.timestamp}.log.json') self.start_iter = runner.iter - if runner.meta is not None: - runner.writer.add_params(runner.meta, file_path=self.yaml_log_path) def after_train_iter(self, runner, @@ -183,15 +177,16 @@ class LoggerHook(Hook): Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[BaseDataElement], optional): Data from - dataloader. Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + Defaults to None. outputs (dict, optional): Outputs from model. Defaults to None. """ self._inner_iter = batch_idx if runner.meta is not None and 'exp_name' in runner.meta: if (self.every_n_iters(runner, self.interval_exp_name)) or ( - self.by_epoch and self.end_of_epoch(runner, batch_idx)): + self.by_epoch and self.end_of_epoch( + runner.train_loop.dataloader, batch_idx)): exp_info = f'Exp name: {runner.meta["exp_name"]}' runner.logger.info(exp_info) if self.by_epoch and self.every_n_inner_iters(batch_idx, @@ -199,7 +194,8 @@ class LoggerHook(Hook): self._log_train(runner) elif not self.by_epoch and self.every_n_iters(runner, self.interval): self._log_train(runner) - elif self.end_of_epoch(runner, batch_idx) and not self.ignore_last: + elif self.end_of_epoch(runner.train_loop.dataloader, + batch_idx) and not self.ignore_last: # `runner.max_iters` may not be divisible by `self.interval`. if # `self.ignore_last==True`, the log of remaining iterations will # be recorded (Epoch [4][1000/1007], the logs of 998-1007 @@ -238,7 +234,6 @@ class LoggerHook(Hook): runner.logger.info((f'{local_filepath} was removed due to the ' '`self.keep_local=False`')) - @master_only def _log_train(self, runner) -> None: """Collect and record training logs which start named with "train/*". @@ -271,9 +266,9 @@ class LoggerHook(Hook): # by iter: Iter [100/100000] if self.by_epoch: log_str = f'Epoch [{cur_epoch}]' \ - f'[{cur_iter}/{len(runner.cur_dataloader)}]\t' + f'[{cur_iter}/{len(runner.train_loop.dataloader)}] ' else: - log_str = f'Iter [{cur_iter}/{runner.train_loop.max_iters}]\t' + log_str = f'Iter [{cur_iter}/{runner.train_loop.max_iters}] ' log_str += f'{lr_momentum_str}, ' # Calculate eta time. self.time_sec_tot += (tag['time'] * self.interval) @@ -299,10 +294,9 @@ class LoggerHook(Hook): log_str += ', '.join(log_items) runner.logger.info(log_str) # Write logs to local, tensorboad, and wandb. - runner.writer.add_scalars( + runner.visualizer.add_scalars( tag, step=runner.iter + 1, file_path=self.json_log_path) - @master_only def _log_val(self, runner) -> None: """Collect and record training logs which start named with "val/*". @@ -311,7 +305,7 @@ class LoggerHook(Hook): """ tag = self._collect_info(runner, 'val') # Compatible with function `log` https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/logger/text.py # noqa E501 - eval_iter = len(runner.cur_dataloader) + eval_iter = len(runner.val_loop.dataloader) cur_iter = self._get_iter(runner) cur_epoch = self._get_epoch(runner, 'val') # val/test time @@ -320,9 +314,9 @@ class LoggerHook(Hook): # by iter: Iter[val] [1000] if self.by_epoch: # runner.epoch += 1 has been done before val workflow - log_str = f'Epoch(val) [{cur_epoch}][{eval_iter}]\t' + log_str = f'Epoch(val) [{cur_epoch}][{eval_iter}] ' else: - log_str = f'Iter(val) [{eval_iter}]\t' + log_str = f'Iter(val) [{eval_iter}] ' log_items = [] for name, val in tag.items(): @@ -332,7 +326,7 @@ class LoggerHook(Hook): log_str += ', '.join(log_items) runner.logger.info(log_str) # Write tag. - runner.writer.add_scalars( + runner.visualizer.add_scalars( tag, step=cur_iter, file_path=self.json_log_path) def _get_window_size(self, runner, window_size: Union[int, str]) \ diff --git a/mmengine/hooks/naive_visualization_hook.py b/mmengine/hooks/naive_visualization_hook.py index 2e05fc5970dd8637a3487d99a984085a07077537..2819563a500274d83ffc5effa22666852628b3d5 100644 --- a/mmengine/hooks/naive_visualization_hook.py +++ b/mmengine/hooks/naive_visualization_hook.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp -from typing import Any, Optional, Sequence, Tuple +from typing import Optional, Sequence, Tuple import cv2 import numpy as np @@ -11,6 +11,8 @@ from mmengine.registry import HOOKS from mmengine.utils.misc import tensor2imgs +# TODO: Due to interface changes, the current class +# functions incorrectly @HOOKS.register_module() class NaiveVisualizationHook(Hook): """Show or Write the predicted results during the process of testing. @@ -41,26 +43,25 @@ class NaiveVisualizationHook(Hook): self, runner, batch_idx: int, - data_batch: Optional[Sequence[Tuple[Any, BaseDataElement]]] = None, + data_batch: Optional[Sequence[dict]] = None, outputs: Optional[Sequence[BaseDataElement]] = None) -> None: """Show or Write the predicted results. Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the test loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data + data_batch (Sequence[dict], optional): Data from dataloader. Defaults to None. outputs (Sequence[BaseDataElement], optional): Outputs from model. Defaults to None. """ if self.every_n_iters(runner, self._interval): - inputs, data_samples = data_batch # type: ignore - inputs = tensor2imgs(inputs, - **data_samples[0].get('img_norm_cfg', dict())) - for input, data_sample, output in zip( - inputs, - data_samples, # type: ignore - outputs): # type: ignore + for data, output in zip(data_batch, outputs): # type: ignore + input = data['inputs'] + data_sample = data['data_sample'] + input = tensor2imgs(input, + **data_sample.get('img_norm_cfg', + dict()))[0] # TODO We will implement a function to revert the augmentation # in the future. ori_shape = (data_sample.ori_width, data_sample.ori_height) @@ -69,5 +70,6 @@ class NaiveVisualizationHook(Hook): data_sample.get('scale', ori_shape)) origin_image = cv2.resize(input, ori_shape) name = osp.basename(data_sample.img_path) - runner.writer.add_image(name, origin_image, data_sample, - output, self.draw_gt, self.draw_pred) + runner.visualizer.add_datasample(name, origin_image, + data_sample, output, + self.draw_gt, self.draw_pred) diff --git a/mmengine/hooks/optimizer_hook.py b/mmengine/hooks/optimizer_hook.py index ff33b54a7de343bb2681398d2ca661a5e59f55c5..9107dbf02500e24d471271fd99a7fc1b29ad12fe 100644 --- a/mmengine/hooks/optimizer_hook.py +++ b/mmengine/hooks/optimizer_hook.py @@ -1,16 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. import logging -from typing import Any, List, Optional, Sequence, Tuple +from typing import List, Optional, Sequence import torch from torch.nn.parameter import Parameter from torch.nn.utils import clip_grad -from mmengine.data import BaseDataElement from mmengine.registry import HOOKS from .hook import Hook -DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]] +DATA_BATCH = Optional[Sequence[dict]] @HOOKS.register_module() @@ -77,10 +76,9 @@ class OptimizerHook(Hook): Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data - from dataloader. In order to keep this interface consistent - with other hooks, we keep ``data_batch`` here. - Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + In order to keep this interface consistent with other hooks, + we keep ``data_batch`` here. Defaults to None. outputs (dict, optional): Outputs from model. In order to keep this interface consistent with other hooks, we keep ``outputs`` here. Defaults to None. diff --git a/mmengine/hooks/param_scheduler_hook.py b/mmengine/hooks/param_scheduler_hook.py index 9522abcf8ae4321d16d6bed102d101c31449677e..c4e7af58df9b38978eedb4b45dc0b803efc0a964 100644 --- a/mmengine/hooks/param_scheduler_hook.py +++ b/mmengine/hooks/param_scheduler_hook.py @@ -1,11 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Optional, Sequence, Tuple +from typing import Optional, Sequence -from mmengine.data import BaseDataElement from mmengine.registry import HOOKS from .hook import Hook -DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]] +DATA_BATCH = Optional[Sequence[dict]] @HOOKS.register_module() @@ -25,10 +24,9 @@ class ParamSchedulerHook(Hook): Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data - from dataloader. In order to keep this interface consistent - with other hooks, we keep ``data_batch`` here. - Defaults to None. + data_batch (Sequence[dict], optional): Data from dataloader. + In order to keep this interface consistent with other hooks, + we keep ``data_batch`` here. Defaults to None. outputs (dict, optional): Outputs from model. In order to keep this interface consistent with other hooks, we keep ``data_batch`` here. Defaults to None. diff --git a/mmengine/hooks/sampler_seed_hook.py b/mmengine/hooks/sampler_seed_hook.py index eed3fa90d46f5abd98bec0bb2c60978c82795855..9ddfd7ab7aa8d657b327a336ac1c91a18ca1c82b 100644 --- a/mmengine/hooks/sampler_seed_hook.py +++ b/mmengine/hooks/sampler_seed_hook.py @@ -20,9 +20,17 @@ class DistSamplerSeedHook(Hook): Args: runner (Runner): The runner of the training process. """ - if hasattr(runner.cur_dataloader.sampler, 'set_epoch'): - # in case the data loader uses `SequentialSampler` in Pytorch - runner.cur_dataloader.sampler.set_epoch(runner.epoch) - elif hasattr(runner.cur_dataloader.batch_sampler.sampler, 'set_epoch'): + if hasattr(runner.train_loop.dataloader, 'sampler') and hasattr( + runner.train_loop.dataloader.sampler, 'set_epoch'): + # In case the` _SingleProcessDataLoaderIter` has no sampler, + # or data loader uses `SequentialSampler` in Pytorch. + runner.train_loop.dataloader.sampler.set_epoch(runner.epoch) + + elif hasattr(runner.train_loop.dataloader, + 'batch_sampler') and hasattr( + runner.train_loop.dataloader.batch_sampler.sampler, + 'set_epoch'): + # In case the` _SingleProcessDataLoaderIter` has no batch sampler. # batch sampler in pytorch warps the sampler as its attributes. - runner.cur_dataloader.batch_sampler.sampler.set_epoch(runner.epoch) + runner.train_loop.dataloader.batch_sampler.sampler.set_epoch( + runner.epoch) diff --git a/mmengine/logging/message_hub.py b/mmengine/logging/message_hub.py index f6f393a0c32b424a9c5fee406e2a7f337e46a074..f399883eca047e668768e8714709e00ee2cbe583 100644 --- a/mmengine/logging/message_hub.py +++ b/mmengine/logging/message_hub.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import copy from collections import OrderedDict from typing import Any, Optional, Union @@ -229,7 +228,8 @@ class MessageHub(ManagerMixin): Returns: OrderedDict: A copy of all runtime information. """ - return copy.deepcopy(self._runtime_info) + # return copy.deepcopy(self._runtime_info) + return self._runtime_info def get_scalar(self, key: str) -> HistoryBuffer: """Get ``HistoryBuffer`` instance by key. @@ -263,7 +263,10 @@ class MessageHub(ManagerMixin): if key not in self.runtime_info: raise KeyError(f'{key} is not found in Messagehub.log_buffers: ' f'instance name is: {MessageHub.instance_name}') - return copy.deepcopy(self._runtime_info[key]) + + # TODO: There are restrictions on objects that can be saved + # return copy.deepcopy(self._runtime_info[key]) + return self._runtime_info[key] def _get_valid_value(self, key: str, value: Union[torch.Tensor, np.ndarray, int, float]) \ diff --git a/mmengine/model/wrappers/data_parallel.py b/mmengine/model/wrappers/data_parallel.py index c2967ceae536a539ddef80810272f9381be80664..d31b009c3352baa15028d594b4114170563f8a83 100644 --- a/mmengine/model/wrappers/data_parallel.py +++ b/mmengine/model/wrappers/data_parallel.py @@ -9,6 +9,9 @@ from torch.nn.parallel.distributed import (DistributedDataParallel, from mmengine.registry import MODEL_WRAPPERS from mmengine.utils import TORCH_VERSION, digit_version +MODEL_WRAPPERS.register_module(module=DataParallel) +MODEL_WRAPPERS.register_module(module=DistributedDataParallel) + @MODEL_WRAPPERS.register_module() class MMDataParallel(DataParallel): diff --git a/mmengine/registry/__init__.py b/mmengine/registry/__init__.py index ead8cb0afd7ba9e4a4800006fa188b866c7ceb8f..56c65b80628eeae6d2a823205c0b1f1c1c45dae0 100644 --- a/mmengine/registry/__init__.py +++ b/mmengine/registry/__init__.py @@ -4,12 +4,12 @@ from .registry import Registry, build_from_cfg from .root import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, METRICS, MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS, - TRANSFORMS, VISUALIZERS, WEIGHT_INITIALIZERS, WRITERS) + TRANSFORMS, VISBACKENDS, VISUALIZERS, WEIGHT_INITIALIZERS) __all__ = [ 'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', 'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS', - 'METRICS', 'MODEL_WRAPPERS', 'LOOPS', 'WRITERS', 'VISUALIZERS', + 'METRICS', 'MODEL_WRAPPERS', 'LOOPS', 'VISBACKENDS', 'VISUALIZERS', 'DefaultScope' ] diff --git a/mmengine/registry/registry.py b/mmengine/registry/registry.py index 3ee7d4d62a367f260a53c51470c9f9c72f7e14f8..bb0640456e30815414067034af2a86d6c5361f90 100644 --- a/mmengine/registry/registry.py +++ b/mmengine/registry/registry.py @@ -6,7 +6,7 @@ from collections.abc import Callable from typing import Any, Dict, List, Optional, Tuple, Type, Union from ..config import Config, ConfigDict -from ..utils import is_seq_of +from ..utils import ManagerMixin, is_seq_of from .default_scope import DefaultScope @@ -88,7 +88,13 @@ def build_from_cfg( f'type must be a str or valid type, but got {type(obj_type)}') try: - return obj_cls(**args) # type: ignore + # If `obj_cls` inherits from `ManagerMixin`, it should be instantiated + # by `ManagerMixin.get_instance` to ensure that it can be accessed + # globally. + if issubclass(obj_cls, ManagerMixin): # type: ignore + return obj_cls.get_instance(**args) # type: ignore + else: + return obj_cls(**args) # type: ignore except Exception as e: # Normal TypeError does not print class name. raise type(e)(f'{obj_cls.__name__}: {e}') # type: ignore diff --git a/mmengine/registry/root.py b/mmengine/registry/root.py index 571d55cbbc9454319945d465e3799edab1921c95..62d72f705039a6d9f04da79d3422a3089e27b67e 100644 --- a/mmengine/registry/root.py +++ b/mmengine/registry/root.py @@ -43,5 +43,5 @@ TASK_UTILS = Registry('task util') # manage visualizer VISUALIZERS = Registry('visualizer') -# manage writer -WRITERS = Registry('writer') +# manage visualizer backend +VISBACKENDS = Registry('vis_backend') diff --git a/mmengine/runner/base_loop.py b/mmengine/runner/base_loop.py index 0a0e3ca772415dd3e9f61ee13c4bbde523be4281..ec3f880dd42b539ecf484f55152525154292bb5c 100644 --- a/mmengine/runner/base_loop.py +++ b/mmengine/runner/base_loop.py @@ -25,9 +25,6 @@ class BaseLoop(metaclass=ABCMeta): else: self.dataloader = dataloader - # TODO, used by `end_of_epoch` of `Hook` - self._runner.data_loader = self.dataloader - @property def runner(self): return self._runner diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 4de30628915878a6eae17f433dd170ab7e409dbf..116560a9460be6a72d8cfff87e20d621b63f96a2 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Dict, List, Sequence, Tuple, Union +import warnings +from typing import Dict, List, Sequence, Union import torch from torch.utils.data import DataLoader -from mmengine.data import BaseDataElement from mmengine.evaluator import Evaluator from mmengine.registry import LOOPS from mmengine.utils import is_list_of @@ -27,6 +27,14 @@ class EpochBasedTrainLoop(BaseLoop): super().__init__(runner, dataloader) self._max_epochs = max_epochs self._max_iters = max_epochs * len(self.dataloader) + if hasattr(self.dataloader.dataset, 'metainfo'): + self.runner.visualizer.dataset_meta = \ + self.dataloader.dataset.metainfo + else: + warnings.warn( + f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' + 'metainfo. ``dataset_meta`` in visualizer will be ' + 'None.') @property def max_epochs(self): @@ -40,7 +48,6 @@ class EpochBasedTrainLoop(BaseLoop): def run(self) -> None: """Launch training.""" - self.runner.cur_dataloader = self.dataloader self.runner.call_hook('before_train') while self.runner._epoch < self._max_epochs: @@ -62,13 +69,11 @@ class EpochBasedTrainLoop(BaseLoop): self.runner.call_hook('after_train_epoch') self.runner.epoch += 1 - def run_iter(self, idx, - data_batch: Sequence[Tuple[Any, BaseDataElement]]) -> None: + def run_iter(self, idx, data_batch: Sequence[dict]) -> None: """Iterate one min-batch. Args: - data_batch (Sequence[Tuple[Any, BaseDataElement]]): Batch of data - from dataloader. + data_batch (Sequence[dict]): Batch of data from dataloader. """ self.runner.call_hook( 'before_train_iter', batch_idx=idx, data_batch=data_batch) @@ -103,6 +108,14 @@ class IterBasedTrainLoop(BaseLoop): max_iters: int) -> None: super().__init__(runner, dataloader) self._max_iters = max_iters + if hasattr(self.dataloader.dataset, 'metainfo'): + self.runner.visualizer.dataset_meta = \ + self.dataloader.dataset.metainfo + else: + warnings.warn( + f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' + 'metainfo. ``dataset_meta`` in visualizer will be ' + 'None.') self.dataloader = iter(self.dataloader) @property @@ -112,7 +125,6 @@ class IterBasedTrainLoop(BaseLoop): def run(self) -> None: """Launch training.""" - self.runner.cur_dataloader = self.dataloader self.runner.call_hook('before_train') # In iteration-based training loop, we treat the whole training process # as a big epoch and execute the corresponding hook. @@ -130,13 +142,11 @@ class IterBasedTrainLoop(BaseLoop): self.runner.call_hook('after_train_epoch') self.runner.call_hook('after_train') - def run_iter(self, data_batch: Sequence[Tuple[Any, - BaseDataElement]]) -> None: + def run_iter(self, data_batch: Sequence[dict]) -> None: """Iterate one mini-batch. Args: - data_batch (Sequence[Tuple[Any, BaseDataElement]]): Batch of data - from dataloader. + data_batch (Sequence[dict]): Batch of data from dataloader. """ self.runner.call_hook( 'before_train_iter', @@ -180,12 +190,19 @@ class ValLoop(BaseLoop): self.evaluator = runner.build_evaluator(evaluator) # type: ignore else: self.evaluator = evaluator # type: ignore - + if hasattr(self.dataloader.dataset, 'metainfo'): + self.evaluator.dataset_meta = self.dataloader.dataset.metainfo + self.runner.visualizer.dataset_meta = \ + self.dataloader.dataset.metainfo + else: + warnings.warn( + f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' + 'metainfo. ``dataset_meta`` in evaluator, metric and ' + 'visualizer will be None.') self.interval = interval def run(self): """Launch validation.""" - self.runner.cur_dataloader = self.dataloader self.runner.call_hook('before_val') self.runner.call_hook('before_val_epoch') self.runner.model.eval() @@ -201,11 +218,11 @@ class ValLoop(BaseLoop): self.runner.call_hook('after_val') @torch.no_grad() - def run_iter(self, idx, data_batch: Sequence[Tuple[Any, BaseDataElement]]): + def run_iter(self, idx, data_batch: Sequence[dict]): """Iterate one mini-batch. Args: - data_batch (Sequence[Tuple[Any, BaseDataElement]]): Batch of data + data_batch (Sequence[dict]): Batch of data from dataloader. """ self.runner.call_hook( @@ -239,10 +256,18 @@ class TestLoop(BaseLoop): self.evaluator = runner.build_evaluator(evaluator) # type: ignore else: self.evaluator = evaluator # type: ignore + if hasattr(self.dataloader.dataset, 'metainfo'): + self.evaluator.dataset_meta = self.dataloader.dataset.metainfo + self.runner.visualizer.dataset_meta = \ + self.dataloader.dataset.metainfo + else: + warnings.warn( + f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' + 'metainfo. ``dataset_meta`` in evaluator, metric and ' + 'visualizer will be None.') def run(self) -> None: """Launch test.""" - self.runner.cur_dataloader = self.dataloader self.runner.call_hook('before_test') self.runner.call_hook('before_test_epoch') self.runner.model.eval() @@ -258,13 +283,11 @@ class TestLoop(BaseLoop): self.runner.call_hook('after_test') @torch.no_grad() - def run_iter(self, idx, - data_batch: Sequence[Tuple[Any, BaseDataElement]]) -> None: + def run_iter(self, idx, data_batch: Sequence[dict]) -> None: """Iterate one mini-batch. Args: - data_batch (Sequence[Tuple[Any, BaseDataElement]]): Batch of data - from dataloader. + data_batch (Sequence[dict]): Batch of data from dataloader. """ self.runner.call_hook( 'before_test_iter', batch_idx=idx, data_batch=data_batch) diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index a1b11511791a85d91285b0fe961f437a29e036af..7e8990be4c8242c6e6edd4714e880a7bb5a35377 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -30,10 +30,10 @@ from mmengine.model import is_model_wrapper from mmengine.optim import _ParamScheduler, build_optimizer from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS, - DefaultScope) + VISUALIZERS, DefaultScope) from mmengine.utils import (TORCH_VERSION, digit_version, find_latest_checkpoint, is_list_of, symlink) -from mmengine.visualization import ComposedWriter +from mmengine.visualization import Visualizer from .base_loop import BaseLoop from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model, get_state_dict, save_checkpoint, weights_to_cpu) @@ -129,8 +129,8 @@ class Runner: dict(dist_cfg=dict(backend='nccl')). log_level (int or str): The log level of MMLogger handlers. Defaults to 'INFO'. - writer (ComposedWriter or dict, optional): A ComposedWriter object or a - dict build ComposedWriter object. Defaults to None. If not + visualizer (Visualizer or dict, optional): A Visualizer object or a + dict build Visualizer object. Defaults to None. If not specified, default config will be used. default_scope (str, optional): Used to reset registries location. Defaults to None. @@ -184,9 +184,9 @@ class Runner: param_scheduler=dict(type='ParamSchedulerHook')), launcher='none', env_cfg=dict(dist_cfg=dict(backend='nccl')), - writer=dict( - name='composed_writer', - writers=[dict(type='LocalWriter', save_dir='temp_dir')]) + visualizer=dict(type='Visualizer', + vis_backends=[dict(type='LocalVisBackend', + save_dir='temp_dir')]) ) >>> runner = Runner.from_cfg(cfg) >>> runner.train() @@ -218,7 +218,7 @@ class Runner: launcher: str = 'none', env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')), log_level: str = 'INFO', - writer: Optional[Union[ComposedWriter, Dict]] = None, + visualizer: Optional[Union[Visualizer, Dict]] = None, default_scope: Optional[str] = None, randomness: Dict = dict(seed=None), experiment_name: Optional[str] = None, @@ -310,7 +310,13 @@ class Runner: else: self._experiment_name = self.timestamp + # Used to reset registries location. See :meth:`Registry.build` for + # more details. + self.default_scope = DefaultScope.get_instance( + self._experiment_name, scope_name=default_scope) + self.logger = self.build_logger(log_level=log_level) + # Build `message_hub` for communication among components. # `message_hub` can store log scalars (loss, learning rate) and # runtime information (iter and epoch). Those components that do not @@ -321,12 +327,8 @@ class Runner: # current epoch by `cur_epoch = self.message_hub.get_info('epoch')`. # See `MessageHub` and `ManagerMixin` for more details. self.message_hub = self.build_message_hub() - # writer used for writing log or visualizing all kinds of data - self.writer = self.build_writer(writer) - # Used to reset registries location. See :meth:`Registry.build` for - # more details. - self.default_scope = DefaultScope.get_instance( - self._experiment_name, scope_name=default_scope) + # visualizer used for writing log or visualizing all kinds of data + self.visualizer = self.build_visualizer(visualizer) self._load_from = load_from self._resume = resume @@ -386,7 +388,7 @@ class Runner: launcher=cfg.get('launcher', 'none'), env_cfg=cfg.get('env_cfg'), # type: ignore log_level=cfg.get('log_level', 'INFO'), - writer=cfg.get('writer'), + visualizer=cfg.get('visualizer'), default_scope=cfg.get('default_scope'), randomness=cfg.get('randomness', dict(seed=None)), experiment_name=cfg.get('experiment_name'), @@ -646,41 +648,50 @@ class Runner: return MessageHub.get_instance(**message_hub) - def build_writer( - self, - writer: Optional[Union[ComposedWriter, - Dict]] = None) -> ComposedWriter: - """Build a global asscessable ComposedWriter. + def build_visualizer( + self, + visualizer: Optional[Union[Visualizer, + Dict]] = None) -> Visualizer: + """Build a global asscessable Visualizer. Args: - writer (ComposedWriter or dict, optional): A ComposedWriter object - or a dict to build ComposedWriter object. If ``writer`` is a - ComposedWriter object, just returns itself. If not specified, - default config will be used to build ComposedWriter object. + visualizer (Visualizer or dict, optional): A Visualizer object + or a dict to build Visualizer object. If ``visualizer`` is a + Visualizer object, just returns itself. If not specified, + default config will be used to build Visualizer object. Defaults to None. Returns: - ComposedWriter: A ComposedWriter object build from ``writer``. + Visualizer: A Visualizer object build from ``visualizer``. """ - if isinstance(writer, ComposedWriter): - return writer - elif writer is None: - writer = dict( + if visualizer is None: + visualizer = dict( name=self._experiment_name, - writers=[dict(type='LocalWriter', save_dir=self._work_dir)]) - elif isinstance(writer, dict): - # ensure writer containing name key - writer.setdefault('name', self._experiment_name) + vis_backends=[ + dict(type='LocalVisBackend', save_dir=self._work_dir) + ]) + return Visualizer.get_instance(**visualizer) + + if isinstance(visualizer, Visualizer): + return visualizer + + if isinstance(visualizer, dict): + # ensure visualizer containing name key + visualizer.setdefault('name', self._experiment_name) + visualizer.setdefault('save_dir', self._work_dir) + return VISUALIZERS.build(visualizer) else: raise TypeError( - 'writer should be ComposedWriter object, a dict or None, ' - f'but got {writer}') - - return ComposedWriter.get_instance(**writer) + 'visualizer should be Visualizer object, a dict or None, ' + f'but got {visualizer}') def build_model(self, model: Union[nn.Module, Dict]) -> nn.Module: """Build model. + If ``model`` is a dict, it will be used to build a nn.Module object + and initialize the weights if it has ``init_weights`` method. + Else, if ``model`` is a nn.Module object it will be returned directly. + An example of ``model``:: model = dict(type='ResNet') @@ -696,7 +707,11 @@ class Runner: if isinstance(model, nn.Module): return model elif isinstance(model, dict): - return MODELS.build(model) + model = MODELS.build(model) + # init weights + if hasattr(model, 'init_weights'): + model.init_weights() + return model else: raise TypeError('model should be a nn.Module object or dict, ' f'but got {model}') @@ -906,6 +921,21 @@ class Runner: # if `sampler_cfg` is not a valid type sampler = sampler_cfg + # build batch sampler + batch_sampler_cfg = dataloader_cfg.pop('batch_sampler', None) + if batch_sampler_cfg is None: + batch_sampler = None + elif isinstance(batch_sampler_cfg, dict): + batch_sampler = DATA_SAMPLERS.build( + batch_sampler_cfg, + default_args=dict( + sampler=sampler, + batch_size=dataloader_cfg.pop('batch_size'))) + else: + # fallback to raise error in dataloader + # if `batch_sampler_cfg` is not a valid type + batch_sampler = batch_sampler_cfg + # build dataloader init_fn: Optional[partial] if self.seed is not None: @@ -924,8 +954,8 @@ class Runner: # in model. data_loader = DataLoader( dataset=dataset, - sampler=sampler, - batch_sampler=None, + sampler=sampler if batch_sampler is None else None, + batch_sampler=batch_sampler, collate_fn=pseudo_collate, worker_init_fn=init_fn, **dataloader_cfg) @@ -1084,7 +1114,13 @@ class Runner: # decide to load from checkpoint or resume from checkpoint resume_from = None if self._resume and self._load_from is None: + # auto resume from the latest checkpoint resume_from = find_latest_checkpoint(self.work_dir) + self.logger.info( + f'Auto resumed from the latest checkpoint {resume_from}.') + elif self._resume and self._load_from is not None: + # resume from the specified checkpoint + resume_from = self._load_from if resume_from is not None: self.resume(resume_from) @@ -1230,6 +1266,8 @@ class Runner: +----------------------+-------------------------+ | IterTimerHook | NORMAL (40) | +----------------------+-------------------------+ + | DistSamplerSeedHook | NORMAL (40) | + +----------------------+-------------------------+ | LoggerHook | BELOW_NORMAL (60) | +----------------------+-------------------------+ | ParamSchedulerHook | LOW (70) | @@ -1243,6 +1281,7 @@ class Runner: default_hooks = dict( optimizer=dict(type='OptimizerHook', grad_clip=None), timer=dict(type='IterTimerHook'), + sampler_seed=dict(type='DistSamplerSeedHook'), logger=dict(type='LoggerHook'), param_scheduler=dict(type='ParamSchedulerHook'), checkpoint=dict(type='CheckpointHook', interval=1), @@ -1267,6 +1306,7 @@ class Runner: logger=dict(type='LoggerHook'), param_scheduler=dict(type='ParamSchedulerHook'), checkpoint=dict(type='CheckpointHook', interval=1), + sampler_seed=dict(type='DistSamplerSeedHook'), ) if hooks is not None: for name, hook in hooks.items(): @@ -1408,8 +1448,13 @@ class Runner: # Add comments to describe the usage of `after_load_ckpt` self.call_hook('after_load_ckpt', checkpoint=checkpoint) + if is_model_wrapper(self.model): + model = self.model.module + else: + model = self.model + checkpoint = _load_checkpoint_to_model( - self.model, checkpoint, strict, revise_keys=revise_keys) + model, checkpoint, strict, revise_keys=revise_keys) self._has_loaded = True diff --git a/mmengine/visualization/__init__.py b/mmengine/visualization/__init__.py index 892c3daae108284340c73b0b641ae91ca5a92f9c..6c8b0bb5464fcb96b0b865a2328b088d68190ab3 100644 --- a/mmengine/visualization/__init__.py +++ b/mmengine/visualization/__init__.py @@ -1,9 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .vis_backend import (BaseVisBackend, LocalVisBackend, + TensorboardVisBackend, WandbVisBackend) from .visualizer import Visualizer -from .writer import (BaseWriter, ComposedWriter, LocalWriter, - TensorboardWriter, WandbWriter) __all__ = [ - 'Visualizer', 'BaseWriter', 'LocalWriter', 'WandbWriter', - 'TensorboardWriter', 'ComposedWriter' + 'Visualizer', 'BaseVisBackend', 'LocalVisBackend', 'WandbVisBackend', + 'TensorboardVisBackend' ] diff --git a/mmengine/visualization/utils.py b/mmengine/visualization/utils.py index 97803ce2230f4bf13670025cecca05ea07014f1a..a0033dac151a623620d9d866fe299f013dbb6dce 100644 --- a/mmengine/visualization/utils.py +++ b/mmengine/visualization/utils.py @@ -1,6 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, List, Tuple, Type, Union +from typing import Any, List, Optional, Tuple, Type, Union + +import cv2 +import matplotlib import numpy as np import torch @@ -84,3 +87,60 @@ def check_type_and_length(name: str, value: Any, """ check_type(name, value, valid_type) check_length(name, value, valid_length) + + +def color_val_matplotlib(colors): + """Convert various input in RGB order to normalized RGB matplotlib color + tuples, + Args: + color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Color inputs + Returns: + tuple[float]: A tuple of 3 normalized floats indicating RGB channels. + """ + if isinstance(colors, str): + return colors + elif isinstance(colors, tuple): + assert len(colors) == 3 + for channel in colors: + assert 0 <= channel <= 255 + colors = [channel / 255 for channel in colors] + return tuple(colors) + elif isinstance(colors, list): + colors = [color_val_matplotlib(color) for color in colors] + return colors + else: + raise TypeError(f'Invalid type for color: {type(colors)}') + + +def str_color_to_rgb(color): + color = matplotlib.colors.to_rgb(color) + color = tuple([int(c * 255) for c in color]) + return color + + +def convert_overlay_heatmap(feat_map: Union[np.ndarray, torch.Tensor], + img: Optional[np.ndarray] = None, + alpha: float = 0.5) -> np.ndarray: + """Convert feat_map to heatmap and overlay on image, if image is not None. + + Args: + feat_map (np.ndarray, torch.Tensor): The feat_map to convert + with of shape (H, W), where H is the image height and W is + the image width. + img (np.ndarray, optional): The origin image. The format + should be RGB. Defaults to None. + alpha (float): The transparency of origin image. Defaults to 0.5. + + Returns: + np.ndarray: heatmap + """ + if isinstance(feat_map, torch.Tensor): + feat_map = feat_map.detach().cpu().numpy() + norm_img = np.zeros(feat_map.shape) + norm_img = cv2.normalize(feat_map, norm_img, 0, 255, cv2.NORM_MINMAX) + norm_img = np.asarray(norm_img, dtype=np.uint8) + heat_img = cv2.applyColorMap(norm_img, cv2.COLORMAP_JET) + heat_img = cv2.cvtColor(heat_img, cv2.COLOR_BGR2RGB) + if img is not None: + heat_img = cv2.addWeighted(img, alpha, heat_img, 1 - alpha, 0) + return heat_img diff --git a/mmengine/visualization/vis_backend.py b/mmengine/visualization/vis_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..13de36d80ceb762cd222f92cb0a6cbc1f790a51a --- /dev/null +++ b/mmengine/visualization/vis_backend.py @@ -0,0 +1,494 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +import time +from abc import ABCMeta, abstractmethod +from typing import Any, Optional, Sequence, Union + +import cv2 +import numpy as np +import torch + +from mmengine.config import Config +from mmengine.fileio import dump +from mmengine.registry import VISBACKENDS +from mmengine.utils import TORCH_VERSION + + +class BaseVisBackend(metaclass=ABCMeta): + """Base class for vis backend. + + All backends must inherit ``BaseVisBackend`` and implement + the required functions. + + Args: + save_dir (str, optional): The root directory to save + the files produced by the backend. Default to None. + """ + + def __init__(self, save_dir: Optional[str] = None): + self._save_dir = save_dir + if self._save_dir: + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + self._save_dir = osp.join(self._save_dir, + f'vis_data_{timestamp}') # type: ignore + + @property + @abstractmethod + def experiment(self) -> Any: + """Return the experiment object associated with this writer. + + The experiment attribute can get the visualizer backend, such as wandb, + tensorboard. If you want to write other data, such as writing a table, + you can directly get the visualizer backend through experiment. + """ + pass + + def add_config(self, config: Config, **kwargs) -> None: + """Record a set of parameters. + + Args: + config (Config): The Config object + """ + pass + + def add_graph(self, model: torch.nn.Module, data_batch: Sequence[dict], + **kwargs) -> None: + """Record graph. + + Args: + model (torch.nn.Module): Model to draw. + data_batch (Sequence[dict]): Batch of data from dataloader. + """ + pass + + def add_image(self, + name: str, + image: np.ndarray, + step: int = 0, + **kwargs) -> None: + """Record image. + + Args: + name (str): The unique identifier for the image to save. + image (np.ndarray, optional): The image to be saved. The format + should be RGB. Default to None. + step (int): Global step value to record. Default to 0. + """ + pass + + def add_scalar(self, + name: str, + value: Union[int, float], + step: int = 0, + **kwargs) -> None: + """Record scalar. + + Args: + name (str): The unique identifier for the scalar to save. + value (float, int): Value to save. + step (int): Global step value to record. Default to 0. + """ + pass + + def add_scalars(self, + scalar_dict: dict, + step: int = 0, + file_path: Optional[str] = None, + **kwargs) -> None: + """Record scalars' data. + + Args: + scalar_dict (dict): Key-value pair storing the tag and + corresponding values. + step (int): Global step value to record. Default to 0. + file_path (str, optional): The scalar's data will be + saved to the `file_path` file at the same time + if the `file_path` parameter is specified. + Default to None. + """ + pass + + def close(self) -> None: + """close an opened object.""" + pass + + +@VISBACKENDS.register_module() +class LocalVisBackend(BaseVisBackend): + """Local vis backend class. + + It can write image, config, scalars, etc. + to the local hard disk. You can get the drawing backend + through the visualizer property for custom drawing. + + Examples: + >>> from mmengine.visualization import LocalVisBackend + >>> import numpy as np + >>> local_vis_backend = LocalVisBackend(save_dir='temp_dir') + >>> img=np.random.randint(0, 256, size=(10, 10, 3)) + >>> local_vis_backend.add_image('img', img) + >>> local_vis_backend.add_scaler('mAP', 0.6) + >>> local_vis_backend.add_scalars({'loss': [1, 2, 3], 'acc': 0.8}) + >>> local_vis_backend.add_image('img', image) + + Args: + save_dir (str, optional): The root directory to save the files + produced by the writer. If it is none, it means no data + is stored. Default None. + img_save_dir (str): The directory to save images. + Default to 'writer_image'. + config_save_file (str): The file to save parameters. + Default to 'parameters.yaml'. + scalar_save_file (str): The file to save scalar values. + Default to 'scalars.json'. + """ + + def __init__(self, + save_dir: Optional[str] = None, + img_save_dir: str = 'vis_image', + config_save_file: str = 'config.py', + scalar_save_file: str = 'scalars.json'): + assert config_save_file.split('.')[-1] == 'py' + assert scalar_save_file.split('.')[-1] == 'json' + super(LocalVisBackend, self).__init__(save_dir) + if self._save_dir is not None: + os.makedirs(self._save_dir, exist_ok=True) # type: ignore + self._img_save_dir = osp.join( + self._save_dir, # type: ignore + img_save_dir) + self._scalar_save_file = osp.join( + self._save_dir, # type: ignore + scalar_save_file) + self._config_save_file = osp.join( + self._save_dir, # type: ignore + config_save_file) + + @property + def experiment(self) -> 'LocalVisBackend': + """Return the experiment object associated with this visualizer + backend.""" + return self + + def add_config(self, config: Config, **kwargs) -> None: + # TODO + assert isinstance(config, Config) + + def add_image(self, + name: str, + image: np.ndarray = None, + step: int = 0, + **kwargs) -> None: + """Record image to disk. + + Args: + name (str): The unique identifier for the image to save. + image (np.ndarray, optional): The image to be saved. The format + should be RGB. Default to None. + step (int): Global step value to record. Default to 0. + """ + + drawn_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + os.makedirs(self._img_save_dir, exist_ok=True) + save_file_name = f'{name}_{step}.png' + cv2.imwrite(osp.join(self._img_save_dir, save_file_name), drawn_image) + + def add_scalar(self, + name: str, + value: Union[int, float], + step: int = 0, + **kwargs) -> None: + """Add scalar data to disk. + + Args: + name (str): The unique identifier for the scalar to save. + value (float, int): Value to save. + step (int): Global step value to record. Default to 0. + """ + self._dump({name: value, 'step': step}, self._scalar_save_file, 'json') + + def add_scalars(self, + scalar_dict: dict, + step: int = 0, + file_path: Optional[str] = None, + **kwargs) -> None: + """Record scalars. The scalar dict will be written to the default and + specified files if ``file_name`` is specified. + + Args: + scalar_dict (dict): Key-value pair storing the tag and + corresponding values. + step (int): Global step value to record. Default to 0. + file_path (str, optional): The scalar's data will be + saved to the ``file_path`` file at the same time + if the ``file_path`` parameter is specified. + Default to None. + """ + assert isinstance(scalar_dict, dict) + scalar_dict.setdefault('step', step) + if file_path is not None: + assert file_path.split('.')[-1] == 'json' + new_save_file_path = osp.join( + self._save_dir, # type: ignore + file_path) + assert new_save_file_path != self._scalar_save_file, \ + '"file_path" and "scalar_save_file" have the same name, ' \ + 'please set "file_path" to another value' + self._dump(scalar_dict, new_save_file_path, 'json') + self._dump(scalar_dict, self._scalar_save_file, 'json') + + def _dump(self, value_dict: dict, file_path: str, + file_format: str) -> None: + """dump dict to file. + + Args: + value_dict (dict) : Save dict data. + file_path (str): The file path to save data. + file_format (str): The file format to save data. + """ + with open(file_path, 'a+') as f: + dump(value_dict, f, file_format=file_format) + f.write('\n') + + +@VISBACKENDS.register_module() +class WandbVisBackend(BaseVisBackend): + """Write various types of data to wandb. + + Examples: + >>> from mmengine.visualization import WandbVisBackend + >>> import numpy as np + >>> wandb_vis_backend = WandbVisBackend() + >>> img=np.random.randint(0, 256, size=(10, 10, 3)) + >>> wandb_vis_backend.add_image('img', img) + >>> wandb_vis_backend.add_scaler('mAP', 0.6) + >>> wandb_vis_backend.add_scalars({'loss': [1, 2, 3],'acc': 0.8}) + >>> wandb_vis_backend.add_image('img', img) + + Args: + init_kwargs (dict, optional): wandb initialization + input parameters. Default to None. + commit: (bool, optional) Save the metrics dict to the wandb server + and increment the step. If false `wandb.log` just + updates the current metrics dict with the row argument + and metrics won't be saved until `wandb.log` is called + with `commit=True`. Default to True. + save_dir (str, optional): The root directory to save the files + produced by the writer. Default to None. + """ + + def __init__(self, + init_kwargs: Optional[dict] = None, + commit: Optional[bool] = True, + save_dir: Optional[str] = None): + super(WandbVisBackend, self).__init__(save_dir) + self._commit = commit + self._wandb = self._setup_env(init_kwargs) + + @property + def experiment(self): + """Return wandb object. + + The experiment attribute can get the wandb backend, If you want to + write other data, such as writing a table, you can directly get the + wandb backend through experiment. + """ + return self._wandb + + def _setup_env(self, init_kwargs: Optional[dict] = None) -> Any: + """Setup env. + + Args: + init_kwargs (dict): The init args. + + Return: + :obj:`wandb` + """ + try: + import wandb + except ImportError: + raise ImportError( + 'Please run "pip install wandb" to install wandb') + if init_kwargs: + wandb.init(**init_kwargs) + else: + wandb.init() + + return wandb + + def add_config(self, config: Config, **kwargs) -> None: + # TODO + pass + + def add_image(self, + name: str, + image: np.ndarray = None, + step: int = 0, + **kwargs) -> None: + """Record image to wandb. + + Args: + name (str): The unique identifier for the image to save. + image (np.ndarray, optional): The image to be saved. The format + should be RGB. Default to None. + step (int): Global step value to record. Default to 0. + """ + self._wandb.log({name: image}, commit=self._commit, step=step) + + def add_scalar(self, + name: str, + value: Union[int, float], + step: int = 0, + **kwargs) -> None: + """Record scalar data to wandb. + + Args: + name (str): The unique identifier for the scalar to save. + value (float, int): Value to save. + step (int): Global step value to record. Default to 0. + """ + self._wandb.log({name: value}, commit=self._commit, step=step) + + def add_scalars(self, + scalar_dict: dict, + step: int = 0, + file_path: Optional[str] = None, + **kwargs) -> None: + """Record scalar's data to wandb. + + Args: + scalar_dict (dict): Key-value pair storing the tag and + corresponding values. + step (int): Global step value to record. Default to 0. + file_path (str, optional): Useless parameter. Just for + interface unification. Default to None. + """ + self._wandb.log(scalar_dict, commit=self._commit, step=step) + + def close(self) -> None: + """close an opened wandb object.""" + if hasattr(self, '_wandb'): + self._wandb.join() + + +@VISBACKENDS.register_module() +class TensorboardVisBackend(BaseVisBackend): + """Tensorboard class. It can write images, config, scalars, etc. to a + tensorboard file. + + Its drawing function is provided by Visualizer. + + Examples: + >>> from mmengine.visualization import TensorboardVisBackend + >>> import numpy as np + >>> tensorboard_visualizer = TensorboardVisBackend(save_dir='temp_dir') + >>> img=np.random.randint(0, 256, size=(10, 10, 3)) + >>> tensorboard_visualizer.add_image('img', img) + >>> tensorboard_visualizer.add_scaler('mAP', 0.6) + >>> tensorboard_visualizer.add_scalars({'loss': 0.1,'acc':0.8}) + >>> tensorboard_visualizer.add_image('img', image) + + Args: + save_dir (str): The root directory to save the files + produced by the backend. + log_dir (str): Save directory location. Default to 'tf_logs'. + """ + + def __init__(self, + save_dir: Optional[str] = None, + log_dir: str = 'tf_logs'): + super(TensorboardVisBackend, self).__init__(save_dir) + if save_dir is not None: + self._tensorboard = self._setup_env(log_dir) + + def _setup_env(self, log_dir: str): + """Setup env. + + Args: + log_dir (str): Save directory location. + + Return: + :obj:`SummaryWriter` + """ + if TORCH_VERSION == 'parrots': + try: + from tensorboardX import SummaryWriter + except ImportError: + raise ImportError('Please install tensorboardX to use ' + 'TensorboardLoggerHook.') + else: + try: + from torch.utils.tensorboard import SummaryWriter + except ImportError: + raise ImportError( + 'Please run "pip install future tensorboard" to install ' + 'the dependencies to use torch.utils.tensorboard ' + '(applicable to PyTorch 1.1 or higher)') + if self._save_dir is None: + return SummaryWriter(f'./{log_dir}') + else: + self.log_dir = osp.join(self._save_dir, log_dir) # type: ignore + return SummaryWriter(self.log_dir) + + @property + def experiment(self): + """Return Tensorboard object.""" + return self._tensorboard + + def add_config(self, config: Config, **kwargs) -> None: + # TODO + pass + + def add_image(self, + name: str, + image: np.ndarray, + step: int = 0, + **kwargs) -> None: + """Record image to tensorboard. + + Args: + name (str): The unique identifier for the image to save. + image (np.ndarray, optional): The image to be saved. The format + should be RGB. Default to None. + step (int): Global step value to record. Default to 0. + """ + self._tensorboard.add_image(name, image, step, dataformats='HWC') + + def add_scalar(self, + name: str, + value: Union[int, float], + step: int = 0, + **kwargs) -> None: + """Record scalar data to summary. + + Args: + name (str): The unique identifier for the scalar to save. + value (float, int): Value to save. + step (int): Global step value to record. Default to 0. + """ + self._tensorboard.add_scalar(name, value, step) + + def add_scalars(self, + scalar_dict: dict, + step: int = 0, + file_path: Optional[str] = None, + **kwargs) -> None: + """Record scalar's data to summary. + + Args: + scalar_dict (dict): Key-value pair storing the tag and + corresponding values. + step (int): Global step value to record. Default to 0. + file_path (str, optional): Useless parameter. Just for + interface unification. Default to None. + """ + assert isinstance(scalar_dict, dict) + assert 'step' not in scalar_dict, 'Please set it directly ' \ + 'through the step parameter' + for key, value in scalar_dict.items(): + self.add_scalar(key, value, step) + + def close(self): + """close an opened tensorboard object.""" + if hasattr(self, '_tensorboard'): + self._tensorboard.close() diff --git a/mmengine/visualization/visualizer.py b/mmengine/visualization/visualizer.py index ae6ff113b2caeedf75e2ea892dcb8019190c20a3..d80256167b0a85df0c6abd1a65d10aecd887b590 100644 --- a/mmengine/visualization/visualizer.py +++ b/mmengine/visualization/visualizer.py @@ -1,25 +1,32 @@ # Copyright (c) OpenMMLab. All rights reserved. import warnings -from typing import Callable, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Sequence, Tuple, Union import cv2 import matplotlib.pyplot as plt import numpy as np import torch +from matplotlib.backend_bases import CloseEvent from matplotlib.backends.backend_agg import FigureCanvasAgg from matplotlib.collections import (LineCollection, PatchCollection, PolyCollection) from matplotlib.figure import Figure from matplotlib.patches import Circle +from mmengine.config import Config from mmengine.data import BaseDataElement -from mmengine.registry import VISUALIZERS -from .utils import (check_type, check_type_and_length, tensor2ndarray, - value2list) +from mmengine.registry import VISBACKENDS, VISUALIZERS +from mmengine.utils import ManagerMixin +from mmengine.visualization.utils import (check_type, check_type_and_length, + color_val_matplotlib, + convert_overlay_heatmap, + str_color_to_rgb, tensor2ndarray, + value2list) +from mmengine.visualization.vis_backend import BaseVisBackend @VISUALIZERS.register_module() -class Visualizer: +class Visualizer(ManagerMixin): """MMEngine provides a Visualizer class that uses the ``Matplotlib`` library as the backend. It has the following functions: @@ -67,15 +74,15 @@ class Visualizer: >>> # Basic drawing methods >>> vis = Visualizer(metadata=metadata, image=image) - >>> vis.draw_bboxes(np.array([0, 0, 1, 1]), edgecolors='g') + >>> vis.draw_bboxes(np.array([0, 0, 1, 1]), edge_colors='g') >>> vis.draw_bboxes(bbox=np.array([[1, 1, 2, 2], [2, 2, 3, 3]]), - edgecolors=['g', 'r'], is_filling=True) + edge_colors=['g', 'r'], is_filling=True) >>> vis.draw_lines(x_datas=np.array([1, 3]), y_datas=np.array([1, 3]), - colors='r', linewidths=1) + colors='r', line_widths=1) >>> vis.draw_lines(x_datas=np.array([[1, 3], [2, 4]]), y_datas=np.array([[1, 3], [2, 4]]), - colors=['r', 'r'], linewidths=[1, 2]) + colors=['r', 'r'], line_widths=[1, 2]) >>> vis.draw_texts(text='MMEngine', position=np.array([2, 2]), colors='b') @@ -87,10 +94,10 @@ class Visualizer: radius=np.array[1, 2], colors=['g', 'r'], is_filling=True) >>> vis.draw_polygons(np.array([0, 0, 1, 0, 1, 1, 0, 1]), - edgecolors='g') + edge_colors='g') >>> vis.draw_polygons(bbox=[np.array([0, 0, 1, 0, 1, 1, 0, 1], np.array([2, 2, 3, 2, 3, 3, 2, 3]]), - edgecolors=['g', 'r'], is_filling=True) + edge_colors=['g', 'r'], is_filling=True) >>> vis.draw_binary_masks(binary_mask, alpha=0.6) >>> # chain calls @@ -106,80 +113,99 @@ class Visualizer: >>> # inherit >>> class DetVisualizer2(Visualizer): - >>> @Visualizer.register_task('instances') - >>> def draw_instance(self, - >>> instances: 'BaseDataInstance', - >>> data_type: Type): - >>> pass - >>> def draw(self, + >>> def add_datasample(self, >>> image: Optional[np.ndarray] = None, >>> gt_sample: Optional['BaseDataElement'] = None, >>> pred_sample: Optional['BaseDataElement'] = None, >>> show_gt: bool = True, - >>> show_pred: bool = True) -> None: + >>> show_pred: bool = True, + >>> show:bool = True) -> None: >>> pass """ - task_dict: dict = {} - - def __init__(self, - image: Optional[np.ndarray] = None, - metadata: Optional[dict] = None) -> None: - self._metadata = metadata + def __init__( + self, + name='visualizer', + image: Optional[np.ndarray] = None, + vis_backends: Optional[Dict] = None, + save_dir: Optional[str] = None, + fig_save_cfg=dict(frameon=False), + fig_show_cfg=dict(frameon=False, num='show') + ) -> None: + super().__init__(name) + self._dataset_meta: Union[None, dict] = None + self._vis_backends: Union[Dict, Dict[str, 'BaseVisBackend']] = dict() + + if vis_backends: + with_name = False + without_name = False + for vis_backend in vis_backends: + if 'name' in vis_backend: + with_name = True + else: + without_name = True + if with_name and without_name: + raise AssertionError + + for vis_backend in vis_backends: + name = vis_backend.pop('name', vis_backend['type']) + assert name not in self._vis_backends + vis_backend.setdefault('save_dir', save_dir) + self._vis_backends[name] = VISBACKENDS.build(vis_backend) + + self.is_inline = 'inline' in plt.get_backend() + + self.fig_save = None + self.fig_show = None + self.fig_save_num = fig_save_cfg.get('num', None) + self.fig_show_num = fig_show_cfg.get('num', None) + self.fig_save_cfg = fig_save_cfg + self.fig_show_cfg = fig_show_cfg + + (self.fig_save, self.ax_save, + self.fig_save_num) = self._initialize_fig(fig_save_cfg) + self.dpi = self.fig_save.get_dpi() if image is not None: - self._setup_fig(image) - - def draw(self, - image: Optional[np.ndarray] = None, - gt_sample: Optional['BaseDataElement'] = None, - pred_sample: Optional['BaseDataElement'] = None, - draw_gt: bool = True, - draw_pred: bool = True) -> None: - pass + self.set_image(image) - def show(self, wait_time: int = 0) -> None: + @property + def dataset_meta(self) -> Optional[dict]: + return self._dataset_meta + + @dataset_meta.setter + def dataset_meta(self, dataset_meta: dict) -> None: + self._dataset_meta = dataset_meta + + def show(self, + drawn_img: Optional[np.ndarray] = None, + win_name: str = 'image', + wait_time: int = 0, + continue_key=' ') -> None: """Show the drawn image. Args: wait_time (int, optional): Delay in milliseconds. 0 is the special value that means "forever". Defaults to 0. """ - if wait_time == 0: - plt.show() - else: - plt.show(block=False) - plt.pause(wait_time) - - def close(self) -> None: - """Close the figure.""" - plt.close(self.fig) - - @classmethod - def register_task(cls, task_name: str, force: bool = False) -> Callable: - """Register a function. - - A record will be added to ``task_dict``, whose key is the task_name - and value is the decorated function. - - Args: - cls (type): Module class to be registered. - task_name (str or list of str, optional): The module name to be - registered. - force (bool): Whether to override an existing function with the - same name. Defaults to False. - """ - - def _register(task_func): - - if (task_name not in cls.task_dict) or force: - cls.task_dict[task_name] = task_func - else: - raise KeyError( - f'"{task_name}" is already registered in task_dict, ' - 'add "force=True" if you want to override it') - return task_func - - return _register + if self.is_inline: + return + if self.fig_show is None or not plt.fignum_exists(self.fig_show_num): + (self.fig_show, self.ax_show, + self.fig_show_num) = self._initialize_fig(self.fig_show_cfg) + img = self.get_image() if drawn_img is None else drawn_img + # dpi = self.fig_show.get_dpi() + # height, width = img.shape[:2] + # self.fig_show.set_size_inches((width + 1e-2) / dpi, + # (height + 1e-2) / dpi) + self.ax_show.cla() + self.ax_show.axis(False) + # self.ax_show.set_title(win_name) + # self.fig_show.set_label(win_name) + + # Refresh canvas, necessary for Qt5 backend. + self.ax_show.imshow(img) + self.fig_show.canvas.draw() # type: ignore + self._wait_continue(timeout=wait_time, continue_key=continue_key) def set_image(self, image: np.ndarray) -> None: """Set the image to draw. @@ -188,7 +214,23 @@ class Visualizer: image (np.ndarray): The image to draw. """ assert image is not None - self._setup_fig(image) + image = image.astype('uint8') + self._image = image + self.width, self.height = image.shape[1], image.shape[0] + self._default_font_size = max( + np.sqrt(self.height * self.width) // 90, 10) + + # add a small 1e-2 to avoid precision lost due to matplotlib's + # truncation (https://github.com/matplotlib/matplotlib/issues/15363) + self.fig_save.set_size_inches( # type: ignore + (self.width + 1e-2) / self.dpi, (self.height + 1e-2) / self.dpi) + # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig) + self.ax_save.cla() + self.ax_save.axis(False) + self.ax_save.imshow( + image, + extent=(0, self.width, self.height, 0), + interpolation='none') def get_image(self) -> np.ndarray: """Get the drawn image. The format is RGB. @@ -197,43 +239,24 @@ class Visualizer: np.ndarray: the drawn image which channel is rgb. """ assert self._image is not None, 'Please set image using `set_image`' - canvas = self.canvas + canvas = self.fig_save.canvas # type: ignore s, (width, height) = canvas.print_to_buffer() buffer = np.frombuffer(s, dtype='uint8') img_rgba = buffer.reshape(height, width, 4) rgb, alpha = np.split(img_rgba, [3], axis=2) return rgb.astype('uint8') - def _setup_fig(self, image: np.ndarray) -> None: - """Set the image to draw. + def _initialize_fig(self, fig_cfg): + fig = plt.figure(**fig_cfg) + ax = fig.add_subplot() + ax.axis(False) - Args: - image (np.ndarray): The image to draw.The format - should be RGB. - """ - image = image.astype('uint8') - self._image = image - self.width, self.height = image.shape[1], image.shape[0] - self._default_font_size = max( - np.sqrt(self.height * self.width) // 90, 10) - fig = plt.figure(frameon=False) + # remove white edges by set subplot margin + fig.subplots_adjust(left=0, right=1, bottom=0, top=1) + return fig, ax, fig.number - self.dpi = fig.get_dpi() - # add a small 1e-2 to avoid precision lost due to matplotlib's - # truncation (https://github.com/matplotlib/matplotlib/issues/15363) - fig.set_size_inches((self.width + 1e-2) / self.dpi, - (self.height + 1e-2) / self.dpi) - self.canvas = fig.canvas - # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig) - plt.subplots_adjust(left=0, right=1, bottom=0, top=1) - plt.axis('off') - ax = plt.gca() - self.fig = fig - self.ax = ax - self.ax.imshow( - image, - extent=(0, self.width, self.height, 0), - interpolation='none') + def get_backend(self, name) -> 'BaseVisBackend': + return self._vis_backends.get(name) # type: ignore def _is_posion_valid(self, position: np.ndarray) -> bool: """Judge whether the position is in image. @@ -251,14 +274,86 @@ class Visualizer: (position[..., 1] >= 0).all() return flag + def _wait_continue(self, timeout: int = 0, continue_key=' ') -> int: + """Show the image and wait for the user's input. + + This implementation refers to + https://github.com/matplotlib/matplotlib/blob/v3.5.x/lib/matplotlib/_blocking_input.py + + Args: + timeout (int): If positive, continue after ``timeout`` seconds. + Defaults to 0. + continue_key (str): The key for users to continue. Defaults to + the space key. + + Returns: + int: If zero, means time out or the user pressed ``continue_key``, + and if one, means the user closed the show figure. + """ # noqa: E501 + if self.is_inline: + # If use inline backend, interactive input and timeout is no use. + return 0 + + if self.fig_show.canvas.manager: # type: ignore + # Ensure that the figure is shown + self.fig_show.show() # type: ignore + + while True: + + # Connect the events to the handler function call. + event = None + + def handler(ev): + # Set external event variable + nonlocal event + # Qt backend may fire two events at the same time, + # use a condition to avoid missing close event. + event = ev if not isinstance(event, CloseEvent) else event + self.fig_show.canvas.stop_event_loop() + + cids = [ + self.fig_show.canvas.mpl_connect(name, handler) # type: ignore + for name in ('key_press_event', 'close_event') + ] + + try: + self.fig_show.canvas.start_event_loop(timeout) # type: ignore + finally: # Run even on exception like ctrl-c. + # Disconnect the callbacks. + for cid in cids: + self.fig_show.canvas.mpl_disconnect(cid) # type: ignore + + if isinstance(event, CloseEvent): + return 1 # Quit for close. + elif event is None or event.key == continue_key: + return 0 # Quit for continue. + + def draw_points(self, + positions: Union[np.ndarray, torch.Tensor], + colors: Union[str, tuple, List[str], List[tuple]] = 'g', + marker: Optional[str] = None, + sizes: Optional[Union[np.ndarray, torch.Tensor]] = None): + check_type('positions', positions, (np.ndarray, torch.Tensor)) + positions = tensor2ndarray(positions) + + if len(positions.shape) == 1: + positions = positions[None] + assert positions.shape[-1] == 2, ( + 'The shape of `positions` should be (N, 2), ' + f'but got {positions.shape}') + colors = color_val_matplotlib(colors) + self.ax_save.scatter( + positions[:, 0], positions[:, 1], c=colors, s=sizes, marker=marker) + return self + def draw_texts( self, texts: Union[str, List[str]], positions: Union[np.ndarray, torch.Tensor], font_sizes: Optional[Union[int, List[int]]] = None, - colors: Union[str, List[str]] = 'g', - verticalalignments: Union[str, List[str]] = 'top', - horizontalalignments: Union[str, List[str]] = 'left', + colors: Union[str, tuple, List[str], List[tuple]] = 'g', + vertical_alignments: Union[str, List[str]] = 'top', + horizontal_alignments: Union[str, List[str]] = 'left', font_families: Union[str, List[str]] = 'sans-serif', rotations: Union[int, str, List[Union[int, str]]] = 0, bboxes: Optional[Union[dict, List[dict]]] = None) -> 'Visualizer': @@ -273,29 +368,29 @@ class Visualizer: texts. ``font_sizes`` can have the same length with texts or just single value. If ``font_sizes`` is single value, all the texts will have the same font size. Defaults to None. - colors (Union[str, List[str]]): The colors of texts. ``colors`` - can have the same length with texts or just single value. - If ``colors`` is single value, all the texts will have the same - colors. Reference to + colors (Union[str, tuple, List[str], List[tuple]]): The colors + of texts. ``colors`` can have the same length with texts or + just single value. If ``colors`` is single value, all the + texts will have the same colors. Reference to https://matplotlib.org/stable/gallery/color/named_colors.html for more details. Defaults to 'g. - verticalalignments (Union[str, List[str]]): The verticalalignment + vertical_alignments (Union[str, List[str]]): The verticalalignment of texts. verticalalignment controls whether the y positional argument for the text indicates the bottom, center or top side of the text bounding box. - ``verticalalignments`` can have the same length with - texts or just single value. If ``verticalalignments`` is single - value, all the texts will have the same verticalalignment. - verticalalignment can be 'center' or 'top', 'bottom' or - 'baseline'. Defaults to 'top'. - horizontalalignments (Union[str, List[str]]): The + ``vertical_alignments`` can have the same length with + texts or just single value. If ``vertical_alignments`` is + single value, all the texts will have the same + verticalalignment. verticalalignment can be 'center' or + 'top', 'bottom' or 'baseline'. Defaults to 'top'. + horizontal_alignments (Union[str, List[str]]): The horizontalalignment of texts. Horizontalalignment controls whether the x positional argument for the text indicates the left, center or right side of the text bounding box. - ``horizontalalignments`` can have + ``horizontal_alignments`` can have the same length with texts or just single value. - If ``horizontalalignments`` is single value, all the texts will - have the same horizontalalignment. Horizontalalignment + If ``horizontal_alignments`` is single value, all the texts + will have the same horizontalalignment. Horizontalalignment can be 'center','right' or 'left'. Defaults to 'left'. font_families (Union[str, List[str]]): The font family of texts. ``font_families`` can have the same length with texts or @@ -335,19 +430,22 @@ class Visualizer: if font_sizes is None: font_sizes = self._default_font_size - check_type_and_length('font_sizes', font_sizes, (int, list), num_text) - font_sizes = value2list(font_sizes, int, num_text) + check_type_and_length('font_sizes', font_sizes, (int, float, list), + num_text) + font_sizes = value2list(font_sizes, (int, float), num_text) - check_type_and_length('colors', colors, (str, list), num_text) - colors = value2list(colors, str, num_text) + check_type_and_length('colors', colors, (str, tuple, list), num_text) + colors = value2list(colors, (str, tuple), num_text) + colors = color_val_matplotlib(colors) - check_type_and_length('verticalalignments', verticalalignments, + check_type_and_length('vertical_alignments', vertical_alignments, (str, list), num_text) - verticalalignments = value2list(verticalalignments, str, num_text) + vertical_alignments = value2list(vertical_alignments, str, num_text) - check_type_and_length('horizontalalignments', horizontalalignments, + check_type_and_length('horizontal_alignments', horizontal_alignments, (str, list), num_text) - horizontalalignments = value2list(horizontalalignments, str, num_text) + horizontal_alignments = value2list(horizontal_alignments, str, + num_text) check_type_and_length('rotations', rotations, (int, list), num_text) rotations = value2list(rotations, int, num_text) @@ -363,14 +461,14 @@ class Visualizer: bboxes = value2list(bboxes, dict, num_text) for i in range(num_text): - self.ax.text( + self.ax_save.text( positions[i][0], positions[i][1], texts[i], size=font_sizes[i], # type: ignore bbox=bboxes[i], # type: ignore - verticalalignment=verticalalignments[i], - horizontalalignment=horizontalalignments[i], + verticalalignment=vertical_alignments[i], + horizontalalignment=horizontal_alignments[i], family=font_families[i], color=colors[i]) return self @@ -379,9 +477,9 @@ class Visualizer: self, x_datas: Union[np.ndarray, torch.Tensor], y_datas: Union[np.ndarray, torch.Tensor], - colors: Union[str, List[str]] = 'g', - linestyles: Union[str, List[str]] = '-', - linewidths: Union[Union[int, float], List[Union[int, float]]] = 1 + colors: Union[str, tuple, List[str], List[tuple]] = 'g', + line_styles: Union[str, List[str]] = '-', + line_widths: Union[Union[int, float], List[Union[int, float]]] = 2 ) -> 'Visualizer': """Draw single or multiple line segments. @@ -390,24 +488,24 @@ class Visualizer: each line' start and end points. y_datas (Union[np.ndarray, torch.Tensor]): The y coordinate of each line' start and end points. - colors (Union[str, List[str]]): The colors of lines. ``colors`` - can have the same length with lines or just single value. - If ``colors`` is single value, all the lines will have the same - colors. Reference to + colors (Union[str, tuple, List[str], List[tuple]]): The colors of + lines. ``colors`` can have the same length with lines or just + single value. If ``colors`` is single value, all the lines + will have the same colors. Reference to https://matplotlib.org/stable/gallery/color/named_colors.html for more details. Defaults to 'g'. - linestyles (Union[str, List[str]]): The linestyle - of lines. ``linestyles`` can have the same length with - texts or just single value. If ``linestyles`` is single + line_styles (Union[str, List[str]]): The linestyle + of lines. ``line_styles`` can have the same length with + texts or just single value. If ``line_styles`` is single value, all the lines will have the same linestyle. Reference to https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle for more details. Defaults to '-'. - linewidths (Union[Union[int, float], List[Union[int, float]]]): The - linewidth of lines. ``linewidths`` can have + line_widths (Union[Union[int, float], List[Union[int, float]]]): + The linewidth of lines. ``line_widths`` can have the same length with lines or just single value. - If ``linewidths`` is single value, all the lines will - have the same linewidth. Defaults to 1. + If ``line_widths`` is single value, all the lines will + have the same linewidth. Defaults to 2. """ check_type('x_datas', x_datas, (np.ndarray, torch.Tensor)) x_datas = tensor2ndarray(x_datas) @@ -421,31 +519,31 @@ class Visualizer: if len(x_datas.shape) == 1: x_datas = x_datas[None] y_datas = y_datas[None] - + colors = color_val_matplotlib(colors) lines = np.concatenate( (x_datas.reshape(-1, 2, 1), y_datas.reshape(-1, 2, 1)), axis=-1) if not self._is_posion_valid(lines): - warnings.warn( 'Warning: The line is out of bounds,' ' the drawn line may not be in the image', UserWarning) line_collect = LineCollection( lines.tolist(), colors=colors, - linestyles=linestyles, - linewidths=linewidths) - self.ax.add_collection(line_collect) + linestyles=line_styles, + linewidths=line_widths) + self.ax_save.add_collection(line_collect) return self - def draw_circles(self, - center: Union[np.ndarray, torch.Tensor], - radius: Union[np.ndarray, torch.Tensor], - alpha: Union[float, int] = 0.8, - edgecolors: Union[str, List[str]] = 'g', - linestyles: Union[str, List[str]] = '-', - linewidths: Union[Union[int, float], - List[Union[int, float]]] = 1, - is_filling: bool = False) -> 'Visualizer': + def draw_circles( + self, + center: Union[np.ndarray, torch.Tensor], + radius: Union[np.ndarray, torch.Tensor], + alpha: Union[float, int] = 0.8, + edge_colors: Union[str, tuple, List[str], List[tuple]] = 'g', + line_styles: Union[str, List[str]] = '-', + line_widths: Union[Union[int, float], List[Union[int, float]]] = 2, + face_colors: Union[str, tuple, List[str], List[tuple]] = 'none' + ) -> 'Visualizer': """Draw single or multiple circles. Args: @@ -453,24 +551,24 @@ class Visualizer: each line' start and end points. radius (Union[np.ndarray, torch.Tensor]): The y coordinate of each line' start and end points. - edgecolors (Union[str, List[str]]): The colors of circles. - ``colors`` can have the same length with lines or just single - value. If ``colors`` is single value, all the lines will have - the same colors. Reference to + edge_colors (Union[str, tuple, List[str], List[tuple]]): The + colors of circles. ``colors`` can have the same length with + lines or just single value. If ``colors`` is single value, + all the lines will have the same colors. Reference to https://matplotlib.org/stable/gallery/color/named_colors.html for more details. Defaults to 'g. - linestyles (Union[str, List[str]]): The linestyle - of lines. ``linestyles`` can have the same length with - texts or just single value. If ``linestyles`` is single + line_styles (Union[str, List[str]]): The linestyle + of lines. ``line_styles`` can have the same length with + texts or just single value. If ``line_styles`` is single value, all the lines will have the same linestyle. Reference to https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle for more details. Defaults to '-'. - linewidths (Union[Union[int, float], List[Union[int, float]]]): The - linewidth of lines. ``linewidths`` can have + line_widths (Union[Union[int, float], List[Union[int, float]]]): + The linewidth of lines. ``line_widths`` can have the same length with lines or just single value. - If ``linewidths`` is single value, all the lines will - have the same linewidth. Defaults to 1. + If ``line_widths`` is single value, all the lines will + have the same linewidth. Defaults to 2. is_filling (bool): Whether to fill all the circles. Defaults to False. """ @@ -493,57 +591,59 @@ class Visualizer: center = center.tolist() radius = radius.tolist() + edge_colors = color_val_matplotlib(edge_colors) + face_colors = color_val_matplotlib(face_colors) circles = [] for i in range(len(center)): circles.append(Circle(tuple(center[i]), radius[i])) - if is_filling: - p = PatchCollection(circles, alpha=alpha, facecolor=edgecolors) - else: - if isinstance(linewidths, (int, float)): - linewidths = [linewidths] * len(circles) - linewidths = [ - min(max(linewidth, 1), self._default_font_size / 4) - for linewidth in linewidths - ] - p = PatchCollection( - circles, - alpha=alpha, - facecolor='none', - edgecolor=edgecolors, - linewidth=linewidths, - linestyles=linestyles) - self.ax.add_collection(p) + + if isinstance(line_widths, (int, float)): + line_widths = [line_widths] * len(circles) + line_widths = [ + min(max(linewidth, 1), self._default_font_size / 4) + for linewidth in line_widths + ] + p = PatchCollection( + circles, + alpha=alpha, + facecolors=face_colors, + edgecolors=edge_colors, + linewidths=line_widths, + linestyles=line_styles) + self.ax_save.add_collection(p) return self - def draw_bboxes(self, - bboxes: Union[np.ndarray, torch.Tensor], - alpha: Union[int, float] = 0.8, - edgecolors: Union[str, List[str]] = 'g', - linestyles: Union[str, List[str]] = '-', - linewidths: Union[Union[int, float], - List[Union[int, float]]] = 1, - is_filling: bool = False) -> 'Visualizer': + def draw_bboxes( + self, + bboxes: Union[np.ndarray, torch.Tensor], + alpha: Union[int, float] = 0.8, + edge_colors: Union[str, tuple, List[str], List[tuple]] = 'g', + line_styles: Union[str, List[str]] = '-', + line_widths: Union[Union[int, float], List[Union[int, float]]] = 2, + face_colors: Union[str, tuple, List[str], List[tuple]] = 'none' + ) -> 'Visualizer': """Draw single or multiple bboxes. Args: bboxes (Union[np.ndarray, torch.Tensor]): The bboxes to draw with the format of(x1,y1,x2,y2). - edgecolors (Union[str, List[str]]): The colors of bboxes. - ``colors`` can have the same length with lines or just single - value. If ``colors`` is single value, all the lines will have - the same colors. Refer to `matplotlib.colors` for full list of - formats that are accepted.. Defaults to 'g'. - linestyles (Union[str, List[str]]): The linestyle - of lines. ``linestyles`` can have the same length with - texts or just single value. If ``linestyles`` is single + edge_colors (Union[str, tuple, List[str], List[tuple]]): The + colors of bboxes. ``colors`` can have the same length with + lines or just single value. If ``colors`` is single value, all + the lines will have the same colors. Refer to `matplotlib. + colors` for full list of formats that are accepted. + Defaults to 'g'. + line_styles (Union[str, List[str]]): The linestyle + of lines. ``line_styles`` can have the same length with + texts or just single value. If ``line_styles`` is single value, all the lines will have the same linestyle. Reference to https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle for more details. Defaults to '-'. - linewidths (Union[Union[int, float], List[Union[int, float]]]): The - linewidth of lines. ``linewidths`` can have + line_widths (Union[Union[int, float], List[Union[int, float]]]): + The linewidth of lines. ``line_widths`` can have the same length with lines or just single value. - If ``linewidths`` is single value, all the lines will + If ``line_widths`` is single value, all the lines will have the same linewidth. Defaults to 1. is_filling (bool): Whether to fill all the bboxes. Defaults to False. @@ -570,47 +670,51 @@ class Visualizer: return self.draw_polygons( poly, alpha=alpha, - edgecolors=edgecolors, - linestyles=linestyles, - linewidths=linewidths, - is_filling=is_filling) - - def draw_polygons(self, - polygons: Union[Union[np.ndarray, torch.Tensor], - List[Union[np.ndarray, torch.Tensor]]], - alpha: Union[int, float] = 0.8, - edgecolors: Union[str, List[str]] = 'g', - linestyles: Union[str, List[str]] = '-', - linewidths: Union[Union[int, float], - List[Union[int, float]]] = 1.0, - is_filling: bool = False) -> 'Visualizer': + edge_colors=edge_colors, + line_styles=line_styles, + line_widths=line_widths, + face_colors=face_colors) + + def draw_polygons( + self, + polygons: Union[Union[np.ndarray, torch.Tensor], + List[Union[np.ndarray, torch.Tensor]]], + alpha: Union[int, float] = 0.8, + edge_colors: Union[str, tuple, List[str], List[tuple]] = 'g', + line_styles: Union[str, List[str]] = '-', + line_widths: Union[Union[int, float], List[Union[int, float]]] = 2, + face_colors: Union[str, tuple, List[str], List[tuple]] = 'none' + ) -> 'Visualizer': """Draw single or multiple bboxes. Args: polygons (Union[Union[np.ndarray, torch.Tensor], List[Union[np.ndarray, torch.Tensor]]]): The polygons to draw with the format of (x1,y1,x2,y2,...,xn,yn). - edgecolors (Union[str, List[str]]): The colors of polygons. - ``colors`` can have the same length with lines or just single - value. If ``colors`` is single value, all the lines will have - the same colors. Refer to `matplotlib.colors` for full list of - formats that are accepted.. Defaults to 'g. - linestyles (Union[str, List[str]]): The linestyle - of lines. ``linestyles`` can have the same length with - texts or just single value. If ``linestyles`` is single + edge_colors (Union[str, tuple, List[str], List[tuple]]): The + colors of polygons. ``colors`` can have the same length with + lines or just single value. If ``colors`` is single value, + all the lines will have the same colors. Refer to + `matplotlib.colors` for full list of formats that are accepted. + Defaults to 'g. + line_styles (Union[str, List[str]]): The linestyle + of lines. ``line_styles`` can have the same length with + texts or just single value. If ``line_styles`` is single value, all the lines will have the same linestyle. Reference to https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle for more details. Defaults to '-'. - linewidths (Union[Union[int, float], List[Union[int, float]]]): The - linewidth of lines. ``linewidths`` can have + line_widths (Union[Union[int, float], List[Union[int, float]]]): + The linewidth of lines. ``line_widths`` can have the same length with lines or just single value. - If ``linewidths`` is single value, all the lines will - have the same linewidth. Defaults to 1. + If ``line_widths`` is single value, all the lines will + have the same linewidth. Defaults to 2. is_filling (bool): Whether to fill all the polygons. Defaults to False. """ check_type('polygons', polygons, (list, np.ndarray, torch.Tensor)) + edge_colors = color_val_matplotlib(edge_colors) + face_colors = color_val_matplotlib(face_colors) if isinstance(polygons, (np.ndarray, torch.Tensor)): polygons = [polygons] @@ -625,32 +729,29 @@ class Visualizer: warnings.warn( 'Warning: The polygon is out of bounds,' ' the drawn polygon may not be in the image', UserWarning) - if is_filling: - polygon_collection = PolyCollection( - polygons, alpha=alpha, facecolor=edgecolors) - else: - if isinstance(linewidths, (int, float)): - linewidths = [linewidths] * len(polygons) - linewidths = [ - min(max(linewidth, 1), self._default_font_size / 4) - for linewidth in linewidths - ] - polygon_collection = PolyCollection( - polygons, - alpha=alpha, - facecolor='none', - linestyles=linestyles, - edgecolors=edgecolors, - linewidths=linewidths) - - self.ax.add_collection(polygon_collection) + if isinstance(line_widths, (int, float)): + line_widths = [line_widths] * len(polygons) + line_widths = [ + min(max(linewidth, 1), self._default_font_size / 4) + for linewidth in line_widths + ] + polygon_collection = PolyCollection( + polygons, + alpha=alpha, + facecolor=face_colors, + linestyles=line_styles, + edgecolors=edge_colors, + linewidths=line_widths) + + self.ax_save.add_collection(polygon_collection) return self def draw_binary_masks( - self, - binary_masks: Union[np.ndarray, torch.Tensor], - colors: np.ndarray = np.array([0, 255, 0]), - alphas: Union[float, List[float]] = 0.5) -> 'Visualizer': + self, + binary_masks: Union[np.ndarray, torch.Tensor], + alphas: Union[float, List[float]] = 0.8, + colors: Union[str, tuple, List[str], + List[tuple]] = 'g') -> 'Visualizer': """Draw single or multiple binary masks. Args: @@ -677,12 +778,24 @@ class Visualizer: binary_masks = binary_masks[None] assert img.shape[:2] == binary_masks.shape[ 1:], '`binary_marks` must have the same shpe with image' - assert isinstance(colors, np.ndarray) - if colors.ndim == 1: - colors = np.tile(colors, (binary_masks.shape[0], 1)) - assert colors.shape == (binary_masks.shape[0], 3) + binary_mask_len = binary_masks.shape[0] + + check_type_and_length('colors', colors, (str, tuple, list), + binary_mask_len) + colors = value2list(colors, (str, tuple), binary_mask_len) + colors = [ + str_color_to_rgb(color) if isinstance(color, str) else color + for color in colors + ] + for color in colors: + assert len(color) == 3 + for channel in color: + assert 0 <= channel <= 255 # type: ignore + colors = np.array(colors) + if colors.ndim == 1: # type: ignore + colors = np.tile(colors, (binary_mask_len, 1)) if isinstance(alphas, float): - alphas = [alphas] * binary_masks.shape[0] + alphas = [alphas] * binary_mask_len for binary_mask, color, alpha in zip(binary_masks, colors, alphas): binary_mask_complement = cv2.bitwise_not(binary_mask) @@ -692,8 +805,8 @@ class Visualizer: img_complement = cv2.bitwise_and( img, img, mask=binary_mask_complement) rgb = rgb + img_complement - img = cv2.addWeighted(img, alpha, rgb, 1 - alpha, 0) - self.ax.imshow( + img = cv2.addWeighted(img, 1 - alpha, rgb, alpha, 0) + self.ax_save.imshow( img, extent=(0, self.width, self.height, 0), interpolation='nearest') @@ -705,7 +818,7 @@ class Visualizer: mode: str = 'mean', topk: int = 10, arrangement: Tuple[int, int] = (5, 2), - alpha: float = 0.3) -> np.ndarray: + alpha: float = 0.8) -> np.ndarray: """Draw featmap. If img is not None, the final image will be the weighted sum of img and featmap. It support the mode: @@ -738,37 +851,6 @@ class Visualizer: Returns: np.ndarray: featmap. """ - - def concat_heatmap(feat_map: Union[np.ndarray, torch.Tensor], - img: Optional[np.ndarray] = None, - alpha: float = 0.5) -> np.ndarray: - """Convert feat_map to heatmap and sum to image, if image is not - None. - - Args: - feat_map (np.ndarray, torch.Tensor): The feat_map to convert - with of shape (H, W), where H is the image height and W is - the image width. - img (np.ndarray, optional): The origin image. The format - should be RGB. Defaults to None. - alphas (Union[int, List[int]]): The transparency of origin - image. Defaults to 0.5. - - Returns: - np.ndarray: heatmap - """ - if isinstance(feat_map, torch.Tensor): - feat_map = feat_map.detach().cpu().numpy() - norm_img = np.zeros(feat_map.shape) - norm_img = cv2.normalize(feat_map, norm_img, 0, 255, - cv2.NORM_MINMAX) - norm_img = np.asarray(norm_img, dtype=np.uint8) - heat_img = cv2.applyColorMap(norm_img, cv2.COLORMAP_JET) - heat_img = cv2.cvtColor(heat_img, cv2.COLOR_BGR2RGB) - if img is not None: - heat_img = cv2.addWeighted(img, alpha, heat_img, 1 - alpha, 0) - return heat_img - assert isinstance( tensor_chw, torch.Tensor), (f'`tensor_chw` should be {torch.Tensor} ' @@ -785,11 +867,9 @@ class Visualizer: ], (f'Mode only support "mean", "max", "min", but got {mode}') if mode == 'max': feat_map, _ = torch.max(tensor_chw, dim=0) - elif mode == 'min': - feat_map, _ = torch.min(tensor_chw, dim=0) elif mode == 'mean': feat_map = torch.mean(tensor_chw, dim=0) - return concat_heatmap(feat_map, image, alpha) + return convert_overlay_heatmap(feat_map, image, alpha) if topk <= 0: tensor_chw_channel = tensor_chw.shape[0] @@ -801,16 +881,15 @@ class Visualizer: ' mode parameter or set topk greater than 0 to solve ' 'the error') if tensor_chw_channel == 1: - return concat_heatmap(tensor_chw[0], image, alpha) + return convert_overlay_heatmap(tensor_chw[0], image, alpha) else: tensor_chw = tensor_chw.permute(1, 2, 0).numpy() - norm_img = np.zeros(tensor_chw.shape) norm_img = cv2.normalize(tensor_chw, None, 0, 255, cv2.NORM_MINMAX) heat_img = np.asarray(norm_img, dtype=np.uint8) if image is not None: - heat_img = cv2.addWeighted(image, alpha, heat_img, - 1 - alpha, 0) + heat_img = cv2.addWeighted(image, 1 - alpha, heat_img, + alpha, 0) return heat_img else: row, col = arrangement @@ -833,9 +912,133 @@ class Visualizer: for i in range(topk): axes = fig.add_subplot(row, col, i + 1) axes.axis('off') - axes.imshow(concat_heatmap(topk_tensor[i], image, alpha)) + axes.imshow( + convert_overlay_heatmap(topk_tensor[i], image, alpha)) s, (width, height) = canvas.print_to_buffer() buffer = np.frombuffer(s, dtype='uint8') img_rgba = buffer.reshape(height, width, 4) rgb, alpha = np.split(img_rgba, [3], axis=2) return rgb.astype('uint8') + + def add_config(self, config: Config, **kwargs): + """Record parameters. + + Args: + config (Config): The Config object. + """ + for vis_backend in self._vis_backends.values(): + vis_backend.add_config(config, **kwargs) # type: ignore + + def add_graph(self, model: torch.nn.Module, data_batch: Sequence[dict], + **kwargs) -> None: + """Record graph data. + + Args: + model (torch.nn.Module): Model to draw. + data_batch (Sequence[dict]): Batch of data from dataloader. + """ + for vis_backend in self._vis_backends.values(): + vis_backend.add_graph(model, data_batch, **kwargs) # type: ignore + + def add_image(self, name: str, image: np.ndarray, step: int = 0) -> None: + """Record image. + + Args: + name (str): The unique identifier for the image to save. + image (np.ndarray, optional): The image to be saved. The format + should be RGB. Default to None. + step (int): Global step value to record. Default to 0. + """ + for vis_backend in self._vis_backends.values(): + vis_backend.add_image(name, image, step) # type: ignore + + def add_scalar(self, + name: str, + value: Union[int, float], + step: int = 0, + **kwargs) -> None: + """Record scalar data. + + Args: + name (str): The unique identifier for the scalar to save. + value (float, int): Value to save. + step (int): Global step value to record. Default to 0. + """ + for vis_backend in self._vis_backends.values(): + vis_backend.add_scalar(name, value, step, **kwargs) # type: ignore + + def add_scalars(self, + scalar_dict: dict, + step: int = 0, + file_path: Optional[str] = None, + **kwargs) -> None: + """Record scalars' data. + + Args: + scalar_dict (dict): Key-value pair storing the tag and + corresponding values. + step (int): Global step value to record. Default to 0. + file_path (str, optional): The scalar's data will be + saved to the `file_path` file at the same time + if the `file_path` parameter is specified. + Default to None. + """ + for vis_backend in self._vis_backends.values(): + vis_backend.add_scalars( # type: ignore + scalar_dict, step, file_path, **kwargs) + + def add_datasample(self, + name, + image: np.ndarray, + gt_sample: Optional['BaseDataElement'] = None, + pred_sample: Optional['BaseDataElement'] = None, + draw_gt: bool = True, + draw_pred: bool = True, + show: bool = False, + wait_time: int = 0, + step: int = 0) -> None: + pass + + def close(self) -> None: + """close an opened object.""" + plt.close(self.fig_save) + if self.fig_show is not None: + plt.close(self.fig_show) + for vis_backend in self._vis_backends.values(): + vis_backend.close() # type: ignore + + @classmethod + def get_instance(cls, name: str, **kwargs) -> 'Visualizer': + """Make subclass can get latest created instance by + ``Visualizer.get_current_instance()``. + + Downstream codebase may need to get the latest created instance + without knowing the specific Visualizer type. For example, mmdetection + builds visualizer in runner and some component which cannot access + runner wants to get latest created visualizer. In this case, + the component does not know which type of visualizer has been built + and cannot get target instance. Therefore, :class:`Visualizer` + overrides the :meth:`get_instance` and its subclass will register + the created instance to :attr:`_instance_dict` additionally. + :meth:`get_current_instance` will return the latest created subclass + instance. + + Examples: + >>> class DetLocalVisualizer(Visualizer): + >>> def __init__(self, name): + >>> super().__init__(name) + >>> + >>> visualizer1 = DetLocalVisualizer.get_instance('name1') + >>> visualizer2 = Visualizer.get_current_instance() + >>> visualizer3 = DetLocalVisualizer.get_current_instance() + >>> assert id(visualizer1) == id(visualizer2) == id(visualizer3) + + Args: + name (str): Name of instance. Defaults to ''. + + Returns: + object: Corresponding name instance. + """ + instance = super().get_instance(name, **kwargs) + Visualizer._instance_dict[name] = instance + return instance diff --git a/mmengine/visualization/writer.py b/mmengine/visualization/writer.py deleted file mode 100644 index c4f548ebfde25e58b16350c69af51a8c7485ac4a..0000000000000000000000000000000000000000 --- a/mmengine/visualization/writer.py +++ /dev/null @@ -1,821 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os -import os.path as osp -import time -from abc import ABCMeta, abstractmethod -from typing import Any, List, Optional, Union - -import cv2 -import numpy as np -import torch - -from mmengine.data import BaseDataElement -from mmengine.fileio import dump -from mmengine.registry import VISUALIZERS, WRITERS -from mmengine.utils import TORCH_VERSION, ManagerMixin -from .visualizer import Visualizer - - -class BaseWriter(metaclass=ABCMeta): - """Base class for writer. - - Each writer can inherit ``BaseWriter`` and implement - the required functions. - - Args: - visualizer (dict, :obj:`Visualizer`, optional): - Visualizer instance or dictionary. Default to None. - save_dir (str, optional): The root directory to save - the files produced by the writer. Default to None. - """ - - def __init__(self, - visualizer: Optional[Union[dict, 'Visualizer']] = None, - save_dir: Optional[str] = None): - self._save_dir = save_dir - if self._save_dir: - timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) - self._save_dir = osp.join( - self._save_dir, f'write_data_{timestamp}') # type: ignore - self._visualizer = visualizer - if visualizer: - if isinstance(visualizer, dict): - self._visualizer = VISUALIZERS.build(visualizer) - else: - assert isinstance(visualizer, Visualizer), \ - 'visualizer should be an instance of Visualizer, ' \ - f'but got {type(visualizer)}' - - @property - def visualizer(self) -> 'Visualizer': - """Return the visualizer object. - - You can get the drawing backend through the visualizer property for - custom drawing. - """ - return self._visualizer # type: ignore - - @property - @abstractmethod - def experiment(self) -> Any: - """Return the experiment object associated with this writer. - - The experiment attribute can get the write backend, such as wandb, - tensorboard. If you want to write other data, such as writing a table, - you can directly get the write backend through experiment. - """ - pass - - def add_params(self, params_dict: dict, **kwargs) -> None: - """Record a set of parameters. - - Args: - params_dict (dict): Each key-value pair in the dictionary is the - name of the parameters and it's corresponding value. - """ - pass - - def add_graph(self, model: torch.nn.Module, - input_tensor: Union[torch.Tensor, - List[torch.Tensor]], **kwargs) -> None: - """Record graph. - - Args: - model (torch.nn.Module): Model to draw. - input_tensor (torch.Tensor, list[torch.Tensor]): A variable - or a tuple of variables to be fed. - """ - pass - - def add_image(self, - name: str, - image: Optional[np.ndarray] = None, - gt_sample: Optional['BaseDataElement'] = None, - pred_sample: Optional['BaseDataElement'] = None, - draw_gt: bool = True, - draw_pred: bool = True, - step: int = 0, - **kwargs) -> None: - """Record image. - - Args: - name (str): The unique identifier for the image to save. - image (np.ndarray, optional): The image to be saved. The format - should be RGB. Default to None. - gt_sample (:obj:`BaseDataElement`, optional): The ground truth data - structure of OpenMMlab. Default to None. - pred_sample (:obj:`BaseDataElement`, optional): The predicted - result data structure of OpenMMlab. Default to None. - draw_gt (bool): Whether to draw the ground truth. Default: True. - draw_pred (bool): Whether to draw the predicted result. - Default to True. - step (int): Global step value to record. Default to 0. - """ - pass - - def add_scalar(self, - name: str, - value: Union[int, float], - step: int = 0, - **kwargs) -> None: - """Record scalar. - - Args: - name (str): The unique identifier for the scalar to save. - value (float, int): Value to save. - step (int): Global step value to record. Default to 0. - """ - pass - - def add_scalars(self, - scalar_dict: dict, - step: int = 0, - file_path: Optional[str] = None, - **kwargs) -> None: - """Record scalars' data. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. - step (int): Global step value to record. Default to 0. - file_path (str, optional): The scalar's data will be - saved to the `file_path` file at the same time - if the `file_path` parameter is specified. - Default to None. - """ - pass - - def close(self) -> None: - """close an opened object.""" - pass - - -@WRITERS.register_module() -class LocalWriter(BaseWriter): - """Local write class. - - It can write image, hyperparameters, scalars, etc. - to the local hard disk. You can get the drawing backend - through the visualizer property for custom drawing. - - Examples: - >>> from mmengine.visualization import LocalWriter - >>> import numpy as np - >>> local_writer = LocalWriter(dict(type='DetVisualizer'),\ - save_dir='temp_dir') - >>> img=np.random.randint(0, 256, size=(10, 10, 3)) - >>> local_writer.add_image('img', img) - >>> local_writer.add_scaler('mAP', 0.6) - >>> local_writer.add_scalars({'loss': [1, 2, 3], 'acc': 0.8}) - >>> local_writer.add_params(dict(lr=0.1, mode='linear')) - - >>> local_writer.visualizer.draw_bboxes(np.array([0, 0, 1, 1]), \ - edgecolors='g') - >>> local_writer.add_image('img', \ - local_writer.visualizer.get_image()) - - Args: - save_dir (str): The root directory to save the files - produced by the writer. - visualizer (dict, :obj:`Visualizer`, optional): Visualizer - instance or dictionary. Default to None - img_save_dir (str): The directory to save images. - Default to 'writer_image'. - params_save_file (str): The file to save parameters. - Default to 'parameters.yaml'. - scalar_save_file (str): The file to save scalar values. - Default to 'scalars.json'. - img_show (bool): Whether to show the image when calling add_image. - Default to False. - """ - - def __init__(self, - save_dir: str, - visualizer: Optional[Union[dict, 'Visualizer']] = None, - img_save_dir: str = 'writer_image', - params_save_file: str = 'parameters.yaml', - scalar_save_file: str = 'scalars.json', - img_show: bool = False): - assert params_save_file.split('.')[-1] == 'yaml' - assert scalar_save_file.split('.')[-1] == 'json' - super(LocalWriter, self).__init__(visualizer, save_dir) - os.makedirs(self._save_dir, exist_ok=True) # type: ignore - self._img_save_dir = osp.join( - self._save_dir, # type: ignore - img_save_dir) - self._scalar_save_file = osp.join( - self._save_dir, # type: ignore - scalar_save_file) - self._params_save_file = osp.join( - self._save_dir, # type: ignore - params_save_file) - self._img_show = img_show - - @property - def experiment(self) -> 'LocalWriter': - """Return the experiment object associated with this writer.""" - return self - - def add_params(self, params_dict: dict, **kwargs) -> None: - """Record parameters to disk. - - Args: - params_dict (dict): The dict of parameters to save. - """ - assert isinstance(params_dict, dict) - self._dump(params_dict, self._params_save_file, 'yaml') - - def add_image(self, - name: str, - image: Optional[np.ndarray] = None, - gt_sample: Optional['BaseDataElement'] = None, - pred_sample: Optional['BaseDataElement'] = None, - draw_gt: bool = True, - draw_pred: bool = True, - step: int = 0, - **kwargs) -> None: - """Record image to disk. - - Args: - name (str): The unique identifier for the image to save. - image (np.ndarray, optional): The image to be saved. The format - should be RGB. Default to None. - gt_sample (:obj:`BaseDataElement`, optional): The ground truth data - structure of OpenMMlab. Default to None. - pred_sample (:obj:`BaseDataElement`, optional): The predicted - result data structure of OpenMMlab. Default to None. - draw_gt (bool): Whether to draw the ground truth. Default to True. - draw_pred (bool): Whether to draw the predicted result. - Default to True. - step (int): Global step value to record. Default to 0. - """ - assert self.visualizer, 'Please instantiate the visualizer ' \ - 'object with initialization parameters.' - self.visualizer.draw(image, gt_sample, pred_sample, draw_gt, draw_pred) - if self._img_show: - self.visualizer.show() - else: - drawn_image = cv2.cvtColor(self.visualizer.get_image(), - cv2.COLOR_RGB2BGR) - os.makedirs(self._img_save_dir, exist_ok=True) - save_file_name = f'{name}_{step}.png' - cv2.imwrite( - osp.join(self._img_save_dir, save_file_name), drawn_image) - - def add_scalar(self, - name: str, - value: Union[int, float], - step: int = 0, - **kwargs) -> None: - """Add scalar data to disk. - - Args: - name (str): The unique identifier for the scalar to save. - value (float, int): Value to save. - step (int): Global step value to record. Default to 0. - """ - self._dump({name: value, 'step': step}, self._scalar_save_file, 'json') - - def add_scalars(self, - scalar_dict: dict, - step: int = 0, - file_path: Optional[str] = None, - **kwargs) -> None: - """Record scalars. The scalar dict will be written to the default and - specified files if ``file_name`` is specified. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. - step (int): Global step value to record. Default to 0. - file_path (str, optional): The scalar's data will be - saved to the ``file_path`` file at the same time - if the ``file_path`` parameter is specified. - Default to None. - """ - assert isinstance(scalar_dict, dict) - scalar_dict.setdefault('step', step) - if file_path is not None: - assert file_path.split('.')[-1] == 'json' - new_save_file_path = osp.join( - self._save_dir, # type: ignore - file_path) - assert new_save_file_path != self._scalar_save_file, \ - '"file_path" and "scalar_save_file" have the same name, ' \ - 'please set "file_path" to another value' - self._dump(scalar_dict, new_save_file_path, 'json') - self._dump(scalar_dict, self._scalar_save_file, 'json') - - def _dump(self, value_dict: dict, file_path: str, - file_format: str) -> None: - """dump dict to file. - - Args: - value_dict (dict) : Save dict data. - file_path (str): The file path to save data. - file_format (str): The file format to save data. - """ - with open(file_path, 'a+') as f: - dump(value_dict, f, file_format=file_format) - f.write('\n') - - -@WRITERS.register_module() -class WandbWriter(BaseWriter): - """Write various types of data to wandb. - - Examples: - >>> from mmengine.visualization import WandbWriter - >>> import numpy as np - >>> wandb_writer = WandbWriter(dict(type='DetVisualizer')) - >>> img=np.random.randint(0, 256, size=(10, 10, 3)) - >>> wandb_writer.add_image('img', img) - >>> wandb_writer.add_scaler('mAP', 0.6) - >>> wandb_writer.add_scalars({'loss': [1, 2, 3],'acc': 0.8}) - >>> wandb_writer.add_params(dict(lr=0.1, mode='linear')) - - >>> wandb_writer.visualizer.draw_bboxes(np.array([0, 0, 1, 1]), \ - edgecolors='g') - >>> wandb_writer.add_image('img', \ - wandb_writer.visualizer.get_image()) - - >>> wandb_writer = WandbWriter() - >>> assert wandb_writer.visualizer is None - >>> wandb_writer.add_image('img', img) - - Args: - init_kwargs (dict, optional): wandb initialization - input parameters. Default to None. - commit: (bool, optional) Save the metrics dict to the wandb server - and increment the step. If false `wandb.log` just - updates the current metrics dict with the row argument - and metrics won't be saved until `wandb.log` is called - with `commit=True`. Default to True. - visualizer (dict, :obj:`Visualizer`, optional): - Visualizer instance or dictionary. Default to None. - save_dir (str, optional): The root directory to save the files - produced by the writer. Default to None. - """ - - def __init__(self, - init_kwargs: Optional[dict] = None, - commit: Optional[bool] = True, - visualizer: Optional[Union[dict, 'Visualizer']] = None, - save_dir: Optional[str] = None): - super(WandbWriter, self).__init__(visualizer, save_dir) - self._commit = commit - self._wandb = self._setup_env(init_kwargs) - - @property - def experiment(self): - """Return wandb object. - - The experiment attribute can get the wandb backend, If you want to - write other data, such as writing a table, you can directly get the - wandb backend through experiment. - """ - return self._wandb - - def _setup_env(self, init_kwargs: Optional[dict] = None) -> Any: - """Setup env. - - Args: - init_kwargs (dict): The init args. - - Return: - :obj:`wandb` - """ - try: - import wandb - except ImportError: - raise ImportError( - 'Please run "pip install wandb" to install wandb') - if init_kwargs: - wandb.init(**init_kwargs) - else: - wandb.init() - - return wandb - - def add_params(self, params_dict: dict, **kwargs) -> None: - """Record a set of parameters to be compared in wandb. - - Args: - params_dict (dict): Each key-value pair in the dictionary - is the name of the parameters and it's - corresponding value. - """ - assert isinstance(params_dict, dict) - self._wandb.log(params_dict, commit=self._commit) - - def add_image(self, - name: str, - image: Optional[np.ndarray] = None, - gt_sample: Optional['BaseDataElement'] = None, - pred_sample: Optional['BaseDataElement'] = None, - draw_gt: bool = True, - draw_pred: bool = True, - step: int = 0, - **kwargs) -> None: - """Record image to wandb. - - Args: - name (str): The unique identifier for the image to save. - image (np.ndarray, optional): The image to be saved. The format - should be RGB. Default to None. - gt_sample (:obj:`BaseDataElement`, optional): The ground truth data - structure of OpenMMlab. Default to None. - pred_sample (:obj:`BaseDataElement`, optional): The predicted - result data structure of OpenMMlab. Default to None. - draw_gt (bool): Whether to draw the ground truth. Default: True. - draw_pred (bool): Whether to draw the predicted result. - Default to True. - step (int): Global step value to record. Default to 0. - """ - if self.visualizer: - self.visualizer.draw(image, gt_sample, pred_sample, draw_gt, - draw_pred) - self._wandb.log({name: self.visualizer.get_image()}, - commit=self._commit, - step=step) - else: - self.add_image_to_wandb(name, image, gt_sample, pred_sample, - draw_gt, draw_pred, step, **kwargs) - - def add_scalar(self, - name: str, - value: Union[int, float], - step: int = 0, - **kwargs) -> None: - """Record scalar data to wandb. - - Args: - name (str): The unique identifier for the scalar to save. - value (float, int): Value to save. - step (int): Global step value to record. Default to 0. - """ - self._wandb.log({name: value}, commit=self._commit, step=step) - - def add_scalars(self, - scalar_dict: dict, - step: int = 0, - file_path: Optional[str] = None, - **kwargs) -> None: - """Record scalar's data to wandb. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. - step (int): Global step value to record. Default to 0. - file_path (str, optional): Useless parameter. Just for - interface unification. Default to None. - """ - self._wandb.log(scalar_dict, commit=self._commit, step=step) - - def add_image_to_wandb(self, - name: str, - image: np.ndarray, - gt_sample: Optional['BaseDataElement'] = None, - pred_sample: Optional['BaseDataElement'] = None, - draw_gt: bool = True, - draw_pred: bool = True, - step: int = 0, - **kwargs) -> None: - """Record image to wandb. - - Args: - name (str): The unique identifier for the image to save. - image (np.ndarray): The image to be saved. The format - should be BGR. - gt_sample (:obj:`BaseDataElement`, optional): The ground truth data - structure of OpenMMlab. Default to None. - pred_sample (:obj:`BaseDataElement`, optional): The predicted - result data structure of OpenMMlab. Default to None. - draw_gt (bool): Whether to draw the ground truth. Default to True. - draw_pred (bool): Whether to draw the predicted result. - Default to True. - step (int): Global step value to record. Default to 0. - """ - raise NotImplementedError() - - def close(self) -> None: - """close an opened wandb object.""" - if hasattr(self, '_wandb'): - self._wandb.join() - - -@WRITERS.register_module() -class TensorboardWriter(BaseWriter): - """Tensorboard write class. It can write images, hyperparameters, scalars, - etc. to a tensorboard file. - - Its drawing function is provided by Visualizer. - - Examples: - >>> from mmengine.visualization import TensorboardWriter - >>> import numpy as np - >>> tensorboard_writer = TensorboardWriter(dict(type='DetVisualizer'),\ - save_dir='temp_dir') - >>> img=np.random.randint(0, 256, size=(10, 10, 3)) - >>> tensorboard_writer.add_image('img', img) - >>> tensorboard_writer.add_scaler('mAP', 0.6) - >>> tensorboard_writer.add_scalars({'loss': 0.1,'acc':0.8}) - >>> tensorboard_writer.add_params(dict(lr=0.1, mode='linear')) - - >>> tensorboard_writer.visualizer.draw_bboxes(np.array([0, 0, 1, 1]), \ - edgecolors='g') - >>> tensorboard_writer.add_image('img', \ - tensorboard_writer.visualizer.get_image()) - - Args: - save_dir (str): The root directory to save the files - produced by the writer. - visualizer (dict, :obj:`Visualizer`, optional): Visualizer instance - or dictionary. Default to None. - log_dir (str): Save directory location. Default to 'tf_writer'. - """ - - def __init__(self, - save_dir: str, - visualizer: Optional[Union[dict, 'Visualizer']] = None, - log_dir: str = 'tf_logs'): - super(TensorboardWriter, self).__init__(visualizer, save_dir) - self._tensorboard = self._setup_env(log_dir) - - def _setup_env(self, log_dir: str): - """Setup env. - - Args: - log_dir (str): Save directory location. Default 'tf_writer'. - - Return: - :obj:`SummaryWriter` - """ - if TORCH_VERSION == 'parrots': - try: - from tensorboardX import SummaryWriter - except ImportError: - raise ImportError('Please install tensorboardX to use ' - 'TensorboardLoggerHook.') - else: - try: - from torch.utils.tensorboard import SummaryWriter - except ImportError: - raise ImportError( - 'Please run "pip install future tensorboard" to install ' - 'the dependencies to use torch.utils.tensorboard ' - '(applicable to PyTorch 1.1 or higher)') - - self.log_dir = osp.join(self._save_dir, log_dir) # type: ignore - return SummaryWriter(self.log_dir) - - @property - def experiment(self): - """Return Tensorboard object.""" - return self._tensorboard - - def add_graph(self, model: torch.nn.Module, - input_tensor: Union[torch.Tensor, - List[torch.Tensor]], **kwargs) -> None: - """Record graph data to tensorboard. - - Args: - model (torch.nn.Module): Model to draw. - input_tensor (torch.Tensor, list[torch.Tensor]): A variable - or a tuple of variables to be fed. - """ - if isinstance(input_tensor, list): - for array in input_tensor: - assert array.ndim == 4 - assert isinstance(array, torch.Tensor) - else: - assert isinstance(input_tensor, - torch.Tensor) and input_tensor.ndim == 4 - self._tensorboard.add_graph(model, input_tensor) - - def add_params(self, params_dict: dict, **kwargs) -> None: - """Record a set of hyperparameters to be compared in TensorBoard. - - Args: - params_dict (dict): Each key-value pair in the dictionary is the - name of the hyper parameter and it's corresponding value. - The type of the value can be one of `bool`, `string`, - `float`, `int`, or `None`. - """ - assert isinstance(params_dict, dict) - self._tensorboard.add_hparams(params_dict, {}) - - def add_image(self, - name: str, - image: Optional[np.ndarray] = None, - gt_sample: Optional['BaseDataElement'] = None, - pred_sample: Optional['BaseDataElement'] = None, - draw_gt: bool = True, - draw_pred: bool = True, - step: int = 0, - **kwargs) -> None: - """Record image to tensorboard. - - Args: - name (str): The unique identifier for the image to save. - image (np.ndarray, optional): The image to be saved. The format - should be RGB. Default to None. - gt_sample (:obj:`BaseDataElement`, optional): The ground truth data - structure of OpenMMlab. Default to None. - pred_sample (:obj:`BaseDataElement`, optional): The predicted - result data structure of OpenMMlab. Default to None. - draw_gt (bool): Whether to draw the ground truth. Default to True. - draw_pred (bool): Whether to draw the predicted result. - Default to True. - step (int): Global step value to record. Default to 0. - """ - assert self.visualizer, 'Please instantiate the visualizer ' \ - 'object with initialization parameters.' - self.visualizer.draw(image, gt_sample, pred_sample, draw_gt, draw_pred) - self._tensorboard.add_image( - name, self.visualizer.get_image(), step, dataformats='HWC') - - def add_scalar(self, - name: str, - value: Union[int, float], - step: int = 0, - **kwargs) -> None: - """Record scalar data to summary. - - Args: - name (str): The unique identifier for the scalar to save. - value (float, int): Value to save. - step (int): Global step value to record. Default to 0. - """ - self._tensorboard.add_scalar(name, value, step) - - def add_scalars(self, - scalar_dict: dict, - step: int = 0, - file_path: Optional[str] = None, - **kwargs) -> None: - """Record scalar's data to summary. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. - step (int): Global step value to record. Default to 0. - file_path (str, optional): Useless parameter. Just for - interface unification. Default to None. - """ - assert isinstance(scalar_dict, dict) - assert 'step' not in scalar_dict, 'Please set it directly ' \ - 'through the step parameter' - for key, value in scalar_dict.items(): - self.add_scalar(key, value, step) - - def close(self): - """close an opened tensorboard object.""" - if hasattr(self, '_tensorboard'): - self._tensorboard.close() - - -class ComposedWriter(ManagerMixin): - """Wrapper class to compose multiple a subclass of :class:`BaseWriter` - instances. By inheriting ManagerMixin, it can be accessed anywhere once - instantiated. - - Examples: - >>> from mmengine.visualization import ComposedWriter - >>> import numpy as np - >>> composed_writer= ComposedWriter.get_instance( \ - 'composed_writer', writers=[dict(type='LocalWriter', \ - visualizer=dict(type='DetVisualizer'), \ - save_dir='temp_dir'), dict(type='WandbWriter')]) - >>> img=np.random.randint(0, 256, size=(10, 10, 3)) - >>> composed_writer.add_image('img', img) - >>> composed_writer.add_scalar('mAP', 0.6) - >>> composed_writer.add_scalars({'loss': 0.1,'acc':0.8}) - >>> composed_writer.add_params(dict(lr=0.1, mode='linear')) - - Args: - name (str): The name of the instance. Defaults: 'composed_writer'. - writers (list, optional): The writers to compose. Default to None - """ - - def __init__(self, - name: str = 'composed_writer', - writers: Optional[List[Union[dict, 'BaseWriter']]] = None): - super().__init__(name) - self._writers = [] - if writers is not None: - assert isinstance(writers, list) - for writer in writers: - if isinstance(writer, dict): - self._writers.append(WRITERS.build(writer)) - else: - assert isinstance(writer, BaseWriter), \ - f'writer should be an instance of a subclass of ' \ - f'BaseWriter, but got {type(writer)}' - self._writers.append(writer) - - def __len__(self): - return len(self._writers) - - def get_writer(self, index: int) -> 'BaseWriter': - """Returns the writer object corresponding to the specified index.""" - return self._writers[index] - - def get_experiment(self, index: int) -> Any: - """Returns the writer's experiment object corresponding to the - specified index.""" - return self._writers[index].experiment - - def get_visualizer(self, index: int) -> 'Visualizer': - """Returns the writer's visualizer object corresponding to the - specified index.""" - return self._writers[index].visualizer - - def add_params(self, params_dict: dict, **kwargs): - """Record parameters. - - Args: - params_dict (dict): The dictionary of parameters to save. - """ - for writer in self._writers: - writer.add_params(params_dict, **kwargs) - - def add_graph(self, model: torch.nn.Module, - input_array: Union[torch.Tensor, - List[torch.Tensor]], **kwargs) -> None: - """Record graph data. - - Args: - model (torch.nn.Module): Model to draw. - input_array (torch.Tensor, list[torch.Tensor]): A variable - or a tuple of variables to be fed. - """ - for writer in self._writers: - writer.add_graph(model, input_array, **kwargs) - - def add_image(self, - name: str, - image: Optional[np.ndarray] = None, - gt_sample: Optional['BaseDataElement'] = None, - pred_sample: Optional['BaseDataElement'] = None, - draw_gt: bool = True, - draw_pred: bool = True, - step: int = 0, - **kwargs) -> None: - """Record image. - - Args: - name (str): The unique identifier for the image to save. - image (np.ndarray, optional): The image to be saved. The format - should be RGB. Default to None. - gt_sample (:obj:`BaseDataElement`, optional): The ground truth data - structure of OpenMMlab. Default to None. - pred_sample (:obj:`BaseDataElement`, optional): The predicted - result data structure of OpenMMlab. Default to None. - draw_gt (bool): Whether to draw the ground truth. Default to True. - draw_pred (bool): Whether to draw the predicted result. - Default to True. - step (int): Global step value to record. Default to 0. - """ - for writer in self._writers: - writer.add_image(name, image, gt_sample, pred_sample, draw_gt, - draw_pred, step, **kwargs) - - def add_scalar(self, - name: str, - value: Union[int, float], - step: int = 0, - **kwargs) -> None: - """Record scalar data. - - Args: - name (str): The unique identifier for the scalar to save. - value (float, int): Value to save. - step (int): Global step value to record. Default to 0. - """ - for writer in self._writers: - writer.add_scalar(name, value, step, **kwargs) - - def add_scalars(self, - scalar_dict: dict, - step: int = 0, - file_path: Optional[str] = None, - **kwargs) -> None: - """Record scalars' data. - - Args: - scalar_dict (dict): Key-value pair storing the tag and - corresponding values. - step (int): Global step value to record. Default to 0. - file_path (str, optional): The scalar's data will be - saved to the `file_path` file at the same time - if the `file_path` parameter is specified. - Default to None. - """ - for writer in self._writers: - writer.add_scalars(scalar_dict, step, file_path, **kwargs) - - def close(self) -> None: - """close an opened object.""" - for writer in self._writers: - writer.close() diff --git a/tests/test_data/test_instance_data.py b/tests/test_data/test_instance_data.py new file mode 100644 index 0000000000000000000000000000000000000000..17fc2bcaa9b7b391a97d75b406b4cecf3715b784 --- /dev/null +++ b/tests/test_data/test_instance_data.py @@ -0,0 +1,105 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random +from unittest import TestCase + +import numpy as np +import pytest +import torch + +from mmengine.data import BaseDataElement, InstanceData + + +class TestInstanceData(TestCase): + + def setup_data(self): + metainfo = dict( + img_id=random.randint(0, 100), + img_shape=(random.randint(400, 600), random.randint(400, 600))) + instances_infos = [1] * 5 + bboxes = torch.rand((5, 4)) + labels = np.random.rand(5) + instance_data = InstanceData( + metainfo=metainfo, + bboxes=bboxes, + labels=labels, + instances_infos=instances_infos) + return instance_data + + def test_set_data(self): + instance_data = self.setup_data() + + # test set '_metainfo_fields' or '_data_fields' + with self.assertRaises(AttributeError): + instance_data._metainfo_fields = 1 + with self.assertRaises(AttributeError): + instance_data._data_fields = 1 + + # value only supports (torch.Tensor, np.ndarray, list) + with self.assertRaises(AssertionError): + instance_data.v = 'value' + + # The data length in InstanceData must be the same + with self.assertRaises(AssertionError): + instance_data.keypoints = torch.rand((17, 2)) + + instance_data.keypoints = torch.rand((5, 2)) + assert 'keypoints' in instance_data + + def test_getitem(self): + instance_data = InstanceData() + # length must be greater than 0 + with self.assertRaises(AssertionError): + instance_data[1] + + instance_data = self.setup_data() + assert len(instance_data) == 5 + slice_instance_data = instance_data[:2] + assert len(slice_instance_data) == 2 + + # assert the index should in 0 ~ len(instance_data) -1 + with pytest.raises(IndexError): + instance_data[5] + + # isinstance(str, slice, int, torch.LongTensor, torch.BoolTensor) + item = torch.Tensor([1, 2, 3, 4]) # float + with pytest.raises(AssertionError): + instance_data[item] + + # when input is a bool tensor, The shape of + # the input at index 0 should equal to + # the value length in instance_data_field + with pytest.raises(AssertionError): + instance_data[item.bool()] + + # test Longtensor + long_tensor = torch.randint(5, (2, )) + long_index_instance_data = instance_data[long_tensor] + assert len(long_index_instance_data) == len(long_tensor) + + # test bool tensor + bool_tensor = torch.rand(5) > 0.5 + bool_index_instance_data = instance_data[bool_tensor] + assert len(bool_index_instance_data) == bool_tensor.sum() + + def test_cat(self): + instance_data_1 = self.setup_data() + instance_data_2 = self.setup_data() + cat_instance_data = InstanceData.cat( + [instance_data_1, instance_data_2]) + assert len(cat_instance_data) == 10 + + # All inputs must be InstanceData + instance_data_2 = BaseDataElement( + bboxes=torch.rand((5, 4)), labels=torch.rand((5, ))) + with self.assertRaises(AssertionError): + InstanceData.cat([instance_data_1, instance_data_2]) + + # Input List length must be greater than 0 + with self.assertRaises(AssertionError): + InstanceData.cat([]) + + def test_len(self): + instance_data = self.setup_data() + assert len(instance_data) == 5 + instance_data = InstanceData() + assert len(instance_data) == 0 diff --git a/tests/test_evaluator/test_evaluator.py b/tests/test_evaluator/test_evaluator.py index 61be034be2d5d3934f3d23eb60fd7eaa0eec143b..5364c0679997d8dfba7630056f3b8ab93e955cd4 100644 --- a/tests/test_evaluator/test_evaluator.py +++ b/tests/test_evaluator/test_evaluator.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import math -from typing import Any, Dict, List, Optional, Sequence, Tuple +from typing import Dict, List, Optional, Sequence from unittest import TestCase import numpy as np @@ -40,7 +40,7 @@ class ToyMetric(BaseMetric): def process(self, data_batch, predictions): results = [{ 'pred': pred.get('pred'), - 'label': data[1].get('label') + 'label': data['data_sample'].get('label') } for pred, data in zip(predictions, data_batch)] self.results.extend(results) @@ -66,7 +66,7 @@ class NonPrefixedMetric(BaseMetric): """Evaluator with unassigned `default_prefix` to test the warning information.""" - def process(self, data_batch: Sequence[Tuple[Any, dict]], + def process(self, data_batch: Sequence[dict], predictions: Sequence[dict]) -> None: pass @@ -79,8 +79,11 @@ def generate_test_results(size, batch_size, pred, label): bs_residual = size % batch_size for i in range(num_batch): bs = bs_residual if i == num_batch - 1 else batch_size - data_batch = [(np.zeros((3, 10, 10)), BaseDataElement(label=label)) - for _ in range(bs)] + data_batch = [ + dict( + inputs=np.zeros((3, 10, 10)), + data_sample=BaseDataElement(label=label)) for _ in range(bs) + ] predictions = [BaseDataElement(pred=pred) for _ in range(bs)] yield (data_batch, predictions) @@ -228,7 +231,10 @@ class TestEvaluator(TestCase): size = 10 - all_data = [(np.zeros((3, 10, 10)), BaseDataElement(label=1)) - for _ in range(size)] + all_data = [ + dict( + inputs=np.zeros((3, 10, 10)), + data_sample=BaseDataElement(label=1)) for _ in range(size) + ] all_predictions = [BaseDataElement(pred=0) for _ in range(size)] evaluator.offline_evaluate(all_data, all_predictions) diff --git a/tests/test_hook/test_hook.py b/tests/test_hook/test_hook.py index db80ed4a6301b469128ec07168655d5ca83a7dc3..771c54f615079adfffccfd7ea058aa90c0d6ad74 100644 --- a/tests/test_hook/test_hook.py +++ b/tests/test_hook/test_hook.py @@ -157,18 +157,17 @@ class TestHook: def test_end_of_epoch(self): hook = Hook() - runner = Mock() # last inner iter batch_idx = 1 - runner.cur_dataloader.__len__ = Mock(return_value=2) - runner.cur_dataloader.__len__ = Mock(return_value=2) - return_val = hook.end_of_epoch(runner, batch_idx) + dataloader = Mock() + dataloader.__len__ = Mock(return_value=2) + return_val = hook.end_of_epoch(dataloader, batch_idx) assert return_val # not the last inner iter batch_idx = 0 - return_val = hook.end_of_epoch(runner, batch_idx) + return_val = hook.end_of_epoch(dataloader, batch_idx) assert not return_val def test_is_last_train_epoch(self): diff --git a/tests/test_hook/test_logger_hook.py b/tests/test_hook/test_logger_hook.py index cac2e45b170cf70e03de5acf0d2356136bdf1603..b2b617991288135ee89cd8f6d72336d67fab9825 100644 --- a/tests/test_hook/test_logger_hook.py +++ b/tests/test_hook/test_logger_hook.py @@ -61,7 +61,6 @@ class TestLoggerHook: assert logger_hook.json_log_path == osp.join('work_dir', 'timestamp.log.json') assert logger_hook.start_iter == runner.iter - runner.writer.add_params.assert_called() def test_after_run(self, tmp_path): out_dir = tmp_path / 'out_dir' @@ -111,7 +110,7 @@ class TestLoggerHook: # Test end of the epoch. logger_hook = LoggerHook(by_epoch=True, ignore_last=False) logger_hook._log_train = MagicMock() - runner.cur_dataloader = [0] * 5 + runner.train_loop.dataloader = [0] * 5 batch_idx = 4 logger_hook.after_train_iter(runner, batch_idx=batch_idx) logger_hook._log_train.assert_called() @@ -151,7 +150,7 @@ class TestLoggerHook: logger_hook._collect_info = MagicMock(return_value=train_infos) logger_hook._log_train(runner) # Verify that the correct variables have been written. - runner.writer.add_scalars.assert_called_with( + runner.visualizer.add_scalars.assert_called_with( train_infos, step=11, file_path='tmp.json') # Verify that the correct context have been logged. out, _ = capsys.readouterr() @@ -161,7 +160,7 @@ class TestLoggerHook: eta_str = str(datetime.timedelta(seconds=int(eta_second))) if by_epoch: if torch.cuda.is_available(): - log_str = 'Epoch [2][2/5]\t' \ + log_str = 'Epoch [2][2/5] ' \ f"lr: {train_infos['lr']:.3e} " \ f"momentum: {train_infos['momentum']:.3e}, " \ f'eta: {eta_str}, ' \ @@ -170,7 +169,7 @@ class TestLoggerHook: f'memory: 100, ' \ f"loss_cls: {train_infos['loss_cls']:.4f}\n" else: - log_str = 'Epoch [2][2/5]\t' \ + log_str = 'Epoch [2][2/5] ' \ f"lr: {train_infos['lr']:.3e} " \ f"momentum: {train_infos['momentum']:.3e}, " \ f'eta: {eta_str}, ' \ @@ -180,7 +179,7 @@ class TestLoggerHook: assert out == log_str else: if torch.cuda.is_available(): - log_str = 'Iter [11/50]\t' \ + log_str = 'Iter [11/50] ' \ f"lr: {train_infos['lr']:.3e} " \ f"momentum: {train_infos['momentum']:.3e}, " \ f'eta: {eta_str}, ' \ @@ -189,7 +188,7 @@ class TestLoggerHook: f'memory: 100, ' \ f"loss_cls: {train_infos['loss_cls']:.4f}\n" else: - log_str = 'Iter [11/50]\t' \ + log_str = 'Iter [11/50] ' \ f"lr: {train_infos['lr']:.3e} " \ f"momentum: {train_infos['momentum']:.3e}, " \ f'eta: {eta_str}, ' \ @@ -209,14 +208,14 @@ class TestLoggerHook: logger_hook._log_val(runner) # Verify that the correct context have been logged. out, _ = capsys.readouterr() - runner.writer.add_scalars.assert_called_with( + runner.visualizer.add_scalars.assert_called_with( metric, step=11, file_path='tmp.json') if by_epoch: - assert out == 'Epoch(val) [1][5]\taccuracy: 0.9000, ' \ + assert out == 'Epoch(val) [1][5] accuracy: 0.9000, ' \ 'data_time: 1.0000\n' else: - assert out == 'Iter(val) [5]\taccuracy: 0.9000, ' \ + assert out == 'Iter(val) [5] accuracy: 0.9000, ' \ 'data_time: 1.0000\n' def test_get_window_size(self): @@ -341,7 +340,9 @@ class TestLoggerHook: def _setup_runner(self): runner = MagicMock() runner.epoch = 1 - runner.cur_dataloader = [0] * 5 + runner.train_loop.dataloader = [0] * 5 + runner.val_loop.dataloader = [0] * 5 + runner.test_loop.dataloader = [0] * 5 runner.iter = 10 runner.train_loop.max_iters = 50 logger = logging.getLogger() diff --git a/tests/test_hook/test_naive_visualization_hook.py b/tests/test_hook/test_naive_visualization_hook.py index 4d75fedb1588b865071a88846ee7bbfae880ab88..e06dd281ba23e07c835d5ea620c1fc45b9218326 100644 --- a/tests/test_hook/test_naive_visualization_hook.py +++ b/tests/test_hook/test_naive_visualization_hook.py @@ -12,74 +12,60 @@ class TestNaiveVisualizationHook: def test_after_train_iter(self): naive_visualization_hook = NaiveVisualizationHook() runner = Mock(iter=1) - runner.writer.add_image = Mock() + runner.visualizer.add_image = Mock() inputs = torch.randn(1, 3, 15, 15) batch_idx = 10 # test with normalize, resize, pad - gt_datasamples = [ - BaseDataElement( - metainfo=dict( - img_norm_cfg=dict( - mean=(0, 0, 0), std=(0.5, 0.5, 0.5), to_bgr=True), - scale=(10, 10), - pad_shape=(15, 15, 3), - ori_height=5, - ori_width=5, - img_path='tmp.jpg')) - ] + gt_datasamples = BaseDataElement( + metainfo=dict( + img_norm_cfg=dict( + mean=(0, 0, 0), std=(0.5, 0.5, 0.5), to_bgr=True), + scale=(10, 10), + pad_shape=(15, 15, 3), + ori_height=5, + ori_width=5, + img_path='tmp.jpg')) pred_datasamples = [BaseDataElement()] - data_batch = (inputs, gt_datasamples) + data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)] naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) # test with resize, pad - gt_datasamples = [ - BaseDataElement( - metainfo=dict( - scale=(10, 10), - pad_shape=(15, 15, 3), - ori_height=5, - ori_width=5, - img_path='tmp.jpg')), - ] + gt_datasamples = BaseDataElement( + metainfo=dict( + scale=(10, 10), + pad_shape=(15, 15, 3), + ori_height=5, + ori_width=5, + img_path='tmp.jpg')) pred_datasamples = [BaseDataElement()] - data_batch = (inputs, gt_datasamples) + data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)] naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) # test with only resize - gt_datasamples = [ - BaseDataElement( - metainfo=dict( - scale=(15, 15), - ori_height=5, - ori_width=5, - img_path='tmp.jpg')), - ] + gt_datasamples = BaseDataElement( + metainfo=dict( + scale=(15, 15), ori_height=5, ori_width=5, img_path='tmp.jpg')) pred_datasamples = [BaseDataElement()] - data_batch = (inputs, gt_datasamples) + data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)] naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) # test with only pad - gt_datasamples = [ - BaseDataElement( - metainfo=dict( - pad_shape=(15, 15, 3), - ori_height=5, - ori_width=5, - img_path='tmp.jpg')), - ] + gt_datasamples = BaseDataElement( + metainfo=dict( + pad_shape=(15, 15, 3), + ori_height=5, + ori_width=5, + img_path='tmp.jpg')) pred_datasamples = [BaseDataElement()] - data_batch = (inputs, gt_datasamples) + data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)] naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) # test no transform - gt_datasamples = [ - BaseDataElement( - metainfo=dict(ori_height=15, ori_width=15, - img_path='tmp.jpg')), - ] + gt_datasamples = BaseDataElement( + metainfo=dict(ori_height=15, ori_width=15, img_path='tmp.jpg')) pred_datasamples = [BaseDataElement()] - data_batch = (inputs, gt_datasamples) + data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)] naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) diff --git a/tests/test_hook/test_sampler_seed_hook.py b/tests/test_hook/test_sampler_seed_hook.py index 9d19edf713cafc665a8b75a8a800b3ecbd6a4f68..c1bf8b543aa1a7e5bf8216a32b55692b9b4d9bb8 100644 --- a/tests/test_hook/test_sampler_seed_hook.py +++ b/tests/test_hook/test_sampler_seed_hook.py @@ -12,17 +12,18 @@ class TestDistSamplerSeedHook: # Test dataset sampler runner = Mock() runner.epoch = 1 - runner.cur_dataloader = Mock() - runner.cur_dataloader.sampler = Mock() - runner.cur_dataloader.sampler.set_epoch = Mock() + runner.train_loop.dataloader = Mock() + runner.train_loop.dataloader.sampler = Mock() + runner.train_loop.dataloader.sampler.set_epoch = Mock() hook.before_train_epoch(runner) - runner.cur_dataloader.sampler.set_epoch.assert_called() + runner.train_loop.dataloader.sampler.set_epoch.assert_called() # Test batch sampler runner = Mock() - runner.cur_dataloader = Mock() - runner.cur_dataloader.sampler = Mock(spec_set=True) - runner.cur_dataloader.batch_sampler = Mock() - runner.cur_dataloader.batch_sampler.sampler = Mock() - runner.cur_dataloader.batch_sampler.sampler.set_epoch = Mock() + runner.train_loop.dataloader = Mock() + runner.train_loop.dataloader.sampler = Mock(spec_set=True) + runner.train_loop.dataloader.batch_sampler = Mock() + runner.train_loop.dataloader.batch_sampler.sampler = Mock() + runner.train_loop.dataloader.batch_sampler.sampler.set_epoch = Mock() hook.before_train_epoch(runner) - runner.cur_dataloader.batch_sampler.sampler.set_epoch.assert_called() + runner.train_loop.dataloader.\ + batch_sampler.sampler.set_epoch.assert_called() diff --git a/tests/test_model/test_wrappers/test_data_parallel.py b/tests/test_model/test_wrappers/test_data_parallel.py index fa3e9993436760f8d576e34989d95bb40b20487f..34b0cad254c57254016e74fab8507a30cf9c382f 100644 --- a/tests/test_model/test_wrappers/test_data_parallel.py +++ b/tests/test_model/test_wrappers/test_data_parallel.py @@ -5,6 +5,8 @@ from unittest.mock import MagicMock, patch import pytest import torch import torch.nn as nn +from torch.nn.parallel import DataParallel +from torch.nn.parallel.distributed import DistributedDataParallel from mmengine.model.wrappers import (MMDataParallel, MMDistributedDataParallel, is_model_wrapper) @@ -44,6 +46,12 @@ def test_is_model_wrapper(): mmddp = MMDistributedDataParallel(model, process_group=MagicMock()) assert is_model_wrapper(mmddp) + torch_dp = DataParallel(model) + assert is_model_wrapper(torch_dp) + + torch_ddp = DistributedDataParallel(model, process_group=MagicMock()) + assert is_model_wrapper(torch_ddp) + # test model wrapper registry @MODEL_WRAPPERS.register_module() class ModelWrapper(object): diff --git a/tests/test_registry/test_registry.py b/tests/test_registry/test_registry.py index 76f1d7ce9b2b926f65c32b65eff471619f8a6c5f..cc7d41563ad95765e071d4e17df3447adf9bb105 100644 --- a/tests/test_registry/test_registry.py +++ b/tests/test_registry/test_registry.py @@ -5,6 +5,7 @@ import pytest from mmengine.config import Config, ConfigDict # type: ignore from mmengine.registry import DefaultScope, Registry, build_from_cfg +from mmengine.utils import ManagerMixin class TestRegistry: @@ -482,3 +483,17 @@ def test_build_from_cfg(cfg_type): "<class 'str'>")): cfg = cfg_type(dict(type='ResNet', depth=50)) model = build_from_cfg(cfg, 'BACKBONES') + + VISUALIZER = Registry('visualizer') + + @VISUALIZER.register_module() + class Visualizer(ManagerMixin): + + def __init__(self, name): + super().__init__(name) + + with pytest.raises(RuntimeError): + Visualizer.get_current_instance() + cfg = dict(type='Visualizer', name='visualizer') + build_from_cfg(cfg, VISUALIZER) + Visualizer.get_current_instance() diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 7b497c8e54eb6c503f9e5efa12adc8b3dc47c18a..29e7ee3658a7678489c9805eeac400204c6cbc62 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -14,8 +14,8 @@ from torch.utils.data import DataLoader, Dataset from mmengine.config import Config from mmengine.data import DefaultSampler from mmengine.evaluator import BaseMetric, Evaluator -from mmengine.hooks import (Hook, IterTimerHook, LoggerHook, OptimizerHook, - ParamSchedulerHook) +from mmengine.hooks import (DistSamplerSeedHook, Hook, IterTimerHook, + LoggerHook, OptimizerHook, ParamSchedulerHook) from mmengine.hooks.checkpoint_hook import CheckpointHook from mmengine.logging import MessageHub, MMLogger from mmengine.optim.scheduler import MultiStepLR, StepLR @@ -25,7 +25,7 @@ from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop, Runner, TestLoop, ValLoop) from mmengine.runner.priority import Priority, get_priority from mmengine.utils import is_list_of -from mmengine.visualization.writer import ComposedWriter +from mmengine.visualization import Visualizer @MODELS.register_module() @@ -36,7 +36,8 @@ class ToyModel(nn.Module): self.linear = nn.Linear(2, 1) def forward(self, data_batch, return_loss=False): - inputs, labels = zip(*data_batch) + inputs, labels = zip( + *map(lambda x: (x['inputs'], x['data_sample']), data_batch)) device = 'cuda:0' if torch.cuda.is_available() else 'cpu' inputs = torch.stack(inputs).to(device) labels = torch.stack(labels).to(device) @@ -67,7 +68,7 @@ class CustomModelWrapper(nn.Module): @DATASETS.register_module() class ToyDataset(Dataset): - META = dict() # type: ignore + METAINFO = dict() # type: ignore data = torch.randn(12, 2) label = torch.ones(12) @@ -75,7 +76,7 @@ class ToyDataset(Dataset): return self.data.size(0) def __getitem__(self, index): - return self.data[index], self.label[index] + return dict(inputs=self.data[index], data_sample=self.label[index]) @METRICS.register_module() @@ -307,24 +308,24 @@ class TestRunner(TestCase): self.assertFalse(runner.distributed) self.assertFalse(runner.deterministic) - # 1.5 message_hub, logger and writer + # 1.5 message_hub, logger and visualizer # they are all not specified cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_init12' runner = Runner(**cfg) self.assertIsInstance(runner.logger, MMLogger) self.assertIsInstance(runner.message_hub, MessageHub) - self.assertIsInstance(runner.writer, ComposedWriter) + self.assertIsInstance(runner.visualizer, Visualizer) # they are all specified cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_init13' cfg.log_level = 'INFO' - cfg.writer = dict(name='test_writer') + cfg.visualizer = None runner = Runner(**cfg) self.assertIsInstance(runner.logger, MMLogger) self.assertIsInstance(runner.message_hub, MessageHub) - self.assertIsInstance(runner.writer, ComposedWriter) + self.assertIsInstance(runner.visualizer, Visualizer) assert runner.distributed is False assert runner.seed is not None @@ -445,32 +446,34 @@ class TestRunner(TestCase): with self.assertRaisesRegex(TypeError, 'message_hub should be'): runner.build_message_hub('invalid-type') - def test_build_writer(self): - self.epoch_based_cfg.experiment_name = 'test_build_writer1' + def test_build_visualizer(self): + self.epoch_based_cfg.experiment_name = 'test_build_visualizer1' runner = Runner.from_cfg(self.epoch_based_cfg) - self.assertIsInstance(runner.writer, ComposedWriter) - self.assertEqual(runner.experiment_name, runner.writer.instance_name) + self.assertIsInstance(runner.visualizer, Visualizer) + self.assertEqual(runner.experiment_name, + runner.visualizer.instance_name) - # input is a ComposedWriter object + # input is a Visualizer object self.assertEqual( - id(runner.build_writer(runner.writer)), id(runner.writer)) + id(runner.build_visualizer(runner.visualizer)), + id(runner.visualizer)) # input is a dict - writer_cfg = dict(name='test_build_writer2') - writer = runner.build_writer(writer_cfg) - self.assertIsInstance(writer, ComposedWriter) - self.assertEqual(writer.instance_name, 'test_build_writer2') + visualizer_cfg = dict(type='Visualizer', name='test_build_visualizer2') + visualizer = runner.build_visualizer(visualizer_cfg) + self.assertIsInstance(visualizer, Visualizer) + self.assertEqual(visualizer.instance_name, 'test_build_visualizer2') # input is a dict but does not contain name key - runner._experiment_name = 'test_build_writer3' - writer_cfg = dict() - writer = runner.build_writer(writer_cfg) - self.assertIsInstance(writer, ComposedWriter) - self.assertEqual(writer.instance_name, 'test_build_writer3') + runner._experiment_name = 'test_build_visualizer3' + visualizer_cfg = None + visualizer = runner.build_visualizer(visualizer_cfg) + self.assertIsInstance(visualizer, Visualizer) + self.assertEqual(visualizer.instance_name, 'test_build_visualizer3') # input is not a valid type - with self.assertRaisesRegex(TypeError, 'writer should be'): - runner.build_writer('invalid-type') + with self.assertRaisesRegex(TypeError, 'visualizer should be'): + runner.build_visualizer('invalid-type') def test_default_scope(self): TOY_SCHEDULERS = Registry( @@ -511,6 +514,25 @@ class TestRunner(TestCase): model = runner.build_model(dict(type='ToyModel1')) self.assertIsInstance(model, ToyModel1) + # test init weights + @MODELS.register_module() + class ToyModel2(ToyModel): + + def __init__(self): + super().__init__() + self.initiailzed = False + + def init_weights(self): + self.initiailzed = True + + model = runner.build_model(dict(type='ToyModel2')) + self.assertTrue(model.initiailzed) + + # test init weights with model object + _model = ToyModel2() + model = runner.build_model(_model) + self.assertFalse(model.initiailzed) + def test_wrap_model(self): # TODO: test on distributed environment # custom model wrapper @@ -897,33 +919,35 @@ class TestRunner(TestCase): # register five hooks by default runner.register_default_hooks() - self.assertEqual(len(runner._hooks), 5) - # the forth registered hook should be `ParamSchedulerHook` - self.assertTrue(isinstance(runner._hooks[3], ParamSchedulerHook)) + self.assertEqual(len(runner._hooks), 6) + # the third registered hook should be `DistSamplerSeedHook` + self.assertTrue(isinstance(runner._hooks[2], DistSamplerSeedHook)) + # the fifth registered hook should be `ParamSchedulerHook` + self.assertTrue(isinstance(runner._hooks[4], ParamSchedulerHook)) runner._hooks = [] # remove `ParamSchedulerHook` from default hooks runner.register_default_hooks(hooks=dict(timer=None)) - self.assertEqual(len(runner._hooks), 4) + self.assertEqual(len(runner._hooks), 5) # `ParamSchedulerHook` was popped so the forth is `CheckpointHook` - self.assertTrue(isinstance(runner._hooks[3], CheckpointHook)) + self.assertTrue(isinstance(runner._hooks[4], CheckpointHook)) # add a new default hook runner._hooks = [] runner.register_default_hooks(hooks=dict(ToyHook=dict(type='ToyHook'))) - self.assertEqual(len(runner._hooks), 6) - self.assertTrue(isinstance(runner._hooks[5], ToyHook)) + self.assertEqual(len(runner._hooks), 7) + self.assertTrue(isinstance(runner._hooks[6], ToyHook)) def test_custom_hooks(self): cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_custom_hooks' runner = Runner.from_cfg(cfg) - self.assertEqual(len(runner._hooks), 5) + self.assertEqual(len(runner._hooks), 6) custom_hooks = [dict(type='ToyHook')] runner.register_custom_hooks(custom_hooks) - self.assertEqual(len(runner._hooks), 6) - self.assertTrue(isinstance(runner._hooks[5], ToyHook)) + self.assertEqual(len(runner._hooks), 7) + self.assertTrue(isinstance(runner._hooks[6], ToyHook)) def test_register_hooks(self): cfg = copy.deepcopy(self.epoch_based_cfg) @@ -933,9 +957,9 @@ class TestRunner(TestCase): runner._hooks = [] custom_hooks = [dict(type='ToyHook')] runner.register_hooks(custom_hooks=custom_hooks) - # five default hooks + custom hook (ToyHook) - self.assertEqual(len(runner._hooks), 6) - self.assertTrue(isinstance(runner._hooks[5], ToyHook)) + # six default hooks + custom hook (ToyHook) + self.assertEqual(len(runner._hooks), 7) + self.assertTrue(isinstance(runner._hooks[6], ToyHook)) def test_custom_loop(self): # test custom loop with additional hook @@ -1057,9 +1081,34 @@ class TestRunner(TestCase): self.assertIsInstance(runner.optimizer, SGD) self.assertIsInstance(runner.param_schedulers[0], MultiStepLR) - # 2. test iter based + # 1.4 test auto resume cfg = copy.deepcopy(self.iter_based_cfg) cfg.experiment_name = 'test_checkpoint4' + cfg.resume = True + runner = Runner.from_cfg(cfg) + runner.load_or_resume() + self.assertEqual(runner.epoch, 3) + self.assertEqual(runner.iter, 12) + self.assertTrue(runner._has_loaded) + self.assertIsInstance(runner.optimizer, SGD) + self.assertIsInstance(runner.param_schedulers[0], MultiStepLR) + + # 1.5 test resume from a specified checkpoint + cfg = copy.deepcopy(self.iter_based_cfg) + cfg.experiment_name = 'test_checkpoint5' + cfg.resume = True + cfg.load_from = osp.join(self.temp_dir, 'epoch_1.pth') + runner = Runner.from_cfg(cfg) + runner.load_or_resume() + self.assertEqual(runner.epoch, 1) + self.assertEqual(runner.iter, 4) + self.assertTrue(runner._has_loaded) + self.assertIsInstance(runner.optimizer, SGD) + self.assertIsInstance(runner.param_schedulers[0], MultiStepLR) + + # 2. test iter based + cfg = copy.deepcopy(self.iter_based_cfg) + cfg.experiment_name = 'test_checkpoint6' runner = Runner.from_cfg(cfg) runner.train() @@ -1078,7 +1127,7 @@ class TestRunner(TestCase): # 2.2 test `load_checkpoint` cfg = copy.deepcopy(self.iter_based_cfg) - cfg.experiment_name = 'test_checkpoint5' + cfg.experiment_name = 'test_checkpoint7' runner = Runner.from_cfg(cfg) runner.load_checkpoint(path) self.assertEqual(runner.epoch, 0) @@ -1087,7 +1136,7 @@ class TestRunner(TestCase): # 2.3 test `resume` cfg = copy.deepcopy(self.iter_based_cfg) - cfg.experiment_name = 'test_checkpoint6' + cfg.experiment_name = 'test_checkpoint8' runner = Runner.from_cfg(cfg) runner.resume(path) self.assertEqual(runner.epoch, 0) @@ -1095,3 +1144,28 @@ class TestRunner(TestCase): self.assertTrue(runner._has_loaded) self.assertIsInstance(runner.optimizer, SGD) self.assertIsInstance(runner.param_schedulers[0], MultiStepLR) + + # 2.4 test auto resume + cfg = copy.deepcopy(self.iter_based_cfg) + cfg.experiment_name = 'test_checkpoint9' + cfg.resume = True + runner = Runner.from_cfg(cfg) + runner.load_or_resume() + self.assertEqual(runner.epoch, 0) + self.assertEqual(runner.iter, 12) + self.assertTrue(runner._has_loaded) + self.assertIsInstance(runner.optimizer, SGD) + self.assertIsInstance(runner.param_schedulers[0], MultiStepLR) + + # 2.5 test resume from a specified checkpoint + cfg = copy.deepcopy(self.iter_based_cfg) + cfg.experiment_name = 'test_checkpoint10' + cfg.resume = True + cfg.load_from = osp.join(self.temp_dir, 'iter_3.pth') + runner = Runner.from_cfg(cfg) + runner.load_or_resume() + self.assertEqual(runner.epoch, 0) + self.assertEqual(runner.iter, 3) + self.assertTrue(runner._has_loaded) + self.assertIsInstance(runner.optimizer, SGD) + self.assertIsInstance(runner.param_schedulers[0], MultiStepLR) diff --git a/tests/test_visualizer/test_vis_backend.py b/tests/test_visualizer/test_vis_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..da662a65d00e1c045110aed751534b420349749d --- /dev/null +++ b/tests/test_visualizer/test_vis_backend.py @@ -0,0 +1,200 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import shutil +import sys +from unittest.mock import MagicMock + +import numpy as np +import pytest + +from mmengine.fileio import load +from mmengine.registry import VISBACKENDS +from mmengine.visualization import (LocalVisBackend, TensorboardVisBackend, + WandbVisBackend) + + +class TestLocalVisBackend: + + def test_init(self): + + # 'config_save_file' format must be py + with pytest.raises(AssertionError): + LocalVisBackend('temp_dir', config_save_file='a.txt') + + # 'scalar_save_file' format must be json + with pytest.raises(AssertionError): + LocalVisBackend('temp_dir', scalar_save_file='a.yaml') + + local_vis_backend = LocalVisBackend('temp_dir') + assert os.path.exists(local_vis_backend._save_dir) + shutil.rmtree('temp_dir') + + local_vis_backend = VISBACKENDS.build( + dict(type='LocalVisBackend', save_dir='temp_dir')) + assert os.path.exists(local_vis_backend._save_dir) + shutil.rmtree('temp_dir') + + def test_experiment(self): + local_vis_backend = LocalVisBackend('temp_dir') + assert local_vis_backend.experiment == local_vis_backend + shutil.rmtree('temp_dir') + + def test_add_config(self): + local_vis_backend = LocalVisBackend('temp_dir') + + # 'params_dict' must be dict + with pytest.raises(AssertionError): + local_vis_backend.add_config(['lr', 0]) + + # TODO + + shutil.rmtree('temp_dir') + + def test_add_image(self): + image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) + local_vis_backend = LocalVisBackend('temp_dir') + local_vis_backend.add_image('img', image) + assert os.path.exists( + os.path.join(local_vis_backend._img_save_dir, 'img_0.png')) + + local_vis_backend.add_image('img', image, step=2) + assert os.path.exists( + os.path.join(local_vis_backend._img_save_dir, 'img_2.png')) + + shutil.rmtree('temp_dir') + + def test_add_scalar(self): + local_vis_backend = LocalVisBackend('temp_dir') + local_vis_backend.add_scalar('map', 0.9) + out_dict = load(local_vis_backend._scalar_save_file, 'json') + assert out_dict == {'map': 0.9, 'step': 0} + shutil.rmtree('temp_dir') + + # test append mode + local_vis_backend = LocalVisBackend('temp_dir') + local_vis_backend.add_scalar('map', 0.9, step=0) + local_vis_backend.add_scalar('map', 0.95, step=1) + with open(local_vis_backend._scalar_save_file) as f: + out_dict = f.read() + assert out_dict == '{"map": 0.9, "step": 0}\n{"map": ' \ + '0.95, "step": 1}\n' + shutil.rmtree('temp_dir') + + def test_add_scalars(self): + local_vis_backend = LocalVisBackend('temp_dir') + input_dict = {'map': 0.7, 'acc': 0.9} + local_vis_backend.add_scalars(input_dict) + out_dict = load(local_vis_backend._scalar_save_file, 'json') + assert out_dict == {'map': 0.7, 'acc': 0.9, 'step': 0} + + # test append mode + local_vis_backend.add_scalars({'map': 0.8, 'acc': 0.8}, step=1) + with open(local_vis_backend._scalar_save_file) as f: + out_dict = f.read() + assert out_dict == '{"map": 0.7, "acc": 0.9, ' \ + '"step": 0}\n{"map": 0.8, "acc": 0.8, "step": 1}\n' + + # test file_path + local_vis_backend = LocalVisBackend('temp_dir') + local_vis_backend.add_scalars(input_dict, file_path='temp.json') + assert os.path.exists(local_vis_backend._scalar_save_file) + assert os.path.exists( + os.path.join(local_vis_backend._save_dir, 'temp.json')) + + # file_path and scalar_save_file cannot be the same + with pytest.raises(AssertionError): + local_vis_backend.add_scalars(input_dict, file_path='scalars.json') + + shutil.rmtree('temp_dir') + + +class TestTensorboardVisBackend: + sys.modules['torch.utils.tensorboard'] = MagicMock() + sys.modules['tensorboardX'] = MagicMock() + + def test_init(self): + + TensorboardVisBackend('temp_dir') + VISBACKENDS.build( + dict(type='TensorboardVisBackend', save_dir='temp_dir')) + + def test_experiment(self): + tensorboard_vis_backend = TensorboardVisBackend('temp_dir') + assert (tensorboard_vis_backend.experiment == + tensorboard_vis_backend._tensorboard) + + def test_add_graph(self): + # TODO + pass + + def test_add_config(self): + # TODO + pass + + def test_add_image(self): + image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) + + tensorboard_vis_backend = TensorboardVisBackend('temp_dir') + tensorboard_vis_backend.add_image('img', image) + + tensorboard_vis_backend.add_image('img', image, step=2) + + def test_add_scalar(self): + tensorboard_vis_backend = TensorboardVisBackend('temp_dir') + tensorboard_vis_backend.add_scalar('map', 0.9) + # test append mode + tensorboard_vis_backend.add_scalar('map', 0.9, step=0) + tensorboard_vis_backend.add_scalar('map', 0.95, step=1) + + def test_add_scalars(self): + tensorboard_vis_backend = TensorboardVisBackend('temp_dir') + # The step value must be passed through the parameter + with pytest.raises(AssertionError): + tensorboard_vis_backend.add_scalars({ + 'map': 0.7, + 'acc': 0.9, + 'step': 1 + }) + + input_dict = {'map': 0.7, 'acc': 0.9} + tensorboard_vis_backend.add_scalars(input_dict) + # test append mode + tensorboard_vis_backend.add_scalars({'map': 0.8, 'acc': 0.8}, step=1) + + +class TestWandbVisBackend: + sys.modules['wandb'] = MagicMock() + + def test_init(self): + WandbVisBackend() + VISBACKENDS.build(dict(type='WandbVisBackend', save_dir='temp_dir')) + + def test_experiment(self): + wandb_vis_backend = WandbVisBackend() + assert wandb_vis_backend.experiment == wandb_vis_backend._wandb + + def test_add_config(self): + # TODO + pass + + def test_add_image(self): + image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) + + wandb_vis_backend = WandbVisBackend() + wandb_vis_backend.add_image('img', image) + + wandb_vis_backend.add_image('img', image, step=2) + + def test_add_scalar(self): + wandb_vis_backend = WandbVisBackend() + wandb_vis_backend.add_scalar('map', 0.9) + # test append mode + wandb_vis_backend.add_scalar('map', 0.9, step=0) + wandb_vis_backend.add_scalar('map', 0.95, step=1) + + def test_add_scalars(self): + wandb_vis_backend = WandbVisBackend() + input_dict = {'map': 0.7, 'acc': 0.9} + wandb_vis_backend.add_scalars(input_dict) + # test append mode + wandb_vis_backend.add_scalars({'map': 0.8, 'acc': 0.8}, step=1) diff --git a/tests/test_visualizer/test_visualizer.py b/tests/test_visualizer/test_visualizer.py index 5a7da41b78702417bd3ff38f36d8e77045f8ef8f..ce3de94d88104d7677aa84532337b5c84b0a6f79 100644 --- a/tests/test_visualizer/test_visualizer.py +++ b/tests/test_visualizer/test_visualizer.py @@ -1,16 +1,64 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Optional +import copy +from typing import Any, List, Optional, Union from unittest import TestCase import matplotlib.pyplot as plt import numpy as np import pytest import torch +import torch.nn as nn -from mmengine.data import BaseDataElement +from mmengine import VISBACKENDS from mmengine.visualization import Visualizer +@VISBACKENDS.register_module() +class MockVisBackend: + + def __init__(self, save_dir: Optional[str] = None): + self._save_dir = save_dir + self._close = False + + @property + def experiment(self) -> Any: + return self + + def add_config(self, params_dict: dict, **kwargs) -> None: + self._add_config = True + + def add_graph(self, model: torch.nn.Module, + input_tensor: Union[torch.Tensor, + List[torch.Tensor]], **kwargs) -> None: + + self._add_graph = True + + def add_image(self, + name: str, + image: np.ndarray, + step: int = 0, + **kwargs) -> None: + self._add_image = True + + def add_scalar(self, + name: str, + value: Union[int, float], + step: int = 0, + **kwargs) -> None: + self._add_scalar = True + + def add_scalars(self, + scalar_dict: dict, + step: int = 0, + file_path: Optional[str] = None, + **kwargs) -> None: + self._add_scalars = True + + def close(self) -> None: + """close an opened object.""" + self._close = True + + class TestVisualizer(TestCase): def setUp(self): @@ -21,11 +69,27 @@ class TestVisualizer(TestCase): """ self.image = np.random.randint( 0, 256, size=(10, 10, 3)).astype('uint8') + self.vis_backend_cfg = [ + dict(type='MockVisBackend', name='mock1', save_dir='tmp'), + dict(type='MockVisBackend', name='mock2', save_dir='tmp') + ] def test_init(self): visualizer = Visualizer(image=self.image) visualizer.get_image() + visualizer = Visualizer( + vis_backends=copy.deepcopy(self.vis_backend_cfg)) + assert isinstance(visualizer.get_backend('mock1'), MockVisBackend) + assert len(visualizer._vis_backends) == 2 + + # test global + visualizer = Visualizer.get_instance( + 'visualizer', vis_backends=copy.deepcopy(self.vis_backend_cfg)) + assert len(visualizer._vis_backends) == 2 + visualizer_any = Visualizer.get_instance('visualizer') + assert visualizer_any == visualizer + def test_set_image(self): visualizer = Visualizer() visualizer.set_image(self.image) @@ -45,7 +109,7 @@ class TestVisualizer(TestCase): visualizer.draw_bboxes(torch.tensor([1, 1, 1, 2])) bboxes = torch.tensor([[1, 1, 2, 2], [1, 2, 2, 2.5]]) visualizer.draw_bboxes( - bboxes, alpha=0.5, edgecolors='b', linestyles='-') + bboxes, alpha=0.5, edge_colors=(255, 0, 0), line_styles='-') bboxes = bboxes.numpy() visualizer.draw_bboxes(bboxes) @@ -66,19 +130,26 @@ class TestVisualizer(TestCase): visualizer.draw_bboxes([1, 1, 2, 2]) def test_close(self): - visualizer = Visualizer(image=self.image) - fig_num = visualizer.fig.number + visualizer = Visualizer( + image=self.image, vis_backends=copy.deepcopy(self.vis_backend_cfg)) + fig_num = visualizer.fig_save_num assert fig_num in plt.get_fignums() + for name in ['mock1', 'mock2']: + assert visualizer.get_backend(name)._close is False visualizer.close() assert fig_num not in plt.get_fignums() + for name in ['mock1', 'mock2']: + assert visualizer.get_backend(name)._close is True def test_draw_texts(self): visualizer = Visualizer(image=self.image) # only support tensor and numpy - visualizer.draw_texts('text1', positions=torch.tensor([5, 5])) + visualizer.draw_texts( + 'text1', positions=torch.tensor([5, 5]), colors=(0, 255, 0)) visualizer.draw_texts(['text1', 'text2'], - positions=torch.tensor([[5, 5], [3, 3]])) + positions=torch.tensor([[5, 5], [3, 3]]), + colors=[(255, 0, 0), (255, 0, 0)]) visualizer.draw_texts('text1', positions=np.array([5, 5])) visualizer.draw_texts(['text1', 'text2'], positions=np.array([[5, 5], [3, 3]])) @@ -111,11 +182,11 @@ class TestVisualizer(TestCase): with pytest.raises(AssertionError): visualizer.draw_texts(['text1', 'test2'], positions=torch.tensor([[5, 5], [3, 3]]), - verticalalignments=['top']) + vertical_alignments=['top']) with pytest.raises(AssertionError): visualizer.draw_texts(['text1', 'test2'], positions=torch.tensor([[5, 5], [3, 3]]), - horizontalalignments=['left']) + horizontal_alignments=['left']) with pytest.raises(AssertionError): visualizer.draw_texts(['text1', 'test2'], positions=torch.tensor([[5, 5], [3, 3]]), @@ -140,8 +211,8 @@ class TestVisualizer(TestCase): x_datas=np.array([[1, 5], [2, 4]]), y_datas=np.array([[2, 6], [4, 7]]), colors='r', - linestyles=['-', '-.'], - linewidths=[1, 2]) + line_styles=['-', '-.'], + line_widths=[1, 2]) # test out of bounds with pytest.warns( UserWarning, @@ -171,19 +242,20 @@ class TestVisualizer(TestCase): visualizer.draw_circles( torch.tensor([[1, 5], [2, 6]]), radius=torch.tensor([1, 2])) - # test filling + # test face_colors visualizer.draw_circles( torch.tensor([[1, 5], [2, 6]]), radius=torch.tensor([1, 2]), - is_filling=True) + face_colors=(255, 0, 0), + edge_colors=(255, 0, 0)) # test config visualizer.draw_circles( torch.tensor([[1, 5], [2, 6]]), radius=torch.tensor([1, 2]), - edgecolors=['g', 'r'], - linestyles=['-', '-.'], - linewidths=[1, 2]) + edge_colors=['g', 'r'], + line_styles=['-', '-.'], + line_widths=[1, 2]) # test out of bounds with pytest.warns( @@ -220,15 +292,16 @@ class TestVisualizer(TestCase): np.array([[1, 1], [2, 2], [3, 4]]), torch.tensor([[1, 1], [2, 2], [3, 4]]) ], - is_filling=True) + face_colors=(255, 0, 0), + edge_colors=(255, 0, 0)) visualizer.draw_polygons( polygons=[ np.array([[1, 1], [2, 2], [3, 4]]), torch.tensor([[1, 1], [2, 2], [3, 4]]) ], - edgecolors=['r', 'g'], - linestyles='-', - linewidths=[2, 1]) + edge_colors=['r', 'g'], + line_styles='-', + line_widths=[2, 1]) # test out of bounds with pytest.warns( @@ -242,7 +315,10 @@ class TestVisualizer(TestCase): visualizer = Visualizer(image=self.image) visualizer.draw_binary_masks(binary_mask) visualizer.draw_binary_masks(torch.from_numpy(binary_mask)) - + # multi binary + binary_mask = np.random.randint(0, 2, size=(2, 10, 10)).astype(np.bool) + visualizer = Visualizer(image=self.image) + visualizer.draw_binary_masks(binary_mask, colors=['r', (0, 255, 0)]) # test the error that the size of mask and image are different. with pytest.raises(AssertionError): binary_mask = np.random.randint(0, 2, size=(8, 10)).astype(np.bool) @@ -269,7 +345,7 @@ class TestVisualizer(TestCase): visualizer.draw_featmap(torch.randn(1, 1, 3, 3)) # test mode parameter - # mode only supports 'mean' and 'max' and 'min + # mode only supports 'mean' and 'max' with pytest.raises(AssertionError): visualizer.draw_featmap(torch.randn(2, 3, 3), mode='xx') # test tensor_chw and img have difference height and width @@ -289,7 +365,6 @@ class TestVisualizer(TestCase): visualizer.draw_featmap(torch.randn(6, 3, 3), mode='mean') visualizer.draw_featmap(torch.randn(1, 3, 3), mode='mean') visualizer.draw_featmap(torch.randn(6, 3, 3), mode='max') - visualizer.draw_featmap(torch.randn(6, 3, 3), mode='min') visualizer.draw_featmap(torch.randn(6, 3, 3), mode='max', topk=10) visualizer.draw_featmap(torch.randn(1, 3, 3), mode=None, topk=-1) visualizer.draw_featmap( @@ -325,57 +400,76 @@ class TestVisualizer(TestCase): draw_polygons(torch.tensor([[1, 1], [2, 2], [3, 4]])). \ draw_binary_masks(binary_mask) - def test_register_task(self): + def test_get_backend(self): + visualizer = Visualizer( + image=self.image, vis_backends=copy.deepcopy(self.vis_backend_cfg)) + for name in ['mock1', 'mock2']: + assert isinstance(visualizer.get_backend(name), MockVisBackend) - class DetVisualizer(Visualizer): + def test_add_config(self): + visualizer = Visualizer( + vis_backends=copy.deepcopy(self.vis_backend_cfg)) - @Visualizer.register_task('instances') - def draw_instance(self, instances, data_type): - pass + params_dict = dict(lr=0.1, wd=0.2, mode='linear') + visualizer.add_config(params_dict) + for name in ['mock1', 'mock2']: + assert visualizer.get_backend(name)._add_config is True - assert len(Visualizer.task_dict) == 1 - assert 'instances' in Visualizer.task_dict + def test_add_graph(self): + visualizer = Visualizer( + vis_backends=copy.deepcopy(self.vis_backend_cfg)) - # test registration of the same names. - with pytest.raises( - KeyError, - match=('"instances" is already registered in task_dict, ' - 'add "force=True" if you want to override it')): - - class DetVisualizer1(Visualizer): - - @Visualizer.register_task('instances') - def draw_instance1(self, instances, data_type): - pass - - @Visualizer.register_task('instances') - def draw_instance2(self, instances, data_type): - pass - - Visualizer.task_dict = dict() - - class DetVisualizer2(Visualizer): - - @Visualizer.register_task('instances') - def draw_instance1(self, instances, data_type): - pass - - @Visualizer.register_task('instances', force=True) - def draw_instance2(self, instances, data_type): - pass - - def draw(self, - image: Optional[np.ndarray] = None, - gt_sample: Optional['BaseDataElement'] = None, - pred_sample: Optional['BaseDataElement'] = None, - draw_gt: bool = True, - draw_pred: bool = True) -> None: - return super().draw(image, gt_sample, pred_sample, draw_gt, - draw_pred) - - det_visualizer = DetVisualizer2() - det_visualizer.draw(gt_sample={}, pred_sample={}) - assert len(det_visualizer.task_dict) == 1 - assert 'instances' in det_visualizer.task_dict - assert det_visualizer.task_dict[ - 'instances'].__name__ == 'draw_instance2' + class Model(nn.Module): + + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(1, 2, 1) + + def forward(self, x, y=None): + return self.conv(x) + + visualizer.add_graph(Model(), np.zeros([1, 1, 3, 3])) + for name in ['mock1', 'mock2']: + assert visualizer.get_backend(name)._add_graph is True + + def test_add_image(self): + image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) + visualizer = Visualizer( + vis_backends=copy.deepcopy(self.vis_backend_cfg)) + + visualizer.add_image('img', image) + for name in ['mock1', 'mock2']: + assert visualizer.get_backend(name)._add_image is True + + def test_add_scalar(self): + visualizer = Visualizer( + vis_backends=copy.deepcopy(self.vis_backend_cfg)) + visualizer.add_scalar('map', 0.9, step=0) + for name in ['mock1', 'mock2']: + assert visualizer.get_backend(name)._add_scalar is True + + def test_add_scalars(self): + visualizer = Visualizer( + vis_backends=copy.deepcopy(self.vis_backend_cfg)) + input_dict = {'map': 0.7, 'acc': 0.9} + visualizer.add_scalars(input_dict) + for name in ['mock1', 'mock2']: + assert visualizer.get_backend(name)._add_scalars is True + + def test_get_instance(self): + + class DetLocalVisualizer(Visualizer): + + def __init__(self, name): + super().__init__(name) + + visualizer1 = DetLocalVisualizer.get_instance('name1') + visualizer2 = Visualizer.get_current_instance() + visualizer3 = DetLocalVisualizer.get_current_instance() + assert id(visualizer1) == id(visualizer2) == id(visualizer3) + + +if __name__ == '__main__': + t = TestVisualizer() + t.setUp() + t.test_init() diff --git a/tests/test_visualizer/test_writer.py b/tests/test_visualizer/test_writer.py deleted file mode 100644 index 447a246d318b7ceaab00012bbc6e32ddf2825d4a..0000000000000000000000000000000000000000 --- a/tests/test_visualizer/test_writer.py +++ /dev/null @@ -1,484 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os -import shutil -import sys -from unittest.mock import MagicMock, Mock, patch - -import numpy as np -import pytest -import torch -import torch.nn as nn - -from mmengine.fileio import load -from mmengine.registry import VISUALIZERS, WRITERS -from mmengine.visualization import (ComposedWriter, LocalWriter, - TensorboardWriter, WandbWriter) - - -def draw(self, image, gt_sample, pred_sample, show_gt=True, show_pred=True): - self.set_image(image) - - -class TestLocalWriter: - - def test_init(self): - # visuailzer must be a dictionary or an instance - # of Visualizer and its subclasses - with pytest.raises(AssertionError): - LocalWriter('temp_dir', [dict(type='Visualizer')]) - - # 'params_save_file' format must be yaml - with pytest.raises(AssertionError): - LocalWriter('temp_dir', params_save_file='a.txt') - - # 'scalar_save_file' format must be json - with pytest.raises(AssertionError): - LocalWriter('temp_dir', scalar_save_file='a.yaml') - - local_writer = LocalWriter('temp_dir') - assert os.path.exists(local_writer._save_dir) - shutil.rmtree('temp_dir') - - local_writer = WRITERS.build( - dict( - type='LocalWriter', - visualizer=dict(type='Visualizer'), - save_dir='temp_dir')) - assert os.path.exists(local_writer._save_dir) - shutil.rmtree('temp_dir') - - def test_experiment(self): - local_writer = LocalWriter('temp_dir') - assert local_writer.experiment == local_writer - shutil.rmtree('temp_dir') - - def test_add_params(self): - local_writer = LocalWriter('temp_dir') - - # 'params_dict' must be dict - with pytest.raises(AssertionError): - local_writer.add_params(['lr', 0]) - - params_dict = dict(lr=0.1, wd=[1.0, 0.1, 0.001], mode='linear') - local_writer.add_params(params_dict) - out_dict = load(local_writer._params_save_file, 'yaml') - assert out_dict == params_dict - shutil.rmtree('temp_dir') - - @patch('mmengine.visualization.visualizer.Visualizer.draw', draw) - def test_add_image(self): - image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) - - # The visuailzer parameter must be set when - # the local_writer object is instantiated and - # the `add_image` method is called. - with pytest.raises(AssertionError): - local_writer = LocalWriter('temp_dir') - local_writer.add_image('img', image) - - local_writer = LocalWriter('temp_dir', dict(type='Visualizer')) - local_writer.add_image('img', image) - assert os.path.exists( - os.path.join(local_writer._img_save_dir, 'img_0.png')) - - bboxes = np.array([[1, 1, 2, 2], [1, 1.5, 1, 2.5]]) - local_writer.visualizer.draw_bboxes(bboxes) - local_writer.add_image( - 'img', local_writer.visualizer.get_image(), step=2) - assert os.path.exists( - os.path.join(local_writer._img_save_dir, 'img_2.png')) - - visuailzer = VISUALIZERS.build(dict(type='Visualizer')) - local_writer = LocalWriter('temp_dir', visuailzer) - local_writer.add_image('img', image) - assert os.path.exists( - os.path.join(local_writer._img_save_dir, 'img_0.png')) - - shutil.rmtree('temp_dir') - - def test_add_scalar(self): - local_writer = LocalWriter('temp_dir') - local_writer.add_scalar('map', 0.9) - out_dict = load(local_writer._scalar_save_file, 'json') - assert out_dict == {'map': 0.9, 'step': 0} - shutil.rmtree('temp_dir') - - # test append mode - local_writer = LocalWriter('temp_dir') - local_writer.add_scalar('map', 0.9, step=0) - local_writer.add_scalar('map', 0.95, step=1) - with open(local_writer._scalar_save_file) as f: - out_dict = f.read() - assert out_dict == '{"map": 0.9, "step": 0}\n{"map": ' \ - '0.95, "step": 1}\n' - shutil.rmtree('temp_dir') - - def test_add_scalars(self): - local_writer = LocalWriter('temp_dir') - input_dict = {'map': 0.7, 'acc': 0.9} - local_writer.add_scalars(input_dict) - out_dict = load(local_writer._scalar_save_file, 'json') - assert out_dict == {'map': 0.7, 'acc': 0.9, 'step': 0} - - # test append mode - local_writer.add_scalars({'map': 0.8, 'acc': 0.8}, step=1) - with open(local_writer._scalar_save_file) as f: - out_dict = f.read() - assert out_dict == '{"map": 0.7, "acc": 0.9, ' \ - '"step": 0}\n{"map": 0.8, "acc": 0.8, "step": 1}\n' - - # test file_path - local_writer = LocalWriter('temp_dir') - local_writer.add_scalars(input_dict, file_path='temp.json') - assert os.path.exists(local_writer._scalar_save_file) - assert os.path.exists( - os.path.join(local_writer._save_dir, 'temp.json')) - - # file_path and scalar_save_file cannot be the same - with pytest.raises(AssertionError): - local_writer.add_scalars(input_dict, file_path='scalars.json') - - shutil.rmtree('temp_dir') - - -class TestTensorboardWriter: - sys.modules['torch.utils.tensorboard'] = MagicMock() - sys.modules['tensorboardX'] = MagicMock() - - def test_init(self): - # visuailzer must be a dictionary or an instance - # of Visualizer and its subclasses - with pytest.raises(AssertionError): - LocalWriter('temp_dir', [dict(type='Visualizer')]) - - TensorboardWriter('temp_dir') - WRITERS.build( - dict( - type='TensorboardWriter', - visualizer=dict(type='Visualizer'), - save_dir='temp_dir')) - - def test_experiment(self): - tensorboard_writer = TensorboardWriter('temp_dir') - assert tensorboard_writer.experiment == tensorboard_writer._tensorboard - - def test_add_graph(self): - - class Model(nn.Module): - - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(1, 2, 1) - - def forward(self, x, y=None): - return self.conv(x) - - tensorboard_writer = TensorboardWriter('temp_dir') - - # input must be tensor - with pytest.raises(AssertionError): - tensorboard_writer.add_graph(Model(), np.zeros([1, 1, 3, 3])) - - # input must be 4d tensor - with pytest.raises(AssertionError): - tensorboard_writer.add_graph(Model(), torch.zeros([1, 3, 3])) - - # If the input is a list, the inner element must be a 4d tensor - with pytest.raises(AssertionError): - tensorboard_writer.add_graph( - Model(), [torch.zeros([1, 1, 3, 3]), - torch.zeros([1, 3, 3])]) - - tensorboard_writer.add_graph(Model(), torch.zeros([1, 1, 3, 3])) - tensorboard_writer.add_graph( - Model(), [torch.zeros([1, 1, 3, 3]), - torch.zeros([1, 1, 3, 3])]) - - def test_add_params(self): - tensorboard_writer = TensorboardWriter('temp_dir') - - # 'params_dict' must be dict - with pytest.raises(AssertionError): - tensorboard_writer.add_params(['lr', 0]) - - params_dict = dict(lr=0.1, wd=0.2, mode='linear') - tensorboard_writer.add_params(params_dict) - - @patch('mmengine.visualization.visualizer.Visualizer.draw', draw) - def test_add_image(self): - image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) - - # The visuailzer parameter must be set when - # the local_writer object is instantiated and - # the `add_image` method is called. - with pytest.raises(AssertionError): - tensorboard_writer = TensorboardWriter('temp_dir') - tensorboard_writer.add_image('img', image) - - tensorboard_writer = TensorboardWriter('temp_dir', - dict(type='Visualizer')) - tensorboard_writer.add_image('img', image) - - bboxes = np.array([[1, 1, 2, 2], [1, 1.5, 1, 2.5]]) - tensorboard_writer.visualizer.draw_bboxes(bboxes) - tensorboard_writer.add_image( - 'img', tensorboard_writer.visualizer.get_image(), step=2) - - visuailzer = VISUALIZERS.build(dict(type='Visualizer')) - tensorboard_writer = TensorboardWriter('temp_dir', visuailzer) - tensorboard_writer.add_image('img', image) - - def test_add_scalar(self): - tensorboard_writer = TensorboardWriter('temp_dir') - tensorboard_writer.add_scalar('map', 0.9) - # test append mode - tensorboard_writer.add_scalar('map', 0.9, step=0) - tensorboard_writer.add_scalar('map', 0.95, step=1) - - def test_add_scalars(self): - tensorboard_writer = TensorboardWriter('temp_dir') - # The step value must be passed through the parameter - with pytest.raises(AssertionError): - tensorboard_writer.add_scalars({'map': 0.7, 'acc': 0.9, 'step': 1}) - - input_dict = {'map': 0.7, 'acc': 0.9} - tensorboard_writer.add_scalars(input_dict) - # test append mode - tensorboard_writer.add_scalars({'map': 0.8, 'acc': 0.8}, step=1) - - -class TestWandbWriter: - sys.modules['wandb'] = MagicMock() - - def test_init(self): - WandbWriter() - WRITERS.build( - dict( - type='WandbWriter', - visualizer=dict(type='Visualizer'), - save_dir='temp_dir')) - - def test_experiment(self): - wandb_writer = WandbWriter() - assert wandb_writer.experiment == wandb_writer._wandb - - def test_add_params(self): - wandb_writer = WandbWriter() - - # 'params_dict' must be dict - with pytest.raises(AssertionError): - wandb_writer.add_params(['lr', 0]) - - params_dict = dict(lr=0.1, wd=0.2, mode='linear') - wandb_writer.add_params(params_dict) - - @patch('mmengine.visualization.visualizer.Visualizer.draw', draw) - @patch('mmengine.visualization.writer.WandbWriter.add_image_to_wandb', - Mock) - def test_add_image(self): - image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) - - wandb_writer = WandbWriter() - wandb_writer.add_image('img', image) - - wandb_writer = WandbWriter(visualizer=dict(type='Visualizer')) - bboxes = np.array([[1, 1, 2, 2], [1, 1.5, 1, 2.5]]) - wandb_writer.visualizer.set_image(image) - wandb_writer.visualizer.draw_bboxes(bboxes) - wandb_writer.add_image( - 'img', wandb_writer.visualizer.get_image(), step=2) - - visuailzer = VISUALIZERS.build(dict(type='Visualizer')) - wandb_writer = WandbWriter(visualizer=visuailzer) - wandb_writer.add_image('img', image) - - def test_add_scalar(self): - wandb_writer = WandbWriter() - wandb_writer.add_scalar('map', 0.9) - # test append mode - wandb_writer.add_scalar('map', 0.9, step=0) - wandb_writer.add_scalar('map', 0.95, step=1) - - def test_add_scalars(self): - wandb_writer = WandbWriter() - input_dict = {'map': 0.7, 'acc': 0.9} - wandb_writer.add_scalars(input_dict) - # test append mode - wandb_writer.add_scalars({'map': 0.8, 'acc': 0.8}, step=1) - - -class TestComposedWriter: - sys.modules['torch.utils.tensorboard'] = MagicMock() - sys.modules['tensorboardX'] = MagicMock() - sys.modules['wandb'] = MagicMock() - - def test_init(self): - - class A: - pass - - # The writers inner element must be a dictionary or a - # subclass of Writer. - with pytest.raises(AssertionError): - ComposedWriter(writers=[A()]) - - composed_writer = ComposedWriter(writers=[ - WandbWriter(), - dict( - type='TensorboardWriter', - visualizer=dict(type='Visualizer'), - save_dir='temp_dir') - ]) - assert len(composed_writer._writers) == 2 - - # test global - composed_writer = ComposedWriter.get_instance( - 'composed_writer', - writers=[ - WandbWriter(), - dict( - type='TensorboardWriter', - visualizer=dict(type='Visualizer'), - save_dir='temp_dir') - ]) - assert len(composed_writer._writers) == 2 - composed_writer_any = ComposedWriter.get_instance('composed_writer') - assert composed_writer_any == composed_writer - - def test_get_writer(self): - composed_writer = ComposedWriter(writers=[ - WandbWriter(), - dict( - type='TensorboardWriter', - visualizer=dict(type='Visualizer'), - save_dir='temp_dir') - ]) - assert isinstance(composed_writer.get_writer(0), WandbWriter) - assert isinstance(composed_writer.get_writer(1), TensorboardWriter) - - def test_get_experiment(self): - composed_writer = ComposedWriter(writers=[ - WandbWriter(), - dict( - type='TensorboardWriter', - visualizer=dict(type='Visualizer'), - save_dir='temp_dir') - ]) - assert composed_writer.get_experiment( - 0) == composed_writer._writers[0].experiment - assert composed_writer.get_experiment( - 1) == composed_writer._writers[1].experiment - - def test_get_visualizer(self): - composed_writer = ComposedWriter(writers=[ - WandbWriter(), - dict( - type='TensorboardWriter', - visualizer=dict(type='Visualizer'), - save_dir='temp_dir') - ]) - assert composed_writer.get_visualizer( - 0) == composed_writer._writers[0].visualizer - assert composed_writer.get_visualizer( - 1) == composed_writer._writers[1].visualizer - - def test_add_params(self): - composed_writer = ComposedWriter(writers=[ - WandbWriter(), - dict( - type='TensorboardWriter', - visualizer=dict(type='Visualizer'), - save_dir='temp_dir') - ]) - - # 'params_dict' must be dict - with pytest.raises(AssertionError): - composed_writer.add_params(['lr', 0]) - - params_dict = dict(lr=0.1, wd=0.2, mode='linear') - composed_writer.add_params(params_dict) - - def test_add_graph(self): - composed_writer = ComposedWriter(writers=[ - WandbWriter(), - dict( - type='TensorboardWriter', - visualizer=dict(type='Visualizer'), - save_dir='temp_dir') - ]) - - class Model(nn.Module): - - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(1, 2, 1) - - def forward(self, x, y=None): - return self.conv(x) - - # input must be tensor - with pytest.raises(AssertionError): - composed_writer.add_graph(Model(), np.zeros([1, 1, 3, 3])) - - # input must be 4d tensor - with pytest.raises(AssertionError): - composed_writer.add_graph(Model(), torch.zeros([1, 3, 3])) - - # If the input is a list, the inner element must be a 4d tensor - with pytest.raises(AssertionError): - composed_writer.add_graph( - Model(), [torch.zeros([1, 1, 3, 3]), - torch.zeros([1, 3, 3])]) - - composed_writer.add_graph(Model(), torch.zeros([1, 1, 3, 3])) - composed_writer.add_graph( - Model(), [torch.zeros([1, 1, 3, 3]), - torch.zeros([1, 1, 3, 3])]) - - @patch('mmengine.visualization.visualizer.Visualizer.draw', draw) - @patch('mmengine.visualization.writer.WandbWriter.add_image_to_wandb', - Mock) - def test_add_image(self): - composed_writer = ComposedWriter(writers=[ - WandbWriter(), - dict( - type='TensorboardWriter', - visualizer=dict(type='Visualizer'), - save_dir='temp_dir') - ]) - - image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) - composed_writer.add_image('img', image) - - bboxes = np.array([[1, 1, 2, 2], [1, 1.5, 1, 2.5]]) - composed_writer.get_writer(1).visualizer.draw_bboxes(bboxes) - composed_writer.get_writer(1).add_image( - 'img', - composed_writer.get_writer(1).visualizer.get_image(), - step=2) - - def test_add_scalar(self): - composed_writer = ComposedWriter(writers=[ - WandbWriter(), - dict( - type='TensorboardWriter', - visualizer=dict(type='Visualizer'), - save_dir='temp_dir') - ]) - composed_writer.add_scalar('map', 0.9) - # test append mode - composed_writer.add_scalar('map', 0.9, step=0) - composed_writer.add_scalar('map', 0.95, step=1) - - def test_add_scalars(self): - composed_writer = ComposedWriter(writers=[ - WandbWriter(), - dict( - type='TensorboardWriter', - visualizer=dict(type='Visualizer'), - save_dir='temp_dir') - ]) - input_dict = {'map': 0.7, 'acc': 0.9} - composed_writer.add_scalars(input_dict) - # test append mode - composed_writer.add_scalars({'map': 0.8, 'acc': 0.8}, step=1)