Skip to content
Snippets Groups Projects
Unverified Commit be6f1898 authored by RangiLyu's avatar RangiLyu Committed by GitHub
Browse files

Update type hint and unit tests of evaluator. (#110)

* update type hint and ut for evaluator

* update doc

* fix
parent cfccabc6
No related branches found
No related tags found
No related merge requests found
...@@ -85,7 +85,7 @@ validation_cfg=dict( ...@@ -85,7 +85,7 @@ validation_cfg=dict(
首先,自定义评测器类应继承自 `BaseEvaluator`,并应加入注册器 `EVALUATORS` (关于注册器的说明请参考[相关文档](docs\zh_cn\tutorials\registry.md))。 首先,自定义评测器类应继承自 `BaseEvaluator`,并应加入注册器 `EVALUATORS` (关于注册器的说明请参考[相关文档](docs\zh_cn\tutorials\registry.md))。
`process()` 方法有 2 个输入参数,分别是测试数据样本`data_samples`和模型预测结果 `predictions`。我们从中分别取出样本类别标签和分类预测结果,并存放在 `self.results` 中。 `process()` 方法有 2 个输入参数,分别是一个批次的测试数据样本 `data_batch` 和模型预测结果 `predictions`。我们从中分别取出样本类别标签和分类预测结果,并存放在 `self.results` 中。
`compute_metrics()` 方法有 1 个输入参数 `results`,里面存放了所有批次测试数据经过 `process()` 方法处理后得到的结果。从中取出样本类别标签和分类预测结果,即可计算得到分类正确率 `acc`。最终,将计算得到的评测指标以字典的形式返回。 `compute_metrics()` 方法有 1 个输入参数 `results`,里面存放了所有批次测试数据经过 `process()` 方法处理后得到的结果。从中取出样本类别标签和分类预测结果,即可计算得到分类正确率 `acc`。最终,将计算得到的评测指标以字典的形式返回。
...@@ -111,14 +111,17 @@ class Accuracy(BaseEvaluator): ...@@ -111,14 +111,17 @@ class Accuracy(BaseEvaluator):
default_prefix = 'ACC' default_prefix = 'ACC'
def process(self, data_samples: Dict, predictions: Dict): def process(self, data_batch: Sequence[Tuple[Any, BaseDataSample]],
predictions: Sequence[BaseDataSample]):
"""Process one batch of data and predictions. The processed """Process one batch of data and predictions. The processed
Results should be stored in `self.results`, which will be used Results should be stored in `self.results`, which will be used
to computed the metrics when all batches have been processed. to computed the metrics when all batches have been processed.
Args: Args:
data_samples (dict): The data samples from the dataset. data_batch (Sequence[Tuple[Any, BaseDataSample]]): A batch of data
predictions (dict): The output of the model. from the dataloader.
predictions (Sequence[BaseDataSample]): A batch of outputs from
the model.
""" """
# 取出分类预测结果和类别标签 # 取出分类预测结果和类别标签
......
...@@ -5,7 +5,7 @@ import shutil ...@@ -5,7 +5,7 @@ import shutil
import tempfile import tempfile
import warnings import warnings
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import Any, List, Optional, Union from typing import Any, List, Optional, Sequence, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -62,14 +62,17 @@ class BaseEvaluator(metaclass=ABCMeta): ...@@ -62,14 +62,17 @@ class BaseEvaluator(metaclass=ABCMeta):
self._dataset_meta = dataset_meta self._dataset_meta = dataset_meta
@abstractmethod @abstractmethod
def process(self, data_samples: BaseDataSample, predictions: dict) -> None: def process(self, data_batch: Sequence[Tuple[Any, BaseDataSample]],
predictions: Sequence[BaseDataSample]) -> None:
"""Process one batch of data samples and predictions. The processed """Process one batch of data samples and predictions. The processed
results should be stored in ``self.results``, which will be used to results should be stored in ``self.results``, which will be used to
compute the metrics when all batches have been processed. compute the metrics when all batches have been processed.
Args: Args:
data_samples (BaseDataSample): The data samples from the dataset. data_batch (Sequence[Tuple[Any, BaseDataSample]]): A batch of data
predictions (dict): The output of the model. from the dataloader.
predictions (Sequence[BaseDataSample]): A batch of outputs from
the model.
""" """
@abstractmethod @abstractmethod
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence, Union from typing import Any, Optional, Sequence, Tuple, Union
from mmengine.data import BaseDataSample from mmengine.data import BaseDataSample
from .base import BaseEvaluator from .base import BaseEvaluator
...@@ -32,16 +32,19 @@ class ComposedEvaluator: ...@@ -32,16 +32,19 @@ class ComposedEvaluator:
for evaluator in self.evaluators: for evaluator in self.evaluators:
evaluator.dataset_meta = dataset_meta evaluator.dataset_meta = dataset_meta
def process(self, data_samples: BaseDataSample, predictions: dict): def process(self, data_batch: Sequence[Tuple[Any, BaseDataSample]],
predictions: Sequence[BaseDataSample]):
"""Invoke process method of each wrapped evaluator. """Invoke process method of each wrapped evaluator.
Args: Args:
data_samples (BaseDataSample): The data samples from the dataset. data_batch (Sequence[Tuple[Any, BaseDataSample]]): A batch of data
predictions (dict): The output of the model. from the dataloader.
predictions (Sequence[BaseDataSample]): A batch of outputs from
the model.
""" """
for evalutor in self.evaluators: for evalutor in self.evaluators:
evalutor.process(data_samples, predictions) evalutor.process(data_batch, predictions)
def evaluate(self, size: int) -> dict: def evaluate(self, size: int) -> dict:
"""Invoke evaluate method of each wrapped evaluator and collect the """Invoke evaluate method of each wrapped evaluator and collect the
......
...@@ -5,6 +5,7 @@ from unittest import TestCase ...@@ -5,6 +5,7 @@ from unittest import TestCase
import numpy as np import numpy as np
from mmengine.data import BaseDataSample
from mmengine.evaluator import BaseEvaluator, build_evaluator from mmengine.evaluator import BaseEvaluator, build_evaluator
from mmengine.registry import EVALUATORS from mmengine.registry import EVALUATORS
...@@ -36,17 +37,20 @@ class ToyEvaluator(BaseEvaluator): ...@@ -36,17 +37,20 @@ class ToyEvaluator(BaseEvaluator):
super().__init__(collect_device=collect_device, prefix=prefix) super().__init__(collect_device=collect_device, prefix=prefix)
self.dummy_metrics = dummy_metrics self.dummy_metrics = dummy_metrics
def process(self, data_samples, predictions): def process(self, data_batch, predictions):
result = {'pred': predictions['pred'], 'label': data_samples['label']} results = [{
self.results.append(result) 'pred': pred.pred,
'label': data[1].label
} for pred, data in zip(predictions, data_batch)]
self.results.extend(results)
def compute_metrics(self, results: List): def compute_metrics(self, results: List):
if self.dummy_metrics is not None: if self.dummy_metrics is not None:
assert isinstance(self.dummy_metrics, dict) assert isinstance(self.dummy_metrics, dict)
return self.dummy_metrics.copy() return self.dummy_metrics.copy()
pred = np.concatenate([result['pred'] for result in results]) pred = np.array([result['pred'] for result in results])
label = np.concatenate([result['label'] for result in results]) label = np.array([result['label'] for result in results])
acc = (pred == label).sum() / pred.size acc = (pred == label).sum() / pred.size
metrics = { metrics = {
...@@ -74,9 +78,11 @@ def generate_test_results(size, batch_size, pred, label): ...@@ -74,9 +78,11 @@ def generate_test_results(size, batch_size, pred, label):
bs_residual = size % batch_size bs_residual = size % batch_size
for i in range(num_batch): for i in range(num_batch):
bs = bs_residual if i == num_batch - 1 else batch_size bs = bs_residual if i == num_batch - 1 else batch_size
data_samples = {'label': np.full(bs, label)} data_batch = [(np.zeros(
predictions = {'pred': np.full(bs, pred)} (3, 10, 10)), BaseDataSample(data={'label': label}))
yield (data_samples, predictions) for _ in range(bs)]
predictions = [BaseDataSample(data={'pred': pred}) for _ in range(bs)]
yield (data_batch, predictions)
class TestBaseEvaluator(TestCase): class TestBaseEvaluator(TestCase):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment