diff --git a/mmengine/data/label_data.py b/mmengine/data/label_data.py index 388313a69a75a221a3373eb93bea069200ee295e..455e024310e19bc0befabc2117dd7692dec8b63d 100644 --- a/mmengine/data/label_data.py +++ b/mmengine/data/label_data.py @@ -40,7 +40,7 @@ class LabelData(BaseDataElement): torch.Tensor: The converted results. """ assert isinstance(label, torch.Tensor) - onehot = torch.zeros((num_classes, ), dtype=torch.int64) - assert label.max().item() < num_classes + onehot = label.new_zeros((num_classes, )) + assert max(label, default=torch.tensor(0)).item() < num_classes onehot[label] = 1 return onehot diff --git a/tests/test_data/test_label_data.py b/tests/test_data/test_label_data.py index 6868e88e82335723fc5112f838cd5dd0406df967..048f4e067fc89d6c2606936a0619f366d02e2a85 100644 --- a/tests/test_data/test_label_data.py +++ b/tests/test_data/test_label_data.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from unittest import TestCase +import pytest import torch from mmengine.data import LabelData @@ -13,7 +14,7 @@ class TestLabelData(TestCase): num_classes = 10 onehot = LabelData.label_to_onehot(label=item, num_classes=num_classes) assert tuple(onehot.shape) == (num_classes, ) - + assert onehot.device == item.device # item is not onehot with self.assertRaises(AssertionError): LabelData.label_to_onehot(label='item', num_classes=num_classes) @@ -22,6 +23,10 @@ class TestLabelData(TestCase): with self.assertRaises(AssertionError): LabelData.label_to_onehot( torch.tensor([11], dtype=torch.int64), num_classes) + onehot = LabelData.label_to_onehot( + label=torch.tensor([], dtype=torch.int64), num_classes=num_classes) + assert (onehot == torch.zeros((num_classes, ), + dtype=torch.int64)).all() def test_onehot_to_label(self): # item is not onehot @@ -38,3 +43,13 @@ class TestLabelData(TestCase): onehot = LabelData.label_to_onehot(item, num_classes=10) label = LabelData.onehot_to_label(onehot) assert (label == item).all() + assert label.device == item.device + + @pytest.mark.skipif( + not torch.cuda.is_available(), reason='GPU is required!') + def test_cuda(self): + item = torch.arange(0, 9).cuda() + onehot = LabelData.label_to_onehot(item, num_classes=10) + assert item.device == onehot.device + label = LabelData.onehot_to_label(onehot) + assert label.device == onehot.device