Skip to content
Snippets Groups Projects
Unverified Commit 61fecabe authored by Yining Li's avatar Yining Li Committed by GitHub
Browse files

[Feature] Update evaluator prefix (#114)

* update evaluator prefix

* update docstring and comments

* update doc
parent 3e0c064f
No related branches found
No related tags found
No related merge requests found
...@@ -56,7 +56,7 @@ validation_cfg=dict( ...@@ -56,7 +56,7 @@ validation_cfg=dict(
], ],
# 指定使用前缀为 COCO 的 AP 为主要评测指标 # 指定使用前缀为 COCO 的 AP 为主要评测指标
# 在没有重名指标歧义的情况下,此处可以不写前缀,只写评测指标名 # 在没有重名指标歧义的情况下,此处可以不写前缀,只写评测指标名
main_metric='COCO.AP', main_metric='COCO/AP',
interval=10, interval=10,
by_epoch=True, by_epoch=True,
) )
...@@ -106,7 +106,7 @@ class Accuracy(BaseEvaluator): ...@@ -106,7 +106,7 @@ class Accuracy(BaseEvaluator):
Default prefix: ACC Default prefix: ACC
Metrics: Metrics:
- accuracy: classification accuracy - accuracy (float): classification accuracy
""" """
default_prefix = 'ACC' default_prefix = 'ACC'
......
...@@ -2,5 +2,8 @@ ...@@ -2,5 +2,8 @@
from .base import BaseEvaluator from .base import BaseEvaluator
from .builder import build_evaluator from .builder import build_evaluator
from .composed_evaluator import ComposedEvaluator from .composed_evaluator import ComposedEvaluator
from .utils import get_metric_value
__all__ = ['BaseEvaluator', 'ComposedEvaluator', 'build_evaluator'] __all__ = [
'BaseEvaluator', 'ComposedEvaluator', 'build_evaluator', 'get_metric_value'
]
...@@ -120,7 +120,7 @@ class BaseEvaluator(metaclass=ABCMeta): ...@@ -120,7 +120,7 @@ class BaseEvaluator(metaclass=ABCMeta):
# Add prefix to metric names # Add prefix to metric names
if self.prefix: if self.prefix:
metrics = { metrics = {
'.'.join((self.prefix, k)): v '/'.join((self.prefix, k)): v
for k, v in metrics.items() for k, v in metrics.items()
} }
metrics = [metrics] # type: ignore metrics = [metrics] # type: ignore
......
from typing import Any, Dict
def get_metric_value(indicator: str, metrics: Dict) -> Any:
"""Get the metric value specified by an indicator, which can be either a
metric name or a full name with evaluator prefix.
Args:
indicator (str): The metric indicator, which can be the metric name
(e.g. 'AP') or the full name with prefix (e.g. 'COCO/AP')
metrics (dict): The evaluation results output by the evaluator
Returns:
Any: The specified metric value
"""
if '/' in indicator:
# The indicator is a full name
if indicator in metrics:
return metrics[indicator]
else:
raise ValueError(
f'The indicator "{indicator}" can not match any metric in '
f'{list(metrics.keys())}')
else:
# The indicator is metric name without prefix
matched = [k for k in metrics.keys() if k.split('/')[-1] == indicator]
if not matched:
raise ValueError(
f'The indicator {indicator} can not match any metric in '
f'{list(metrics.keys())}')
elif len(matched) > 1:
raise ValueError(f'The indicator "{indicator}" matches multiple '
f'metrics {matched}')
else:
return metrics[matched[0]]
...@@ -6,7 +6,7 @@ from unittest import TestCase ...@@ -6,7 +6,7 @@ from unittest import TestCase
import numpy as np import numpy as np
from mmengine.data import BaseDataSample from mmengine.data import BaseDataSample
from mmengine.evaluator import BaseEvaluator, build_evaluator from mmengine.evaluator import BaseEvaluator, build_evaluator, get_metric_value
from mmengine.registry import EVALUATORS from mmengine.registry import EVALUATORS
...@@ -62,7 +62,7 @@ class ToyEvaluator(BaseEvaluator): ...@@ -62,7 +62,7 @@ class ToyEvaluator(BaseEvaluator):
@EVALUATORS.register_module() @EVALUATORS.register_module()
class UnprefixedEvaluator(BaseEvaluator): class NonPrefixedEvaluator(BaseEvaluator):
"""Evaluator with unassigned `default_prefix` to test the warning """Evaluator with unassigned `default_prefix` to test the warning
information.""" information."""
...@@ -100,8 +100,8 @@ class TestBaseEvaluator(TestCase): ...@@ -100,8 +100,8 @@ class TestBaseEvaluator(TestCase):
evaluator.process(data_samples, predictions) evaluator.process(data_samples, predictions)
metrics = evaluator.evaluate(size=size) metrics = evaluator.evaluate(size=size)
self.assertAlmostEqual(metrics['Toy.accuracy'], 1.0) self.assertAlmostEqual(metrics['Toy/accuracy'], 1.0)
self.assertEqual(metrics['Toy.size'], size) self.assertEqual(metrics['Toy/size'], size)
# Test empty results # Test empty results
cfg = dict(type='ToyEvaluator', dummy_metrics=dict(accuracy=1.0)) cfg = dict(type='ToyEvaluator', dummy_metrics=dict(accuracy=1.0))
...@@ -126,9 +126,9 @@ class TestBaseEvaluator(TestCase): ...@@ -126,9 +126,9 @@ class TestBaseEvaluator(TestCase):
metrics = evaluator.evaluate(size=size) metrics = evaluator.evaluate(size=size)
self.assertAlmostEqual(metrics['Toy.accuracy'], 1.0) self.assertAlmostEqual(metrics['Toy/accuracy'], 1.0)
self.assertAlmostEqual(metrics['Toy.mAP'], 0.0) self.assertAlmostEqual(metrics['Toy/mAP'], 0.0)
self.assertEqual(metrics['Toy.size'], size) self.assertEqual(metrics['Toy/size'], size)
def test_ambiguate_metric(self): def test_ambiguate_metric(self):
...@@ -167,6 +167,45 @@ class TestBaseEvaluator(TestCase): ...@@ -167,6 +167,45 @@ class TestBaseEvaluator(TestCase):
self.assertDictEqual(_evaluator.dataset_meta, dataset_meta) self.assertDictEqual(_evaluator.dataset_meta, dataset_meta)
def test_prefix(self): def test_prefix(self):
cfg = dict(type='UnprefixedEvaluator') cfg = dict(type='NonPrefixedEvaluator')
with self.assertWarnsRegex(UserWarning, 'The prefix is not set'): with self.assertWarnsRegex(UserWarning, 'The prefix is not set'):
_ = build_evaluator(cfg) _ = build_evaluator(cfg)
def test_get_metric_value(self):
metrics = {
'prefix_0/metric_0': 0,
'prefix_1/metric_0': 1,
'prefix_1/metric_1': 2,
'nonprefixed': 3,
}
# Test indicator with prefix
indicator = 'prefix_0/metric_0' # correct indicator
self.assertEqual(get_metric_value(indicator, metrics), 0)
indicator = 'prefix_1/metric_0' # correct indicator
self.assertEqual(get_metric_value(indicator, metrics), 1)
indicator = 'prefix_0/metric_1' # unmatched indicator (wrong metric)
with self.assertRaisesRegex(ValueError, 'can not match any metric'):
_ = get_metric_value(indicator, metrics)
indicator = 'prefix_2/metric' # unmatched indicator (wrong prefix)
with self.assertRaisesRegex(ValueError, 'can not match any metric'):
_ = get_metric_value(indicator, metrics)
# Test indicator without prefix
indicator = 'metric_1' # correct indicator (prefixed metric)
self.assertEqual(get_metric_value(indicator, metrics), 2)
indicator = 'nonprefixed' # correct indicator (non-prefixed metric)
self.assertEqual(get_metric_value(indicator, metrics), 3)
indicator = 'metric_0' # ambiguous indicator
with self.assertRaisesRegex(ValueError, 'matches multiple metrics'):
_ = get_metric_value(indicator, metrics)
indicator = 'metric_2' # unmatched indicator
with self.assertRaisesRegex(ValueError, 'can not match any metric'):
_ = get_metric_value(indicator, metrics)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment