From 4432e54c97e25b7ca1fb6f33f2212b006008f813 Mon Sep 17 00:00:00 2001
From: RangiLyu <lyuchqi@gmail.com>
Date: Mon, 1 Aug 2022 10:20:57 +0800
Subject: [PATCH] [Fix] Fix gpu tensors in results list are not on the same
 device. (#385)

* [Fix] Fix gpu tensors in results list are not on the same device.

* cast all tensor to cpu
---
 mmengine/evaluator/metric.py           |  2 ++
 tests/test_evaluator/test_evaluator.py | 33 ++++++++++++++++++++++++++
 2 files changed, 35 insertions(+)

diff --git a/mmengine/evaluator/metric.py b/mmengine/evaluator/metric.py
index 2f3b0c01..0f837ed0 100644
--- a/mmengine/evaluator/metric.py
+++ b/mmengine/evaluator/metric.py
@@ -106,6 +106,8 @@ class BaseMetric(metaclass=ABCMeta):
         results = collect_results(self.results, size, self.collect_device)
 
         if is_main_process():
+            # cast all tensors in results list to cpu
+            results = _to_cpu(results)
             _metrics = self.compute_metrics(results)  # type: ignore
             # Add prefix to metric names
             if self.prefix:
diff --git a/tests/test_evaluator/test_evaluator.py b/tests/test_evaluator/test_evaluator.py
index 5364c067..a7a58166 100644
--- a/tests/test_evaluator/test_evaluator.py
+++ b/tests/test_evaluator/test_evaluator.py
@@ -1,9 +1,11 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 import math
+import unittest
 from typing import Dict, List, Optional, Sequence
 from unittest import TestCase
 
 import numpy as np
+import torch
 
 from mmengine.data import BaseDataElement
 from mmengine.evaluator import BaseMetric, Evaluator, get_metric_value
@@ -238,3 +240,34 @@ class TestEvaluator(TestCase):
         ]
         all_predictions = [BaseDataElement(pred=0) for _ in range(size)]
         evaluator.offline_evaluate(all_data, all_predictions)
+
+    @unittest.skipUnless(torch.cuda.is_available(), 'can only run with gpu')
+    def test_evaluate_cast_cpu(self):
+        cfg = dict(type='ToyMetric')
+        evaluator = Evaluator(cfg)
+
+        size = 10
+
+        all_data = [
+            dict(
+                inputs=torch.zeros((3, 10, 10), device='cuda'),
+                data_sample=BaseDataElement(
+                    label=torch.ones((1, ), device='cuda')))
+            for _ in range(size)
+        ]
+        all_predictions = [
+            BaseDataElement(pred=torch.zeros((1, ), device='cuda'))
+            for _ in range(size)
+        ]
+        for data, pred in zip(all_data, all_predictions):
+            evaluator.process([data], [pred])
+
+        def test_results_device(results: List):
+            for result in results:
+                self.assertEqual(result['pred'].device, torch.device('cpu'))
+                self.assertEqual(result['label'].device, torch.device('cpu'))
+            return {}
+
+        # replace the `compute_metrics` to the test function
+        evaluator.metrics[0].compute_metrics = test_results_device
+        evaluator.evaluate(size)
-- 
GitLab