From 819e10c24c1cd810aad2ea5e771df9b3535b1818 Mon Sep 17 00:00:00 2001
From: RangiLyu <lyuchqi@gmail.com>
Date: Mon, 13 Jun 2022 21:21:19 +0800
Subject: [PATCH] [Fix] Fix image dtype when enable_normalize=False. (#301)

* [Fix] Fix image dtype when enable_normalize=False.

* update ut

* move to collate

* update ut
---
 mmengine/model/base_model/data_preprocessor.py             | 2 +-
 tests/test_model/test_base_model/test_data_preprocessor.py | 6 ++++++
 2 files changed, 7 insertions(+), 1 deletion(-)

diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py
index 7242fde9..2640ce55 100644
--- a/mmengine/model/base_model/data_preprocessor.py
+++ b/mmengine/model/base_model/data_preprocessor.py
@@ -53,7 +53,7 @@ class BaseDataPreprocessor(nn.Module):
             Tuple[List[torch.Tensor], Optional[list]]: Unstacked list of input
             tensor and list of labels at target device.
         """
-        inputs = [_data['inputs'].to(self._device) for _data in data]
+        inputs = [_data['inputs'].to(self._device).float() for _data in data]
         batch_data_samples: List[BaseDataElement] = []
         # Model can get predictions without any data samples.
         for _data in data:
diff --git a/tests/test_model/test_base_model/test_data_preprocessor.py b/tests/test_model/test_base_model/test_data_preprocessor.py
index b82b09fc..6af60752 100644
--- a/tests/test_model/test_base_model/test_data_preprocessor.py
+++ b/tests/test_model/test_base_model/test_data_preprocessor.py
@@ -28,6 +28,7 @@ class TestBaseDataPreprocessor(TestCase):
         ]
 
         batch_inputs, batch_labels = base_data_preprocessor(data)
+        self.assertTrue(torch.is_floating_point(batch_inputs))
         self.assertEqual(batch_inputs.shape, (2, 1, 3, 5))
 
         assert_allclose(input1, batch_inputs[0])
@@ -38,14 +39,17 @@ class TestBaseDataPreprocessor(TestCase):
         if torch.cuda.is_available():
             base_data_preprocessor = base_data_preprocessor.cuda()
             batch_inputs, batch_labels = base_data_preprocessor(data)
+            self.assertTrue(torch.is_floating_point(batch_inputs))
             self.assertEqual(batch_inputs.device.type, 'cuda')
 
             base_data_preprocessor = base_data_preprocessor.cpu()
             batch_inputs, batch_labels = base_data_preprocessor(data)
+            self.assertTrue(torch.is_floating_point(batch_inputs))
             self.assertEqual(batch_inputs.device.type, 'cpu')
 
             base_data_preprocessor = base_data_preprocessor.to('cuda:0')
             batch_inputs, batch_labels = base_data_preprocessor(data)
+            self.assertTrue(torch.is_floating_point(batch_inputs))
             self.assertEqual(batch_inputs.device.type, 'cuda')
 
 
@@ -122,6 +126,7 @@ class TestImageDataPreprocessor(TestBaseDataPreprocessor):
 
         target_inputs = [target_inputs1, target_inputs2]
         inputs, data_samples = data_preprocessor(data, True)
+        self.assertTrue(torch.is_floating_point(inputs))
 
         target_data_samples = [data_sample1, data_sample2]
         for input_, data_sample, target_input, target_data_sample in zip(
@@ -142,6 +147,7 @@ class TestImageDataPreprocessor(TestBaseDataPreprocessor):
 
         target_inputs = [target_inputs1, target_inputs2]
         inputs, data_samples = data_preprocessor(data, True)
+        self.assertTrue(torch.is_floating_point(inputs))
 
         target_data_samples = [data_sample1, data_sample2]
         for input_, data_sample, target_input, target_data_sample in zip(
-- 
GitLab