From 819e10c24c1cd810aad2ea5e771df9b3535b1818 Mon Sep 17 00:00:00 2001 From: RangiLyu <lyuchqi@gmail.com> Date: Mon, 13 Jun 2022 21:21:19 +0800 Subject: [PATCH] [Fix] Fix image dtype when enable_normalize=False. (#301) * [Fix] Fix image dtype when enable_normalize=False. * update ut * move to collate * update ut --- mmengine/model/base_model/data_preprocessor.py | 2 +- tests/test_model/test_base_model/test_data_preprocessor.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py index 7242fde9..2640ce55 100644 --- a/mmengine/model/base_model/data_preprocessor.py +++ b/mmengine/model/base_model/data_preprocessor.py @@ -53,7 +53,7 @@ class BaseDataPreprocessor(nn.Module): Tuple[List[torch.Tensor], Optional[list]]: Unstacked list of input tensor and list of labels at target device. """ - inputs = [_data['inputs'].to(self._device) for _data in data] + inputs = [_data['inputs'].to(self._device).float() for _data in data] batch_data_samples: List[BaseDataElement] = [] # Model can get predictions without any data samples. for _data in data: diff --git a/tests/test_model/test_base_model/test_data_preprocessor.py b/tests/test_model/test_base_model/test_data_preprocessor.py index b82b09fc..6af60752 100644 --- a/tests/test_model/test_base_model/test_data_preprocessor.py +++ b/tests/test_model/test_base_model/test_data_preprocessor.py @@ -28,6 +28,7 @@ class TestBaseDataPreprocessor(TestCase): ] batch_inputs, batch_labels = base_data_preprocessor(data) + self.assertTrue(torch.is_floating_point(batch_inputs)) self.assertEqual(batch_inputs.shape, (2, 1, 3, 5)) assert_allclose(input1, batch_inputs[0]) @@ -38,14 +39,17 @@ class TestBaseDataPreprocessor(TestCase): if torch.cuda.is_available(): base_data_preprocessor = base_data_preprocessor.cuda() batch_inputs, batch_labels = base_data_preprocessor(data) + self.assertTrue(torch.is_floating_point(batch_inputs)) self.assertEqual(batch_inputs.device.type, 'cuda') base_data_preprocessor = base_data_preprocessor.cpu() batch_inputs, batch_labels = base_data_preprocessor(data) + self.assertTrue(torch.is_floating_point(batch_inputs)) self.assertEqual(batch_inputs.device.type, 'cpu') base_data_preprocessor = base_data_preprocessor.to('cuda:0') batch_inputs, batch_labels = base_data_preprocessor(data) + self.assertTrue(torch.is_floating_point(batch_inputs)) self.assertEqual(batch_inputs.device.type, 'cuda') @@ -122,6 +126,7 @@ class TestImageDataPreprocessor(TestBaseDataPreprocessor): target_inputs = [target_inputs1, target_inputs2] inputs, data_samples = data_preprocessor(data, True) + self.assertTrue(torch.is_floating_point(inputs)) target_data_samples = [data_sample1, data_sample2] for input_, data_sample, target_input, target_data_sample in zip( @@ -142,6 +147,7 @@ class TestImageDataPreprocessor(TestBaseDataPreprocessor): target_inputs = [target_inputs1, target_inputs2] inputs, data_samples = data_preprocessor(data, True) + self.assertTrue(torch.is_floating_point(inputs)) target_data_samples = [data_sample1, data_sample2] for input_, data_sample, target_input, target_data_sample in zip( -- GitLab