Skip to content
Snippets Groups Projects
Unverified Commit bd679138 authored by Qian Zhao's avatar Qian Zhao Committed by GitHub
Browse files

[Fix] BaseModel & BaseDataPreprocessor `to` method to be consistent with torch.nn.Module (#783)

* fix BaseModel `to` method to be consistent with torch.nn.Module

* fix data_preprocessor as well

* fix docstring alignment

* fix docstring alignment
parent 0dd0a22e
No related branches found
No related tags found
No related merge requests found
......@@ -155,9 +155,9 @@ class BaseModel(BaseModule):
Returns:
tuple[Tensor, dict]: There are two elements. The first is the
loss tensor passed to optim_wrapper which may be a weighted sum of
all losses, and the second is log_vars which will be sent to the
logger.
loss tensor passed to optim_wrapper which may be a weighted sum
of all losses, and the second is log_vars which will be sent to
the logger.
"""
log_vars = []
for loss_name, loss_value in losses.items():
......@@ -177,23 +177,17 @@ class BaseModel(BaseModule):
return loss, log_vars # type: ignore
def to(self,
device: Optional[Union[int, str, torch.device]] = None,
*args,
**kwargs) -> nn.Module:
def to(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.to`
additionally.
Args:
device (int, str or torch.device, optional): the desired device
of the parameters and buffers in this module.
Returns:
nn.Module: The model itself.
"""
device = torch._C._nn._parse_to(*args, **kwargs)[0]
if device is not None:
self._set_device(torch.device(device))
return super().to(device)
return super().to(*args, **kwargs)
def cuda(
self,
......@@ -244,7 +238,7 @@ class BaseModel(BaseModule):
Args:
device (torch.device): the desired device of the parameters and
buffers in this module.
buffers in this module.
"""
def apply_fn(module):
......
......@@ -84,19 +84,16 @@ class BaseDataPreprocessor(nn.Module):
def device(self):
return self._device
def to(self, device: Optional[Union[int, torch.device]], *args,
**kwargs) -> nn.Module:
def to(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to set the :attr:`device`
Args:
device (int or torch.device, optional): The desired device of the
parameters and buffers in this module.
Returns:
nn.Module: The model itself.
"""
self._device = torch.device(device)
return super().to(device)
device = torch._C._nn._parse_to(*args, **kwargs)[0]
if device is not None:
self._device = torch.device(device)
return super().to(*args, **kwargs)
def cuda(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to set the :attr:`device`
......
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
import unittest
from unittest import TestCase
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.optim import SGD
from mmengine.model import BaseDataPreprocessor, BaseModel
......@@ -11,6 +13,18 @@ from mmengine.optim import OptimWrapper
from mmengine.registry import MODELS
from mmengine.testing import assert_allclose
dtypes_to_test = [torch.float16, torch.float32, torch.float64, torch.half]
cpu_devices = ['cpu', torch.device('cpu')]
cuda_devices = ['cuda', 0, torch.device('cuda')]
devices_to_test = cpu_devices
if torch.cuda.is_available():
devices_to_test += cuda_devices
def list_product(*args):
return list(itertools.product(*args))
@MODELS.register_module()
class CustomDataPreprocessor(BaseDataPreprocessor):
......@@ -158,3 +172,32 @@ class TestBaseModel(TestCase):
self.assertEqual(model.data_preprocessor._device, torch.device('cuda'))
self.assertEqual(model.toy_model.data_preprocessor._device,
torch.device('cuda'))
@parameterized.expand(list_product(devices_to_test))
def test_to_device(self, device):
model = ToyModel().to(device)
self.assertTrue(
all(p.device.type == torch.device(device).type
for p in model.parameters())
and model.data_preprocessor._device == torch.device(device))
@parameterized.expand(list_product(dtypes_to_test))
def test_to_dtype(self, dtype):
model = ToyModel().to(dtype)
self.assertTrue(all(p.dtype == dtype for p in model.parameters()))
@parameterized.expand(
list_product(devices_to_test, dtypes_to_test,
['args', 'kwargs', 'hybrid']))
def test_to_device_and_dtype(self, device, dtype, mode):
if mode == 'args':
model = ToyModel().to(device, dtype)
elif mode == 'kwargs':
model = ToyModel().to(device=device, dtype=dtype)
elif mode == 'hybrid':
model = ToyModel().to(device, dtype=dtype)
self.assertTrue(
all(p.dtype == dtype for p in model.parameters())
and model.data_preprocessor._device == torch.device(device)
and all(p.device.type == torch.device(device).type
for p in model.parameters()))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment