diff --git a/mmengine/evaluator/__init__.py b/mmengine/evaluator/__init__.py index ac91828d98ba56345aa58214f8fcdcabca861ecc..e6bc78425e2a3194bdaa4da29e6b3e238237fafa 100644 --- a/mmengine/evaluator/__init__.py +++ b/mmengine/evaluator/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .evaluator import Evaluator -from .metric import BaseMetric +from .metric import BaseMetric, DumpResults from .utils import get_metric_value -__all__ = ['BaseMetric', 'Evaluator', 'get_metric_value'] +__all__ = ['BaseMetric', 'Evaluator', 'get_metric_value', 'DumpResults'] diff --git a/mmengine/evaluator/metric.py b/mmengine/evaluator/metric.py index 7d8fa73676e11a6a3a8fba8d030984f6cf7e6d25..2f3b0c01d4db7c52d8e4f1ccbac4183530dccd3e 100644 --- a/mmengine/evaluator/metric.py +++ b/mmengine/evaluator/metric.py @@ -3,8 +3,14 @@ import warnings from abc import ABCMeta, abstractmethod from typing import Any, List, Optional, Sequence, Union +from torch import Tensor + +from mmengine.data import BaseDataElement from mmengine.dist import (broadcast_object_list, collect_results, is_main_process) +from mmengine.fileio import dump +from mmengine.logging import print_log +from mmengine.registry import METRICS class BaseMetric(metaclass=ABCMeta): @@ -116,3 +122,51 @@ class BaseMetric(metaclass=ABCMeta): # reset the results list self.results.clear() return metrics[0] + + +@METRICS.register_module() +class DumpResults(BaseMetric): + """Dump model predictions to a pickle file for offline evaluation. + + Args: + out_file_path (str): Path of the dumped file. Must end with '.pkl' + or '.pickle'. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + """ + + def __init__(self, + out_file_path: str, + collect_device: str = 'cpu') -> None: + super().__init__(collect_device=collect_device) + if not out_file_path.endswith(('.pkl', '.pickle')): + raise ValueError('The output file must be a pkl file.') + self.out_file_path = out_file_path + + def process(self, data_batch: Sequence[dict], + predictions: Sequence[dict]) -> None: + """transfer tensors in predictions to CPU.""" + self.results.extend(_to_cpu(predictions)) + + def compute_metrics(self, results: list) -> dict: + """dump the prediction results to a pickle file.""" + dump(results, self.out_file_path) + print_log( + f'Results has been saved to {self.out_file_path}.', + logger='current') + return {} + + +def _to_cpu(data: Any) -> Any: + """transfer all tensors and BaseDataElement to cpu.""" + if isinstance(data, (Tensor, BaseDataElement)): + return data.to('cpu') + elif isinstance(data, list): + return [_to_cpu(d) for d in data] + elif isinstance(data, tuple): + return tuple(_to_cpu(d) for d in data) + elif isinstance(data, dict): + return {k: _to_cpu(v) for k, v in data.items()} + else: + return data diff --git a/tests/test_evaluator/test_metric.py b/tests/test_evaluator/test_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..91d66ea15a303cc1e411794dc87c89eb91c12c80 --- /dev/null +++ b/tests/test_evaluator/test_metric.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import tempfile +from unittest import TestCase + +import torch +from torch import Tensor + +from mmengine.evaluator import DumpResults +from mmengine.fileio import load + + +class TestDumpResults(TestCase): + + def test_init(self): + with self.assertRaisesRegex(ValueError, + 'The output file must be a pkl file.'): + DumpResults(out_file_path='./results.json') + + def test_process(self): + metric = DumpResults(out_file_path='./results.pkl') + predictions = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))] + metric.process(None, predictions) + self.assertEqual(len(metric.results), 1) + self.assertEqual(metric.results[0]['data'][0].device, + torch.device('cpu')) + + def test_compute_metrics(self): + temp_dir = tempfile.TemporaryDirectory() + path = osp.join(temp_dir.name, 'results.pkl') + metric = DumpResults(out_file_path=path) + predictions = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))] + metric.process(None, predictions) + metric.compute_metrics(metric.results) + self.assertTrue(osp.isfile(path)) + + results = load(path) + self.assertEqual(len(results), 1) + self.assertEqual(results[0]['data'][0].device, torch.device('cpu')) + + temp_dir.cleanup()