Skip to content
Snippets Groups Projects
Unverified Commit 4cd91ffe authored by RangiLyu's avatar RangiLyu Committed by GitHub
Browse files

[Feature] Dump predictions to a pickle file for offline evaluation. (#293)

* [Feature] Dump predictions to pickle file for offline evaluation.

* print_log
parent b7866021
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved.
from .evaluator import Evaluator
from .metric import BaseMetric
from .metric import BaseMetric, DumpResults
from .utils import get_metric_value
__all__ = ['BaseMetric', 'Evaluator', 'get_metric_value']
__all__ = ['BaseMetric', 'Evaluator', 'get_metric_value', 'DumpResults']
......@@ -3,8 +3,14 @@ import warnings
from abc import ABCMeta, abstractmethod
from typing import Any, List, Optional, Sequence, Union
from torch import Tensor
from mmengine.data import BaseDataElement
from mmengine.dist import (broadcast_object_list, collect_results,
is_main_process)
from mmengine.fileio import dump
from mmengine.logging import print_log
from mmengine.registry import METRICS
class BaseMetric(metaclass=ABCMeta):
......@@ -116,3 +122,51 @@ class BaseMetric(metaclass=ABCMeta):
# reset the results list
self.results.clear()
return metrics[0]
@METRICS.register_module()
class DumpResults(BaseMetric):
"""Dump model predictions to a pickle file for offline evaluation.
Args:
out_file_path (str): Path of the dumped file. Must end with '.pkl'
or '.pickle'.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
"""
def __init__(self,
out_file_path: str,
collect_device: str = 'cpu') -> None:
super().__init__(collect_device=collect_device)
if not out_file_path.endswith(('.pkl', '.pickle')):
raise ValueError('The output file must be a pkl file.')
self.out_file_path = out_file_path
def process(self, data_batch: Sequence[dict],
predictions: Sequence[dict]) -> None:
"""transfer tensors in predictions to CPU."""
self.results.extend(_to_cpu(predictions))
def compute_metrics(self, results: list) -> dict:
"""dump the prediction results to a pickle file."""
dump(results, self.out_file_path)
print_log(
f'Results has been saved to {self.out_file_path}.',
logger='current')
return {}
def _to_cpu(data: Any) -> Any:
"""transfer all tensors and BaseDataElement to cpu."""
if isinstance(data, (Tensor, BaseDataElement)):
return data.to('cpu')
elif isinstance(data, list):
return [_to_cpu(d) for d in data]
elif isinstance(data, tuple):
return tuple(_to_cpu(d) for d in data)
elif isinstance(data, dict):
return {k: _to_cpu(v) for k, v in data.items()}
else:
return data
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import tempfile
from unittest import TestCase
import torch
from torch import Tensor
from mmengine.evaluator import DumpResults
from mmengine.fileio import load
class TestDumpResults(TestCase):
def test_init(self):
with self.assertRaisesRegex(ValueError,
'The output file must be a pkl file.'):
DumpResults(out_file_path='./results.json')
def test_process(self):
metric = DumpResults(out_file_path='./results.pkl')
predictions = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))]
metric.process(None, predictions)
self.assertEqual(len(metric.results), 1)
self.assertEqual(metric.results[0]['data'][0].device,
torch.device('cpu'))
def test_compute_metrics(self):
temp_dir = tempfile.TemporaryDirectory()
path = osp.join(temp_dir.name, 'results.pkl')
metric = DumpResults(out_file_path=path)
predictions = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))]
metric.process(None, predictions)
metric.compute_metrics(metric.results)
self.assertTrue(osp.isfile(path))
results = load(path)
self.assertEqual(len(results), 1)
self.assertEqual(results[0]['data'][0].device, torch.device('cpu'))
temp_dir.cleanup()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment