From d0bcb83e4188ca0957e69f038162a04786f468e8 Mon Sep 17 00:00:00 2001
From: RangiLyu <lyuchqi@gmail.com>
Date: Thu, 24 Feb 2022 23:41:42 +0800
Subject: [PATCH] [Feature]: Add evaluator base class. (#41)

* [Feature]: Add evaluator base class.

* solve comments

* update

* fix
---
 docs/zh_cn/tutorials/registry.md         |   2 +
 mmengine/evaluator/__init__.py           |   6 +
 mmengine/evaluator/base.py               | 210 +++++++++++++++++++++++
 mmengine/evaluator/builder.py            |  16 ++
 mmengine/evaluator/composed_evaluator.py |  73 ++++++++
 mmengine/registry/__init__.py            |   5 +-
 mmengine/registry/root.py                |   3 +
 7 files changed, 313 insertions(+), 2 deletions(-)
 create mode 100644 mmengine/evaluator/__init__.py
 create mode 100644 mmengine/evaluator/base.py
 create mode 100644 mmengine/evaluator/builder.py
 create mode 100644 mmengine/evaluator/composed_evaluator.py

diff --git a/docs/zh_cn/tutorials/registry.md b/docs/zh_cn/tutorials/registry.md
index 3c209c62..164cb0ea 100644
--- a/docs/zh_cn/tutorials/registry.md
+++ b/docs/zh_cn/tutorials/registry.md
@@ -222,7 +222,9 @@ MMEngine 的注册器支持跨项目调用,即可以在一个项目中使用
 - WEIGHT_INITIALIZERS: 权重初始化的工具
 - OPTIMIZERS: 注册了 PyTorch 中所有的 `optimizer` 以及自定义的 `optimizer`
 - OPTIMIZER_CONSTRUCTORS: optimizer 的构造器
+- PARAM_SCHEDULERS: 各种参数调度器, 如 `MultiStepLR`
 - TASK_UTILS: 任务强相关的一些组件,如 `AnchorGenerator`, `BboxCoder`
+- EVALUATORS: 用于验证模型精度的评估器
 
 下面我们以 OpenMMLab 开源项目为例介绍如何跨项目调用模块。
 
diff --git a/mmengine/evaluator/__init__.py b/mmengine/evaluator/__init__.py
new file mode 100644
index 00000000..c2a8d5dd
--- /dev/null
+++ b/mmengine/evaluator/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base import BaseEvaluator
+from .builder import build_evaluator
+from .composed_evaluator import ComposedEvaluator
+
+__all__ = ['BaseEvaluator', 'ComposedEvaluator', 'build_evaluator']
diff --git a/mmengine/evaluator/base.py b/mmengine/evaluator/base.py
new file mode 100644
index 00000000..287c2fe2
--- /dev/null
+++ b/mmengine/evaluator/base.py
@@ -0,0 +1,210 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import pickle
+import shutil
+import tempfile
+import warnings
+from abc import ABCMeta, abstractmethod
+from typing import Any, List, Optional, Union
+
+import torch
+import torch.distributed as dist
+
+from mmengine.utils import mkdir_or_exist
+
+
+class BaseEvaluator(metaclass=ABCMeta):
+    """Base class for an evaluator.
+
+    The evaluator first processes each batch of data_samples and
+    predictions, and appends the processed results in to the results list.
+    Then it collects all results together from all ranks if distributed
+    training is used. Finally, it computes the metrics of the entire dataset.
+
+    Args:
+        collect_device (str): Device name used for collecting results from
+            different ranks during distributed training. Must be 'cpu' or
+            'gpu'. Defaults to 'cpu'.
+    """
+
+    def __init__(self, collect_device: str = 'cpu') -> None:
+        self._dataset_meta: Union[None, dict] = None
+        self.collect_device = collect_device
+        self.results: List[Any] = []
+
+        rank, world_size = get_dist_info()
+        self.rank = rank
+        self.world_size = world_size
+
+    @property
+    def dataset_meta(self) -> Optional[dict]:
+        return self._dataset_meta
+
+    @dataset_meta.setter
+    def dataset_meta(self, dataset_meta: dict) -> None:
+        self._dataset_meta = dataset_meta
+
+    @abstractmethod
+    def process(self, data_samples: dict, predictions: dict) -> None:
+        """Process one batch of data samples and predictions. The processed
+        results should be stored in ``self.results``, which will be used to
+        compute the metrics when all batches have been processed.
+
+        Args:
+            data_samples (dict): The data samples from the dataset.
+            predictions (dict): The output of the model.
+        """
+
+    @abstractmethod
+    def compute_metrics(self, results: list) -> dict:
+        """Compute the metrics from processed results.
+
+        Args:
+            results (list): The processed results of each batch.
+        Returns:
+            dict: The computed metrics. The keys are the names of the metrics,
+            and the values are corresponding results.
+        """
+
+    def evaluate(self, size: int) -> dict:
+        """Evaluate the model performance of the whole dataset after processing
+        all batches.
+
+        Args:
+            size (int): Length of the entire validation dataset. When batch
+                size > 1, the dataloader may pad some data samples to make
+                sure all ranks have the same length of dataset slice. The
+                ``collect_results`` function will drop the padded data base on
+                this size.
+
+        Returns:
+            metrics (dict): Evaluation metrics dict on the val dataset. The
+            keys are the names of the metrics, and the values are
+            corresponding results.
+        """
+        if len(self.results) == 0:
+            warnings.warn(
+                f'{self.__class__.__name__} got empty `self._results`. Please '
+                'ensure that the processed results are properly added into '
+                '`self._results` in `process` method.')
+
+        if self.world_size == 1:
+            # non-distributed
+            results = self.results
+        else:
+            results = collect_results(self.results, size, self.collect_device)
+
+        if self.rank == 0:
+            # TODO: replace with mmengine.dist.master_only
+            metrics = [self.compute_metrics(results)]
+        else:
+            metrics = [None]  # type: ignore
+        # TODO: replace with mmengine.dist.broadcast
+        if self.world_size > 1:
+            metrics = dist.broadcast_object_list(metrics)
+
+        # reset the results list
+        self.results.clear()
+        return metrics[0]
+
+
+# TODO: replace with mmengine.dist.get_dist_info
+def get_dist_info():
+    if dist.is_available() and dist.is_initialized():
+        rank = dist.get_rank()
+        world_size = dist.get_world_size()
+    else:
+        rank = 0
+        world_size = 1
+    return rank, world_size
+
+
+# TODO: replace with mmengine.dist.collect_results
+def collect_results(results, size, device='cpu'):
+    """Collected results in distributed environments."""
+    # TODO: replace with mmengine.dist.collect_results
+    if device == 'gpu':
+        return collect_results_gpu(results, size)
+    elif device == 'cpu':
+        return collect_results_cpu(results, size)
+    else:
+        NotImplementedError(f"device must be 'cpu' or 'gpu', but got {device}")
+
+
+# TODO: replace with mmengine.dist.collect_results
+def collect_results_cpu(result_part, size, tmpdir=None):
+    rank, world_size = get_dist_info()
+    # create a tmp dir if it is not specified
+    if tmpdir is None:
+        MAX_LEN = 512
+        # 32 is whitespace
+        dir_tensor = torch.full((MAX_LEN, ),
+                                32,
+                                dtype=torch.uint8,
+                                device='cuda')
+        if rank == 0:
+            mkdir_or_exist('.dist_test')
+            tmpdir = tempfile.mkdtemp(dir='.dist_test')
+            tmpdir = torch.tensor(
+                bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
+            dir_tensor[:len(tmpdir)] = tmpdir
+        dist.broadcast(dir_tensor, 0)
+        tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
+    else:
+        mkdir_or_exist(tmpdir)
+    # dump the part result to the dir
+    with open(osp.join(tmpdir, f'part_{rank}.pkl'), 'wb') as f:
+        pickle.dump(result_part, f, protocol=2)
+    dist.barrier()
+    # collect all parts
+    if rank != 0:
+        return None
+    else:
+        # load results of all parts from tmp dir
+        part_list = []
+        for i in range(world_size):
+            with open(osp.join(tmpdir, f'part_{i}.pkl'), 'wb') as f:
+                part_list.append(pickle.load(f))
+        # sort the results
+        ordered_results = []
+        for res in zip(*part_list):
+            ordered_results.extend(list(res))
+        # the dataloader may pad some samples
+        ordered_results = ordered_results[:size]
+        # remove tmp dir
+        shutil.rmtree(tmpdir)
+        return ordered_results
+
+
+# TODO: replace with mmengine.dist.collect_results
+def collect_results_gpu(result_part, size):
+    rank, world_size = get_dist_info()
+    # dump result part to tensor with pickle
+    part_tensor = torch.tensor(
+        bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
+    # gather all result part tensor shape
+    shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
+    shape_list = [shape_tensor.clone() for _ in range(world_size)]
+    dist.all_gather(shape_list, shape_tensor)
+    # padding result part tensor to max length
+    shape_max = torch.tensor(shape_list).max()
+    part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
+    part_send[:shape_tensor[0]] = part_tensor
+    part_recv_list = [
+        part_tensor.new_zeros(shape_max) for _ in range(world_size)
+    ]
+    # gather all result part
+    dist.all_gather(part_recv_list, part_send)
+
+    if rank == 0:
+        part_list = []
+        for recv, shape in zip(part_recv_list, shape_list):
+            part_list.append(
+                pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
+        # sort the results
+        ordered_results = []
+        for res in zip(*part_list):
+            ordered_results.extend(list(res))
+        # the dataloader may pad some samples
+        ordered_results = ordered_results[:size]
+        return ordered_results
diff --git a/mmengine/evaluator/builder.py b/mmengine/evaluator/builder.py
new file mode 100644
index 00000000..710c6554
--- /dev/null
+++ b/mmengine/evaluator/builder.py
@@ -0,0 +1,16 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..registry import EVALUATORS
+from .composed_evaluator import ComposedEvaluator
+
+
+def build_evaluator(cfg: dict) -> object:
+    """Build function of evaluator.
+
+    When the evaluator config is a list, it will automatically build composed
+    evaluators.
+    """
+    if isinstance(cfg, list):
+        evaluators = [EVALUATORS.build(_cfg) for _cfg in cfg]
+        return ComposedEvaluator(evaluators=evaluators)
+    else:
+        return EVALUATORS.build(cfg)
diff --git a/mmengine/evaluator/composed_evaluator.py b/mmengine/evaluator/composed_evaluator.py
new file mode 100644
index 00000000..225284e7
--- /dev/null
+++ b/mmengine/evaluator/composed_evaluator.py
@@ -0,0 +1,73 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional, Sequence, Union
+
+from .base import BaseEvaluator
+
+
+class ComposedEvaluator:
+    """Wrapper class to compose multiple :class:`DatasetEvaluator` instances.
+
+    Args:
+        evaluators (Sequence[BaseEvaluator]): The evaluators to compose.
+        collect_device (str): Device name used for collecting results from
+            different ranks during distributed training. Must be 'cpu' or
+            'gpu'. Defaults to 'cpu'.
+    """
+
+    def __init__(self,
+                 evaluators: Sequence[BaseEvaluator],
+                 collect_device='cpu'):
+        self._dataset_meta: Union[None, dict] = None
+        self.collect_device = collect_device
+        self.evaluators = evaluators
+
+    @property
+    def dataset_meta(self) -> Optional[dict]:
+        return self._dataset_meta
+
+    @dataset_meta.setter
+    def dataset_meta(self, dataset_meta: dict) -> None:
+        self._dataset_meta = dataset_meta
+        for evaluator in self.evaluators:
+            evaluator.dataset_meta = dataset_meta
+
+    def process(self, data_samples: dict, predictions: dict):
+        """Invoke process method of each wrapped evaluator.
+
+        Args:
+            data_samples (dict): The data samples from the dataset.
+            predictions (dict): The output of the model.
+        """
+
+        for evalutor in self.evaluators:
+            evalutor.process(data_samples, predictions)
+
+    def evaluate(self, size: int) -> dict:
+        """Invoke evaluate method of each wrapped evaluator and collect the
+        metrics dict.
+
+        Args:
+            size (int): Length of the entire validation dataset. When batch
+                size > 1, the dataloader may pad some data samples to make
+                sure all ranks have the same length of dataset slice. The
+                ``collect_results`` function will drop the padded data base on
+                this size.
+
+        Returns:
+            metrics (dict): Evaluation metrics of all wrapped evaluators. The
+            keys are the names of the metrics, and the values are
+            corresponding results.
+        """
+        metrics = {}
+        for evaluator in self.evaluators:
+            _metrics = evaluator.evaluate(size)
+
+            # Check metric name conflicts
+            for name in _metrics.keys():
+                if name in metrics:
+                    raise ValueError(
+                        'There are multiple evaluators with the same metric '
+                        f'name {name}')
+
+            metrics.update(_metrics)
+        return metrics
diff --git a/mmengine/registry/__init__.py b/mmengine/registry/__init__.py
index d8602d4f..24ebcb6e 100644
--- a/mmengine/registry/__init__.py
+++ b/mmengine/registry/__init__.py
@@ -1,6 +1,6 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 from .registry import Registry, build_from_cfg
-from .root import (DATA_SAMPLERS, DATASETS, HOOKS, MODELS,
+from .root import (DATA_SAMPLERS, DATASETS, EVALUATORS, HOOKS, MODELS,
                    OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, PARAM_SCHEDULERS,
                    RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS, TRANSFORMS,
                    WEIGHT_INITIALIZERS)
@@ -8,5 +8,6 @@ from .root import (DATA_SAMPLERS, DATASETS, HOOKS, MODELS,
 __all__ = [
     'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS',
     'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS',
-    'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS'
+    'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS',
+    'EVALUATORS'
 ]
diff --git a/mmengine/registry/root.py b/mmengine/registry/root.py
index c67a8d5b..9636fae2 100644
--- a/mmengine/registry/root.py
+++ b/mmengine/registry/root.py
@@ -34,3 +34,6 @@ PARAM_SCHEDULERS = Registry('parameter scheduler')
 
 # manage task-specific modules like anchor generators and box coders
 TASK_UTILS = Registry('task util')
+
+# manage all kinds of evaluators for computing metrics
+EVALUATORS = Registry('evaluator')
-- 
GitLab