From c90b95a44b1d05340425dbd73bb1ea5ec9765326 Mon Sep 17 00:00:00 2001
From: liukuikun <24622904+Harold-lkk@users.noreply.github.com>
Date: Fri, 10 Jun 2022 15:12:41 +0800
Subject: [PATCH] [Fix]: fix label data and support empty tensor in
 label_to_onehot (#291)

---
 mmengine/data/label_data.py        |  4 ++--
 tests/test_data/test_label_data.py | 17 ++++++++++++++++-
 2 files changed, 18 insertions(+), 3 deletions(-)

diff --git a/mmengine/data/label_data.py b/mmengine/data/label_data.py
index 388313a6..455e0243 100644
--- a/mmengine/data/label_data.py
+++ b/mmengine/data/label_data.py
@@ -40,7 +40,7 @@ class LabelData(BaseDataElement):
             torch.Tensor: The converted results.
         """
         assert isinstance(label, torch.Tensor)
-        onehot = torch.zeros((num_classes, ), dtype=torch.int64)
-        assert label.max().item() < num_classes
+        onehot = label.new_zeros((num_classes, ))
+        assert max(label, default=torch.tensor(0)).item() < num_classes
         onehot[label] = 1
         return onehot
diff --git a/tests/test_data/test_label_data.py b/tests/test_data/test_label_data.py
index 6868e88e..048f4e06 100644
--- a/tests/test_data/test_label_data.py
+++ b/tests/test_data/test_label_data.py
@@ -1,6 +1,7 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 from unittest import TestCase
 
+import pytest
 import torch
 
 from mmengine.data import LabelData
@@ -13,7 +14,7 @@ class TestLabelData(TestCase):
         num_classes = 10
         onehot = LabelData.label_to_onehot(label=item, num_classes=num_classes)
         assert tuple(onehot.shape) == (num_classes, )
-
+        assert onehot.device == item.device
         # item is not onehot
         with self.assertRaises(AssertionError):
             LabelData.label_to_onehot(label='item', num_classes=num_classes)
@@ -22,6 +23,10 @@ class TestLabelData(TestCase):
         with self.assertRaises(AssertionError):
             LabelData.label_to_onehot(
                 torch.tensor([11], dtype=torch.int64), num_classes)
+        onehot = LabelData.label_to_onehot(
+            label=torch.tensor([], dtype=torch.int64), num_classes=num_classes)
+        assert (onehot == torch.zeros((num_classes, ),
+                                      dtype=torch.int64)).all()
 
     def test_onehot_to_label(self):
         # item is not onehot
@@ -38,3 +43,13 @@ class TestLabelData(TestCase):
         onehot = LabelData.label_to_onehot(item, num_classes=10)
         label = LabelData.onehot_to_label(onehot)
         assert (label == item).all()
+        assert label.device == item.device
+
+    @pytest.mark.skipif(
+        not torch.cuda.is_available(), reason='GPU is required!')
+    def test_cuda(self):
+        item = torch.arange(0, 9).cuda()
+        onehot = LabelData.label_to_onehot(item, num_classes=10)
+        assert item.device == onehot.device
+        label = LabelData.onehot_to_label(onehot)
+        assert label.device == onehot.device
-- 
GitLab