Skip to content
Snippets Groups Projects
Unverified Commit c90b95a4 authored by liukuikun's avatar liukuikun Committed by GitHub
Browse files

[Fix]: fix label data and support empty tensor in label_to_onehot (#291)

parent 2f16ec69
No related branches found
No related tags found
No related merge requests found
...@@ -40,7 +40,7 @@ class LabelData(BaseDataElement): ...@@ -40,7 +40,7 @@ class LabelData(BaseDataElement):
torch.Tensor: The converted results. torch.Tensor: The converted results.
""" """
assert isinstance(label, torch.Tensor) assert isinstance(label, torch.Tensor)
onehot = torch.zeros((num_classes, ), dtype=torch.int64) onehot = label.new_zeros((num_classes, ))
assert label.max().item() < num_classes assert max(label, default=torch.tensor(0)).item() < num_classes
onehot[label] = 1 onehot[label] = 1
return onehot return onehot
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase from unittest import TestCase
import pytest
import torch import torch
from mmengine.data import LabelData from mmengine.data import LabelData
...@@ -13,7 +14,7 @@ class TestLabelData(TestCase): ...@@ -13,7 +14,7 @@ class TestLabelData(TestCase):
num_classes = 10 num_classes = 10
onehot = LabelData.label_to_onehot(label=item, num_classes=num_classes) onehot = LabelData.label_to_onehot(label=item, num_classes=num_classes)
assert tuple(onehot.shape) == (num_classes, ) assert tuple(onehot.shape) == (num_classes, )
assert onehot.device == item.device
# item is not onehot # item is not onehot
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
LabelData.label_to_onehot(label='item', num_classes=num_classes) LabelData.label_to_onehot(label='item', num_classes=num_classes)
...@@ -22,6 +23,10 @@ class TestLabelData(TestCase): ...@@ -22,6 +23,10 @@ class TestLabelData(TestCase):
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
LabelData.label_to_onehot( LabelData.label_to_onehot(
torch.tensor([11], dtype=torch.int64), num_classes) torch.tensor([11], dtype=torch.int64), num_classes)
onehot = LabelData.label_to_onehot(
label=torch.tensor([], dtype=torch.int64), num_classes=num_classes)
assert (onehot == torch.zeros((num_classes, ),
dtype=torch.int64)).all()
def test_onehot_to_label(self): def test_onehot_to_label(self):
# item is not onehot # item is not onehot
...@@ -38,3 +43,13 @@ class TestLabelData(TestCase): ...@@ -38,3 +43,13 @@ class TestLabelData(TestCase):
onehot = LabelData.label_to_onehot(item, num_classes=10) onehot = LabelData.label_to_onehot(item, num_classes=10)
label = LabelData.onehot_to_label(onehot) label = LabelData.onehot_to_label(onehot)
assert (label == item).all() assert (label == item).all()
assert label.device == item.device
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='GPU is required!')
def test_cuda(self):
item = torch.arange(0, 9).cuda()
onehot = LabelData.label_to_onehot(item, num_classes=10)
assert item.device == onehot.device
label = LabelData.onehot_to_label(onehot)
assert label.device == onehot.device
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment