diff --git a/mmengine/evaluator/metric.py b/mmengine/evaluator/metric.py index 2f3b0c01d4db7c52d8e4f1ccbac4183530dccd3e..0f837ed0b942f2a64a30edcdbf0a2f29592f1a0b 100644 --- a/mmengine/evaluator/metric.py +++ b/mmengine/evaluator/metric.py @@ -106,6 +106,8 @@ class BaseMetric(metaclass=ABCMeta): results = collect_results(self.results, size, self.collect_device) if is_main_process(): + # cast all tensors in results list to cpu + results = _to_cpu(results) _metrics = self.compute_metrics(results) # type: ignore # Add prefix to metric names if self.prefix: diff --git a/tests/test_evaluator/test_evaluator.py b/tests/test_evaluator/test_evaluator.py index 5364c0679997d8dfba7630056f3b8ab93e955cd4..a7a581661545bdbc6a6ab38015d12f2a7625e0ba 100644 --- a/tests/test_evaluator/test_evaluator.py +++ b/tests/test_evaluator/test_evaluator.py @@ -1,9 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import math +import unittest from typing import Dict, List, Optional, Sequence from unittest import TestCase import numpy as np +import torch from mmengine.data import BaseDataElement from mmengine.evaluator import BaseMetric, Evaluator, get_metric_value @@ -238,3 +240,34 @@ class TestEvaluator(TestCase): ] all_predictions = [BaseDataElement(pred=0) for _ in range(size)] evaluator.offline_evaluate(all_data, all_predictions) + + @unittest.skipUnless(torch.cuda.is_available(), 'can only run with gpu') + def test_evaluate_cast_cpu(self): + cfg = dict(type='ToyMetric') + evaluator = Evaluator(cfg) + + size = 10 + + all_data = [ + dict( + inputs=torch.zeros((3, 10, 10), device='cuda'), + data_sample=BaseDataElement( + label=torch.ones((1, ), device='cuda'))) + for _ in range(size) + ] + all_predictions = [ + BaseDataElement(pred=torch.zeros((1, ), device='cuda')) + for _ in range(size) + ] + for data, pred in zip(all_data, all_predictions): + evaluator.process([data], [pred]) + + def test_results_device(results: List): + for result in results: + self.assertEqual(result['pred'].device, torch.device('cpu')) + self.assertEqual(result['label'].device, torch.device('cpu')) + return {} + + # replace the `compute_metrics` to the test function + evaluator.metrics[0].compute_metrics = test_results_device + evaluator.evaluate(size)