diff --git a/docs/zh_cn/tutorials/evaluator.md b/docs/zh_cn/tutorials/evaluator.md index d2ab35b66168753846a9e122f3de9b431b546259..00649d582ff2ad95727acdc1b777ec989fd606bc 100644 --- a/docs/zh_cn/tutorials/evaluator.md +++ b/docs/zh_cn/tutorials/evaluator.md @@ -37,8 +37,8 @@ validation_cfg=dict( ```python validation_cfg=dict( evaluator=[ - dict(type='accuracy', top_k=1), # 使用分类æ£ç¡®çŽ‡è¯„测器 - dict(type='f1_score') # 使用 F1_score 评测器 + dict(type='Accuracy', top_k=1), # 使用分类æ£ç¡®çŽ‡è¯„测器 + dict(type='F1Score') # 使用 F1_score 评测器 ], main_metric='accuracy' interval=10, @@ -51,8 +51,8 @@ validation_cfg=dict( ```python validation_cfg=dict( evaluator=[ - dict(type='accuracy', top_k=1, prefix='top1'), - dict(type='accuracy', top_k=5, prefix='top5') + dict(type='Accuracy', top_k=1, prefix='top1'), + dict(type='Accuracy', top_k=5, prefix='top5') ], main_metric='top1_accuracy', # å‰ç¼€ 'top1' è¢«è‡ªåŠ¨æ·»åŠ è¿›æŒ‡æ ‡å称ä¸ï¼Œç”¨ä»¥åŒºåˆ†åŒåæŒ‡æ ‡ interval=10, @@ -62,7 +62,7 @@ validation_cfg=dict( ## å¢žåŠ è‡ªå®šä¹‰è¯„æµ‹å™¨ -在 OpenMMLab çš„å„个算法库ä¸ï¼Œå·²ç»å®žçŽ°äº†å¯¹åº”æ–¹å‘的常用评测器。如 MMDetection ä¸æ供了 COCO 评测器,MMClassification ä¸æ供了 accuracyã€f1_score ç‰è¯„测器ç‰ã€‚ +在 OpenMMLab çš„å„个算法库ä¸ï¼Œå·²ç»å®žçŽ°äº†å¯¹åº”æ–¹å‘的常用评测器。如 MMDetection ä¸æ供了 COCO 评测器,MMClassification ä¸æ供了 Accuracyã€F1Score ç‰è¯„测器ç‰ã€‚ 用户也å¯ä»¥æ ¹æ®è‡ªèº«éœ€æ±‚ï¼Œå¢žåŠ è‡ªå®šä¹‰çš„è¯„æµ‹å™¨ã€‚åœ¨å®žçŽ°è‡ªå®šä¹‰è¯„æµ‹å™¨æ—¶ï¼Œç”¨æˆ·éœ€è¦ç»§æ‰¿ MMEngine ä¸æ供的评测器基类 [BaseEvaluator](Todo:baseevaluator-doc-link),并实现对应的抽象方法。 @@ -96,8 +96,8 @@ from mmengine.registry import EVALUATORS import numpy as np @EVALUATORS.register_module() -class AccuracyEvaluator(BaseEvaluator): - +class Accuracy(BaseEvaluator): + def process(self, data_samples: Dict, predictions: Dict): """Process one batch of data and predictions. The processed Results should be stored in `self.results`, which will be used diff --git a/tests/test_evaluator/test_base_evaluator.py b/tests/test_evaluator/test_base_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..382c8b4ebe352d8ef98ede630cd690a7b68f89f6 --- /dev/null +++ b/tests/test_evaluator/test_base_evaluator.py @@ -0,0 +1,136 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Dict, List, Optional +from unittest import TestCase + +import numpy as np + +from mmengine.evaluator import BaseEvaluator, ComposedEvaluator +from mmengine.registry import EVALUATORS + + +@EVALUATORS.register_module() +class ToyEvaluator(BaseEvaluator): + + def __init__(self, + collect_device: str = 'cpu', + dummy_metrics: Optional[Dict] = None): + super().__init__(collect_device=collect_device) + self.dummy_metrics = dummy_metrics + + def process(self, data_samples, predictions): + result = {'pred': predictions['pred'], 'label': data_samples['label']} + self.results.append(result) + + def compute_metrics(self, results: List): + if self.dummy_metrics is not None: + assert isinstance(self.dummy_metrics, dict) + return self.dummy_metrics.copy() + + pred = np.concatenate([result['pred'] for result in results]) + label = np.concatenate([result['label'] for result in results]) + acc = (pred == label).sum() / pred.size + + metrics = { + 'accuracy': acc, + 'size': pred.size, # To check the number of testing samples + } + + return metrics + + +def generate_test_results(size, batch_size, pred, label): + num_batch = math.ceil(size / batch_size) + bs_residual = size % batch_size + for i in range(num_batch): + bs = bs_residual if i == num_batch - 1 else batch_size + data_samples = {'label': np.full(bs, label)} + predictions = {'pred': np.full(bs, pred)} + yield (data_samples, predictions) + + +class TestBaseEvaluator(TestCase): + + def build_evaluator(self, cfg): + if isinstance(cfg, (list, tuple)): + evaluators = [EVALUATORS.build(_cfg) for _cfg in cfg] + return ComposedEvaluator(evaluators=evaluators) + else: + return EVALUATORS.build(cfg) + + def test_single_evaluator(self): + cfg = dict(type='ToyEvaluator') + evaluator = self.build_evaluator(cfg) + + size = 10 + batch_size = 4 + + for data_samples, predictions in generate_test_results( + size, batch_size, pred=1, label=1): + evaluator.process(data_samples, predictions) + + metrics = evaluator.evaluate(size=size) + self.assertAlmostEqual(metrics['accuracy'], 1.0) + self.assertEqual(metrics['size'], size) + + # Test empty results + cfg = dict(type='ToyEvaluator', dummy_metrics=dict(accuracy=1.0)) + evaluator = self.build_evaluator(cfg) + with self.assertWarnsRegex(UserWarning, 'got empty `self._results`.'): + evaluator.evaluate(0) + + def test_composed_evaluator(self): + cfg = [ + dict(type='ToyEvaluator'), + dict(type='ToyEvaluator', dummy_metrics=dict(mAP=0.0)) + ] + + evaluator = self.build_evaluator(cfg) + + size = 10 + batch_size = 4 + + for data_samples, predictions in generate_test_results( + size, batch_size, pred=1, label=1): + evaluator.process(data_samples, predictions) + + metrics = evaluator.evaluate(size=size) + + self.assertAlmostEqual(metrics['accuracy'], 1.0) + self.assertAlmostEqual(metrics['mAP'], 0.0) + self.assertEqual(metrics['size'], size) + + def test_ambiguate_metric(self): + + cfg = [ + dict(type='ToyEvaluator', dummy_metrics=dict(mAP=0.0)), + dict(type='ToyEvaluator', dummy_metrics=dict(mAP=0.0)) + ] + + evaluator = self.build_evaluator(cfg) + + size = 10 + batch_size = 4 + + for data_samples, predictions in generate_test_results( + size, batch_size, pred=1, label=1): + evaluator.process(data_samples, predictions) + + with self.assertRaisesRegex( + ValueError, + 'There are multiple evaluators with the same metric name'): + _ = evaluator.evaluate(size=size) + + def test_dataset_meta(self): + dataset_meta = dict(classes=('cat', 'dog')) + + cfg = [ + dict(type='ToyEvaluator'), + dict(type='ToyEvaluator', dummy_metrics=dict(mAP=0.0)) + ] + + evaluator = self.build_evaluator(cfg) + evaluator.dataset_meta = dataset_meta + + for _evaluator in evaluator.evaluators: + self.assertDictEqual(_evaluator.dataset_meta, dataset_meta)