Skip to content
Snippets Groups Projects
Unverified Commit 438e8e74 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

BaseModel support recursively set the device of data_preprocessor (#387)

parent f98ba606
No related branches found
No related tags found
No related merge requests found
......@@ -11,6 +11,7 @@ from mmengine.optim import OptimWrapper
from mmengine.registry import MODELS
from mmengine.utils import is_list_of
from ..base_module import BaseModule
from .data_preprocessor import BaseDataPreprocessor
ForwardResults = Union[Dict[str, torch.Tensor], List[BaseDataElement],
Tuple[torch.Tensor], torch.Tensor]
......@@ -177,30 +178,38 @@ class BaseModel(BaseModule):
return loss, log_vars
def to(self, device: Optional[Union[int, torch.device]], *args,
def to(self,
device: Optional[Union[int, str, torch.device]] = None,
*args,
**kwargs) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.to`
additionally.
Args:
device (int or torch.device, optional): the desired device of the
parameters and buffers in this module.
device (int, str or torch.device, optional): the desired device
of the parameters and buffers in this module.
Returns:
nn.Module: The model itself.
"""
self.data_preprocessor.to(device)
if device is not None:
self._set_device(torch.device(device))
return super().to(device)
def cuda(self, *args, **kwargs) -> nn.Module:
def cuda(
self,
device: Optional[Union[int, str, torch.device]] = None,
) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.cuda`
additionally.
Returns:
nn.Module: The model itself.
"""
self.data_preprocessor.cuda()
return super().cuda()
if device is None or isinstance(device, int):
device = torch.device('cuda', index=device)
self._set_device(torch.device(device))
return super().cuda(device)
def cpu(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.cpu`
......@@ -209,9 +218,25 @@ class BaseModel(BaseModule):
Returns:
nn.Module: The model itself.
"""
self.data_preprocessor.cpu()
self._set_device(torch.device('cpu'))
return super().cpu()
def _set_device(self, device: torch.device) -> None:
"""Recursively set device for `BaseDataPreprocessor` instance.
Args:
device (torch.device): the desired device of the parameters and
buffers in this module.
"""
def apply_fn(module):
if not isinstance(module, BaseDataPreprocessor):
return
if device is not None:
module._device = device
self.apply(apply_fn)
@abstractmethod
def forward(self,
batch_inputs: torch.Tensor,
......
......@@ -40,6 +40,16 @@ class ToyModel(BaseModel):
return out
class NestedModel(BaseModel):
def __init__(self):
super().__init__()
self.toy_model = ToyModel()
def forward(self):
pass
class TestBaseModel(TestCase):
def test_init(self):
......@@ -118,6 +128,15 @@ class TestBaseModel(TestCase):
out = model.val_step([data])
self.assertEqual(out.device.type, 'cuda')
model = NestedModel()
self.assertEqual(model.data_preprocessor._device, torch.device('cpu'))
self.assertEqual(model.toy_model.data_preprocessor._device,
torch.device('cpu'))
model.cuda()
self.assertEqual(model.data_preprocessor._device, torch.device('cuda'))
self.assertEqual(model.toy_model.data_preprocessor._device,
torch.device('cuda'))
@unittest.skipIf(not torch.cuda.is_available(), 'cuda should be available')
def test_to(self):
inputs = torch.randn(3, 1, 1).to('cuda:0')
......@@ -125,3 +144,17 @@ class TestBaseModel(TestCase):
model = ToyModel().to(torch.cuda.current_device())
out = model.val_step([data])
self.assertEqual(out.device.type, 'cuda')
model = NestedModel()
self.assertEqual(model.data_preprocessor._device, torch.device('cpu'))
self.assertEqual(model.toy_model.data_preprocessor._device,
torch.device('cpu'))
model.to('cuda')
self.assertEqual(model.data_preprocessor._device, torch.device('cuda'))
self.assertEqual(model.toy_model.data_preprocessor._device,
torch.device('cuda'))
model.to()
self.assertEqual(model.data_preprocessor._device, torch.device('cuda'))
self.assertEqual(model.toy_model.data_preprocessor._device,
torch.device('cuda'))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment