Skip to content
Snippets Groups Projects
Unverified Commit 4432e54c authored by RangiLyu's avatar RangiLyu Committed by GitHub
Browse files

[Fix] Fix gpu tensors in results list are not on the same device. (#385)

* [Fix] Fix gpu tensors in results list are not on the same device.

* cast all tensor to cpu
parent cfee85ff
No related branches found
No related tags found
No related merge requests found
...@@ -106,6 +106,8 @@ class BaseMetric(metaclass=ABCMeta): ...@@ -106,6 +106,8 @@ class BaseMetric(metaclass=ABCMeta):
results = collect_results(self.results, size, self.collect_device) results = collect_results(self.results, size, self.collect_device)
if is_main_process(): if is_main_process():
# cast all tensors in results list to cpu
results = _to_cpu(results)
_metrics = self.compute_metrics(results) # type: ignore _metrics = self.compute_metrics(results) # type: ignore
# Add prefix to metric names # Add prefix to metric names
if self.prefix: if self.prefix:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import math import math
import unittest
from typing import Dict, List, Optional, Sequence from typing import Dict, List, Optional, Sequence
from unittest import TestCase from unittest import TestCase
import numpy as np import numpy as np
import torch
from mmengine.data import BaseDataElement from mmengine.data import BaseDataElement
from mmengine.evaluator import BaseMetric, Evaluator, get_metric_value from mmengine.evaluator import BaseMetric, Evaluator, get_metric_value
...@@ -238,3 +240,34 @@ class TestEvaluator(TestCase): ...@@ -238,3 +240,34 @@ class TestEvaluator(TestCase):
] ]
all_predictions = [BaseDataElement(pred=0) for _ in range(size)] all_predictions = [BaseDataElement(pred=0) for _ in range(size)]
evaluator.offline_evaluate(all_data, all_predictions) evaluator.offline_evaluate(all_data, all_predictions)
@unittest.skipUnless(torch.cuda.is_available(), 'can only run with gpu')
def test_evaluate_cast_cpu(self):
cfg = dict(type='ToyMetric')
evaluator = Evaluator(cfg)
size = 10
all_data = [
dict(
inputs=torch.zeros((3, 10, 10), device='cuda'),
data_sample=BaseDataElement(
label=torch.ones((1, ), device='cuda')))
for _ in range(size)
]
all_predictions = [
BaseDataElement(pred=torch.zeros((1, ), device='cuda'))
for _ in range(size)
]
for data, pred in zip(all_data, all_predictions):
evaluator.process([data], [pred])
def test_results_device(results: List):
for result in results:
self.assertEqual(result['pred'].device, torch.device('cpu'))
self.assertEqual(result['label'].device, torch.device('cpu'))
return {}
# replace the `compute_metrics` to the test function
evaluator.metrics[0].compute_metrics = test_results_device
evaluator.evaluate(size)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment