diff --git a/mmengine/data/label_data.py b/mmengine/data/label_data.py
index 388313a69a75a221a3373eb93bea069200ee295e..455e024310e19bc0befabc2117dd7692dec8b63d 100644
--- a/mmengine/data/label_data.py
+++ b/mmengine/data/label_data.py
@@ -40,7 +40,7 @@ class LabelData(BaseDataElement):
             torch.Tensor: The converted results.
         """
         assert isinstance(label, torch.Tensor)
-        onehot = torch.zeros((num_classes, ), dtype=torch.int64)
-        assert label.max().item() < num_classes
+        onehot = label.new_zeros((num_classes, ))
+        assert max(label, default=torch.tensor(0)).item() < num_classes
         onehot[label] = 1
         return onehot
diff --git a/tests/test_data/test_label_data.py b/tests/test_data/test_label_data.py
index 6868e88e82335723fc5112f838cd5dd0406df967..048f4e067fc89d6c2606936a0619f366d02e2a85 100644
--- a/tests/test_data/test_label_data.py
+++ b/tests/test_data/test_label_data.py
@@ -1,6 +1,7 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 from unittest import TestCase
 
+import pytest
 import torch
 
 from mmengine.data import LabelData
@@ -13,7 +14,7 @@ class TestLabelData(TestCase):
         num_classes = 10
         onehot = LabelData.label_to_onehot(label=item, num_classes=num_classes)
         assert tuple(onehot.shape) == (num_classes, )
-
+        assert onehot.device == item.device
         # item is not onehot
         with self.assertRaises(AssertionError):
             LabelData.label_to_onehot(label='item', num_classes=num_classes)
@@ -22,6 +23,10 @@ class TestLabelData(TestCase):
         with self.assertRaises(AssertionError):
             LabelData.label_to_onehot(
                 torch.tensor([11], dtype=torch.int64), num_classes)
+        onehot = LabelData.label_to_onehot(
+            label=torch.tensor([], dtype=torch.int64), num_classes=num_classes)
+        assert (onehot == torch.zeros((num_classes, ),
+                                      dtype=torch.int64)).all()
 
     def test_onehot_to_label(self):
         # item is not onehot
@@ -38,3 +43,13 @@ class TestLabelData(TestCase):
         onehot = LabelData.label_to_onehot(item, num_classes=10)
         label = LabelData.onehot_to_label(onehot)
         assert (label == item).all()
+        assert label.device == item.device
+
+    @pytest.mark.skipif(
+        not torch.cuda.is_available(), reason='GPU is required!')
+    def test_cuda(self):
+        item = torch.arange(0, 9).cuda()
+        onehot = LabelData.label_to_onehot(item, num_classes=10)
+        assert item.device == onehot.device
+        label = LabelData.onehot_to_label(onehot)
+        assert label.device == onehot.device