Skip to content
Snippets Groups Projects
base.py 8.59 KiB
Newer Older
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import pickle
import shutil
import tempfile
import warnings
from abc import ABCMeta, abstractmethod
from typing import Any, List, Optional, Sequence, Tuple, Union

import torch
import torch.distributed as dist

from mmengine.data import BaseDataSample
from mmengine.utils import mkdir_or_exist


class BaseEvaluator(metaclass=ABCMeta):
    """Base class for an evaluator.

    The evaluator first processes each batch of data_samples and
    predictions, and appends the processed results in to the results list.
    Then it collects all results together from all ranks if distributed
    training is used. Finally, it computes the metrics of the entire dataset.

    A subclass of class:`BaseEvaluator` should assign a meanful value to the
    class attribute `default_prefix`. See the argument `prefix` for details.

    Args:
        collect_device (str): Device name used for collecting results from
            different ranks during distributed training. Must be 'cpu' or
            'gpu'. Defaults to 'cpu'.
        prefix (str, optional): The prefix that will be added in the metric
            names to disambiguate homonymous metrics of different evaluators.
            If prefix is not provided in the argument, self.default_prefix
            will be used instead. Default: None
    default_prefix: Optional[str] = None

    def __init__(self,
                 collect_device: str = 'cpu',
                 prefix: Optional[str] = None) -> None:
        self._dataset_meta: Union[None, dict] = None
        self.collect_device = collect_device
        self.results: List[Any] = []

        rank, world_size = get_dist_info()
        self.rank = rank
        self.world_size = world_size

        self.prefix = prefix or self.default_prefix
        if self.prefix is None:
            warnings.warn('The prefix is not set in evaluator class '
                          f'{self.__class__.__name__}.')

    @property
    def dataset_meta(self) -> Optional[dict]:
        return self._dataset_meta

    @dataset_meta.setter
    def dataset_meta(self, dataset_meta: dict) -> None:
        self._dataset_meta = dataset_meta

    @abstractmethod
    def process(self, data_batch: Sequence[Tuple[Any, BaseDataSample]],
                predictions: Sequence[BaseDataSample]) -> None:
        """Process one batch of data samples and predictions. The processed
        results should be stored in ``self.results``, which will be used to
        compute the metrics when all batches have been processed.

        Args:
            data_batch (Sequence[Tuple[Any, BaseDataSample]]): A batch of data
                from the dataloader.
            predictions (Sequence[BaseDataSample]): A batch of outputs from
                the model.
        """

    @abstractmethod
    def compute_metrics(self, results: list) -> dict:
        """Compute the metrics from processed results.

        Args:
            results (list): The processed results of each batch.
        Returns:
            dict: The computed metrics. The keys are the names of the metrics,
            and the values are corresponding results.
        """

    def evaluate(self, size: int) -> dict:
        """Evaluate the model performance of the whole dataset after processing
        all batches.

        Args:
            size (int): Length of the entire validation dataset. When batch
                size > 1, the dataloader may pad some data samples to make
                sure all ranks have the same length of dataset slice. The
                ``collect_results`` function will drop the padded data base on
                this size.

        Returns:
            dict: Evaluation metrics dict on the val dataset. The keys are the
            names of the metrics, and the values are corresponding results.
        """
        if len(self.results) == 0:
            warnings.warn(
                f'{self.__class__.__name__} got empty `self._results`. Please '
                'ensure that the processed results are properly added into '
                '`self._results` in `process` method.')

        if self.world_size == 1:
            # non-distributed
            results = self.results
        else:
            results = collect_results(self.results, size, self.collect_device)

        if self.rank == 0:
            # TODO: replace with mmengine.dist.master_only
            metrics = self.compute_metrics(results)
            # Add prefix to metric names
            if self.prefix:
                metrics = {
                    '.'.join((self.prefix, k)): v
                    for k, v in metrics.items()
                }
            metrics = [metrics]  # type: ignore
        else:
            metrics = [None]  # type: ignore
        # TODO: replace with mmengine.dist.broadcast
        if self.world_size > 1:
            metrics = dist.broadcast_object_list(metrics)

        # reset the results list
        self.results.clear()
        return metrics[0]


# TODO: replace with mmengine.dist.get_dist_info
def get_dist_info():
    if dist.is_available() and dist.is_initialized():
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    else:
        rank = 0
        world_size = 1
    return rank, world_size


# TODO: replace with mmengine.dist.collect_results
def collect_results(results, size, device='cpu'):
    """Collected results in distributed environments."""
    # TODO: replace with mmengine.dist.collect_results
    if device == 'gpu':
        return collect_results_gpu(results, size)
    elif device == 'cpu':
        return collect_results_cpu(results, size)
    else:
        NotImplementedError(f"device must be 'cpu' or 'gpu', but got {device}")


# TODO: replace with mmengine.dist.collect_results
def collect_results_cpu(result_part, size, tmpdir=None):
    rank, world_size = get_dist_info()
    # create a tmp dir if it is not specified
    if tmpdir is None:
        MAX_LEN = 512
        # 32 is whitespace
        dir_tensor = torch.full((MAX_LEN, ),
                                32,
                                dtype=torch.uint8,
                                device='cuda')
        if rank == 0:
            mkdir_or_exist('.dist_test')
            tmpdir = tempfile.mkdtemp(dir='.dist_test')
            tmpdir = torch.tensor(
                bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
            dir_tensor[:len(tmpdir)] = tmpdir
        dist.broadcast(dir_tensor, 0)
        tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
    else:
        mkdir_or_exist(tmpdir)
    # dump the part result to the dir
    with open(osp.join(tmpdir, f'part_{rank}.pkl'), 'wb') as f:
        pickle.dump(result_part, f, protocol=2)
    dist.barrier()
    # collect all parts
    if rank != 0:
        return None
    else:
        # load results of all parts from tmp dir
        part_list = []
        for i in range(world_size):
            with open(osp.join(tmpdir, f'part_{i}.pkl'), 'wb') as f:
                part_list.append(pickle.load(f))
        # sort the results
        ordered_results = []
        for res in zip(*part_list):
            ordered_results.extend(list(res))
        # the dataloader may pad some samples
        ordered_results = ordered_results[:size]
        # remove tmp dir
        shutil.rmtree(tmpdir)
        return ordered_results


# TODO: replace with mmengine.dist.collect_results
def collect_results_gpu(result_part, size):
    rank, world_size = get_dist_info()
    # dump result part to tensor with pickle
    part_tensor = torch.tensor(
        bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
    # gather all result part tensor shape
    shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
    shape_list = [shape_tensor.clone() for _ in range(world_size)]
    dist.all_gather(shape_list, shape_tensor)
    # padding result part tensor to max length
    shape_max = torch.tensor(shape_list).max()
    part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
    part_send[:shape_tensor[0]] = part_tensor
    part_recv_list = [
        part_tensor.new_zeros(shape_max) for _ in range(world_size)
    ]
    # gather all result part
    dist.all_gather(part_recv_list, part_send)

    if rank == 0:
        part_list = []
        for recv, shape in zip(part_recv_list, shape_list):
            part_list.append(
                pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
        # sort the results
        ordered_results = []
        for res in zip(*part_list):
            ordered_results.extend(list(res))
        # the dataloader may pad some samples
        ordered_results = ordered_results[:size]
        return ordered_results