From 61fecabea6c8cdcd305279b65140fce9e4e9a7ba Mon Sep 17 00:00:00 2001 From: Yining Li <liyining0712@gmail.com> Date: Thu, 10 Mar 2022 17:25:20 +0800 Subject: [PATCH] [Feature] Update evaluator prefix (#114) * update evaluator prefix * update docstring and comments * update doc --- docs/zh_cn/tutorials/evaluator.md | 4 +- mmengine/evaluator/__init__.py | 5 +- mmengine/evaluator/base.py | 2 +- mmengine/evaluator/utils.py | 37 ++++++++++++++ tests/test_evaluator/test_base_evaluator.py | 55 ++++++++++++++++++--- 5 files changed, 91 insertions(+), 12 deletions(-) create mode 100644 mmengine/evaluator/utils.py diff --git a/docs/zh_cn/tutorials/evaluator.md b/docs/zh_cn/tutorials/evaluator.md index 9122a287..695073d4 100644 --- a/docs/zh_cn/tutorials/evaluator.md +++ b/docs/zh_cn/tutorials/evaluator.md @@ -56,7 +56,7 @@ validation_cfg=dict( ], # 指定使用å‰ç¼€ä¸º COCO çš„ AP 为主è¦è¯„æµ‹æŒ‡æ ‡ # 在没有é‡åæŒ‡æ ‡æ§ä¹‰çš„情况下,æ¤å¤„å¯ä»¥ä¸å†™å‰ç¼€ï¼Œåªå†™è¯„æµ‹æŒ‡æ ‡å - main_metric='COCO.AP', + main_metric='COCO/AP', interval=10, by_epoch=True, ) @@ -106,7 +106,7 @@ class Accuracy(BaseEvaluator): Default prefix: ACC Metrics: - - accuracy: classification accuracy + - accuracy (float): classification accuracy """ default_prefix = 'ACC' diff --git a/mmengine/evaluator/__init__.py b/mmengine/evaluator/__init__.py index c2a8d5dd..38668510 100644 --- a/mmengine/evaluator/__init__.py +++ b/mmengine/evaluator/__init__.py @@ -2,5 +2,8 @@ from .base import BaseEvaluator from .builder import build_evaluator from .composed_evaluator import ComposedEvaluator +from .utils import get_metric_value -__all__ = ['BaseEvaluator', 'ComposedEvaluator', 'build_evaluator'] +__all__ = [ + 'BaseEvaluator', 'ComposedEvaluator', 'build_evaluator', 'get_metric_value' +] diff --git a/mmengine/evaluator/base.py b/mmengine/evaluator/base.py index 309e288d..dcf447be 100644 --- a/mmengine/evaluator/base.py +++ b/mmengine/evaluator/base.py @@ -120,7 +120,7 @@ class BaseEvaluator(metaclass=ABCMeta): # Add prefix to metric names if self.prefix: metrics = { - '.'.join((self.prefix, k)): v + '/'.join((self.prefix, k)): v for k, v in metrics.items() } metrics = [metrics] # type: ignore diff --git a/mmengine/evaluator/utils.py b/mmengine/evaluator/utils.py new file mode 100644 index 00000000..17c3229b --- /dev/null +++ b/mmengine/evaluator/utils.py @@ -0,0 +1,37 @@ +from typing import Any, Dict + + +def get_metric_value(indicator: str, metrics: Dict) -> Any: + """Get the metric value specified by an indicator, which can be either a + metric name or a full name with evaluator prefix. + + Args: + indicator (str): The metric indicator, which can be the metric name + (e.g. 'AP') or the full name with prefix (e.g. 'COCO/AP') + metrics (dict): The evaluation results output by the evaluator + + Returns: + Any: The specified metric value + """ + + if '/' in indicator: + # The indicator is a full name + if indicator in metrics: + return metrics[indicator] + else: + raise ValueError( + f'The indicator "{indicator}" can not match any metric in ' + f'{list(metrics.keys())}') + else: + # The indicator is metric name without prefix + matched = [k for k in metrics.keys() if k.split('/')[-1] == indicator] + + if not matched: + raise ValueError( + f'The indicator {indicator} can not match any metric in ' + f'{list(metrics.keys())}') + elif len(matched) > 1: + raise ValueError(f'The indicator "{indicator}" matches multiple ' + f'metrics {matched}') + else: + return metrics[matched[0]] diff --git a/tests/test_evaluator/test_base_evaluator.py b/tests/test_evaluator/test_base_evaluator.py index c880e26e..7040abff 100644 --- a/tests/test_evaluator/test_base_evaluator.py +++ b/tests/test_evaluator/test_base_evaluator.py @@ -6,7 +6,7 @@ from unittest import TestCase import numpy as np from mmengine.data import BaseDataSample -from mmengine.evaluator import BaseEvaluator, build_evaluator +from mmengine.evaluator import BaseEvaluator, build_evaluator, get_metric_value from mmengine.registry import EVALUATORS @@ -62,7 +62,7 @@ class ToyEvaluator(BaseEvaluator): @EVALUATORS.register_module() -class UnprefixedEvaluator(BaseEvaluator): +class NonPrefixedEvaluator(BaseEvaluator): """Evaluator with unassigned `default_prefix` to test the warning information.""" @@ -100,8 +100,8 @@ class TestBaseEvaluator(TestCase): evaluator.process(data_samples, predictions) metrics = evaluator.evaluate(size=size) - self.assertAlmostEqual(metrics['Toy.accuracy'], 1.0) - self.assertEqual(metrics['Toy.size'], size) + self.assertAlmostEqual(metrics['Toy/accuracy'], 1.0) + self.assertEqual(metrics['Toy/size'], size) # Test empty results cfg = dict(type='ToyEvaluator', dummy_metrics=dict(accuracy=1.0)) @@ -126,9 +126,9 @@ class TestBaseEvaluator(TestCase): metrics = evaluator.evaluate(size=size) - self.assertAlmostEqual(metrics['Toy.accuracy'], 1.0) - self.assertAlmostEqual(metrics['Toy.mAP'], 0.0) - self.assertEqual(metrics['Toy.size'], size) + self.assertAlmostEqual(metrics['Toy/accuracy'], 1.0) + self.assertAlmostEqual(metrics['Toy/mAP'], 0.0) + self.assertEqual(metrics['Toy/size'], size) def test_ambiguate_metric(self): @@ -167,6 +167,45 @@ class TestBaseEvaluator(TestCase): self.assertDictEqual(_evaluator.dataset_meta, dataset_meta) def test_prefix(self): - cfg = dict(type='UnprefixedEvaluator') + cfg = dict(type='NonPrefixedEvaluator') with self.assertWarnsRegex(UserWarning, 'The prefix is not set'): _ = build_evaluator(cfg) + + def test_get_metric_value(self): + + metrics = { + 'prefix_0/metric_0': 0, + 'prefix_1/metric_0': 1, + 'prefix_1/metric_1': 2, + 'nonprefixed': 3, + } + + # Test indicator with prefix + indicator = 'prefix_0/metric_0' # correct indicator + self.assertEqual(get_metric_value(indicator, metrics), 0) + + indicator = 'prefix_1/metric_0' # correct indicator + self.assertEqual(get_metric_value(indicator, metrics), 1) + + indicator = 'prefix_0/metric_1' # unmatched indicator (wrong metric) + with self.assertRaisesRegex(ValueError, 'can not match any metric'): + _ = get_metric_value(indicator, metrics) + + indicator = 'prefix_2/metric' # unmatched indicator (wrong prefix) + with self.assertRaisesRegex(ValueError, 'can not match any metric'): + _ = get_metric_value(indicator, metrics) + + # Test indicator without prefix + indicator = 'metric_1' # correct indicator (prefixed metric) + self.assertEqual(get_metric_value(indicator, metrics), 2) + + indicator = 'nonprefixed' # correct indicator (non-prefixed metric) + self.assertEqual(get_metric_value(indicator, metrics), 3) + + indicator = 'metric_0' # ambiguous indicator + with self.assertRaisesRegex(ValueError, 'matches multiple metrics'): + _ = get_metric_value(indicator, metrics) + + indicator = 'metric_2' # unmatched indicator + with self.assertRaisesRegex(ValueError, 'can not match any metric'): + _ = get_metric_value(indicator, metrics) -- GitLab