From c90b95a44b1d05340425dbd73bb1ea5ec9765326 Mon Sep 17 00:00:00 2001 From: liukuikun <24622904+Harold-lkk@users.noreply.github.com> Date: Fri, 10 Jun 2022 15:12:41 +0800 Subject: [PATCH] [Fix]: fix label data and support empty tensor in label_to_onehot (#291) --- mmengine/data/label_data.py | 4 ++-- tests/test_data/test_label_data.py | 17 ++++++++++++++++- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/mmengine/data/label_data.py b/mmengine/data/label_data.py index 388313a6..455e0243 100644 --- a/mmengine/data/label_data.py +++ b/mmengine/data/label_data.py @@ -40,7 +40,7 @@ class LabelData(BaseDataElement): torch.Tensor: The converted results. """ assert isinstance(label, torch.Tensor) - onehot = torch.zeros((num_classes, ), dtype=torch.int64) - assert label.max().item() < num_classes + onehot = label.new_zeros((num_classes, )) + assert max(label, default=torch.tensor(0)).item() < num_classes onehot[label] = 1 return onehot diff --git a/tests/test_data/test_label_data.py b/tests/test_data/test_label_data.py index 6868e88e..048f4e06 100644 --- a/tests/test_data/test_label_data.py +++ b/tests/test_data/test_label_data.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from unittest import TestCase +import pytest import torch from mmengine.data import LabelData @@ -13,7 +14,7 @@ class TestLabelData(TestCase): num_classes = 10 onehot = LabelData.label_to_onehot(label=item, num_classes=num_classes) assert tuple(onehot.shape) == (num_classes, ) - + assert onehot.device == item.device # item is not onehot with self.assertRaises(AssertionError): LabelData.label_to_onehot(label='item', num_classes=num_classes) @@ -22,6 +23,10 @@ class TestLabelData(TestCase): with self.assertRaises(AssertionError): LabelData.label_to_onehot( torch.tensor([11], dtype=torch.int64), num_classes) + onehot = LabelData.label_to_onehot( + label=torch.tensor([], dtype=torch.int64), num_classes=num_classes) + assert (onehot == torch.zeros((num_classes, ), + dtype=torch.int64)).all() def test_onehot_to_label(self): # item is not onehot @@ -38,3 +43,13 @@ class TestLabelData(TestCase): onehot = LabelData.label_to_onehot(item, num_classes=10) label = LabelData.onehot_to_label(onehot) assert (label == item).all() + assert label.device == item.device + + @pytest.mark.skipif( + not torch.cuda.is_available(), reason='GPU is required!') + def test_cuda(self): + item = torch.arange(0, 9).cuda() + onehot = LabelData.label_to_onehot(item, num_classes=10) + assert item.device == onehot.device + label = LabelData.onehot_to_label(onehot) + assert label.device == onehot.device -- GitLab