From 4cd91ffe155a31ff4108231857df016e2d91f3ca Mon Sep 17 00:00:00 2001
From: RangiLyu <lyuchqi@gmail.com>
Date: Tue, 14 Jun 2022 14:48:21 +0800
Subject: [PATCH] [Feature] Dump predictions to a pickle file for offline
 evaluation. (#293)

* [Feature] Dump predictions to pickle file for offline evaluation.

* print_log
---
 mmengine/evaluator/__init__.py      |  4 +--
 mmengine/evaluator/metric.py        | 54 +++++++++++++++++++++++++++++
 tests/test_evaluator/test_metric.py | 41 ++++++++++++++++++++++
 3 files changed, 97 insertions(+), 2 deletions(-)
 create mode 100644 tests/test_evaluator/test_metric.py

diff --git a/mmengine/evaluator/__init__.py b/mmengine/evaluator/__init__.py
index ac91828d..e6bc7842 100644
--- a/mmengine/evaluator/__init__.py
+++ b/mmengine/evaluator/__init__.py
@@ -1,6 +1,6 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 from .evaluator import Evaluator
-from .metric import BaseMetric
+from .metric import BaseMetric, DumpResults
 from .utils import get_metric_value
 
-__all__ = ['BaseMetric', 'Evaluator', 'get_metric_value']
+__all__ = ['BaseMetric', 'Evaluator', 'get_metric_value', 'DumpResults']
diff --git a/mmengine/evaluator/metric.py b/mmengine/evaluator/metric.py
index 7d8fa736..2f3b0c01 100644
--- a/mmengine/evaluator/metric.py
+++ b/mmengine/evaluator/metric.py
@@ -3,8 +3,14 @@ import warnings
 from abc import ABCMeta, abstractmethod
 from typing import Any, List, Optional, Sequence, Union
 
+from torch import Tensor
+
+from mmengine.data import BaseDataElement
 from mmengine.dist import (broadcast_object_list, collect_results,
                            is_main_process)
+from mmengine.fileio import dump
+from mmengine.logging import print_log
+from mmengine.registry import METRICS
 
 
 class BaseMetric(metaclass=ABCMeta):
@@ -116,3 +122,51 @@ class BaseMetric(metaclass=ABCMeta):
         # reset the results list
         self.results.clear()
         return metrics[0]
+
+
+@METRICS.register_module()
+class DumpResults(BaseMetric):
+    """Dump model predictions to a pickle file for offline evaluation.
+
+    Args:
+        out_file_path (str): Path of the dumped file. Must end with '.pkl'
+            or '.pickle'.
+        collect_device (str): Device name used for collecting results from
+            different ranks during distributed training. Must be 'cpu' or
+            'gpu'. Defaults to 'cpu'.
+    """
+
+    def __init__(self,
+                 out_file_path: str,
+                 collect_device: str = 'cpu') -> None:
+        super().__init__(collect_device=collect_device)
+        if not out_file_path.endswith(('.pkl', '.pickle')):
+            raise ValueError('The output file must be a pkl file.')
+        self.out_file_path = out_file_path
+
+    def process(self, data_batch: Sequence[dict],
+                predictions: Sequence[dict]) -> None:
+        """transfer tensors in predictions to CPU."""
+        self.results.extend(_to_cpu(predictions))
+
+    def compute_metrics(self, results: list) -> dict:
+        """dump the prediction results to a pickle file."""
+        dump(results, self.out_file_path)
+        print_log(
+            f'Results has been saved to {self.out_file_path}.',
+            logger='current')
+        return {}
+
+
+def _to_cpu(data: Any) -> Any:
+    """transfer all tensors and BaseDataElement to cpu."""
+    if isinstance(data, (Tensor, BaseDataElement)):
+        return data.to('cpu')
+    elif isinstance(data, list):
+        return [_to_cpu(d) for d in data]
+    elif isinstance(data, tuple):
+        return tuple(_to_cpu(d) for d in data)
+    elif isinstance(data, dict):
+        return {k: _to_cpu(v) for k, v in data.items()}
+    else:
+        return data
diff --git a/tests/test_evaluator/test_metric.py b/tests/test_evaluator/test_metric.py
new file mode 100644
index 00000000..91d66ea1
--- /dev/null
+++ b/tests/test_evaluator/test_metric.py
@@ -0,0 +1,41 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import tempfile
+from unittest import TestCase
+
+import torch
+from torch import Tensor
+
+from mmengine.evaluator import DumpResults
+from mmengine.fileio import load
+
+
+class TestDumpResults(TestCase):
+
+    def test_init(self):
+        with self.assertRaisesRegex(ValueError,
+                                    'The output file must be a pkl file.'):
+            DumpResults(out_file_path='./results.json')
+
+    def test_process(self):
+        metric = DumpResults(out_file_path='./results.pkl')
+        predictions = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))]
+        metric.process(None, predictions)
+        self.assertEqual(len(metric.results), 1)
+        self.assertEqual(metric.results[0]['data'][0].device,
+                         torch.device('cpu'))
+
+    def test_compute_metrics(self):
+        temp_dir = tempfile.TemporaryDirectory()
+        path = osp.join(temp_dir.name, 'results.pkl')
+        metric = DumpResults(out_file_path=path)
+        predictions = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))]
+        metric.process(None, predictions)
+        metric.compute_metrics(metric.results)
+        self.assertTrue(osp.isfile(path))
+
+        results = load(path)
+        self.assertEqual(len(results), 1)
+        self.assertEqual(results[0]['data'][0].device, torch.device('cpu'))
+
+        temp_dir.cleanup()
-- 
GitLab