From 4432e54c97e25b7ca1fb6f33f2212b006008f813 Mon Sep 17 00:00:00 2001 From: RangiLyu <lyuchqi@gmail.com> Date: Mon, 1 Aug 2022 10:20:57 +0800 Subject: [PATCH] [Fix] Fix gpu tensors in results list are not on the same device. (#385) * [Fix] Fix gpu tensors in results list are not on the same device. * cast all tensor to cpu --- mmengine/evaluator/metric.py | 2 ++ tests/test_evaluator/test_evaluator.py | 33 ++++++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/mmengine/evaluator/metric.py b/mmengine/evaluator/metric.py index 2f3b0c01..0f837ed0 100644 --- a/mmengine/evaluator/metric.py +++ b/mmengine/evaluator/metric.py @@ -106,6 +106,8 @@ class BaseMetric(metaclass=ABCMeta): results = collect_results(self.results, size, self.collect_device) if is_main_process(): + # cast all tensors in results list to cpu + results = _to_cpu(results) _metrics = self.compute_metrics(results) # type: ignore # Add prefix to metric names if self.prefix: diff --git a/tests/test_evaluator/test_evaluator.py b/tests/test_evaluator/test_evaluator.py index 5364c067..a7a58166 100644 --- a/tests/test_evaluator/test_evaluator.py +++ b/tests/test_evaluator/test_evaluator.py @@ -1,9 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import math +import unittest from typing import Dict, List, Optional, Sequence from unittest import TestCase import numpy as np +import torch from mmengine.data import BaseDataElement from mmengine.evaluator import BaseMetric, Evaluator, get_metric_value @@ -238,3 +240,34 @@ class TestEvaluator(TestCase): ] all_predictions = [BaseDataElement(pred=0) for _ in range(size)] evaluator.offline_evaluate(all_data, all_predictions) + + @unittest.skipUnless(torch.cuda.is_available(), 'can only run with gpu') + def test_evaluate_cast_cpu(self): + cfg = dict(type='ToyMetric') + evaluator = Evaluator(cfg) + + size = 10 + + all_data = [ + dict( + inputs=torch.zeros((3, 10, 10), device='cuda'), + data_sample=BaseDataElement( + label=torch.ones((1, ), device='cuda'))) + for _ in range(size) + ] + all_predictions = [ + BaseDataElement(pred=torch.zeros((1, ), device='cuda')) + for _ in range(size) + ] + for data, pred in zip(all_data, all_predictions): + evaluator.process([data], [pred]) + + def test_results_device(results: List): + for result in results: + self.assertEqual(result['pred'].device, torch.device('cpu')) + self.assertEqual(result['label'].device, torch.device('cpu')) + return {} + + # replace the `compute_metrics` to the test function + evaluator.metrics[0].compute_metrics = test_results_device + evaluator.evaluate(size) -- GitLab