Skip to content
Snippets Groups Projects
Unverified Commit 438e8e74 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

BaseModel support recursively set the device of data_preprocessor (#387)

parent f98ba606
No related branches found
No related tags found
No related merge requests found
...@@ -11,6 +11,7 @@ from mmengine.optim import OptimWrapper ...@@ -11,6 +11,7 @@ from mmengine.optim import OptimWrapper
from mmengine.registry import MODELS from mmengine.registry import MODELS
from mmengine.utils import is_list_of from mmengine.utils import is_list_of
from ..base_module import BaseModule from ..base_module import BaseModule
from .data_preprocessor import BaseDataPreprocessor
ForwardResults = Union[Dict[str, torch.Tensor], List[BaseDataElement], ForwardResults = Union[Dict[str, torch.Tensor], List[BaseDataElement],
Tuple[torch.Tensor], torch.Tensor] Tuple[torch.Tensor], torch.Tensor]
...@@ -177,30 +178,38 @@ class BaseModel(BaseModule): ...@@ -177,30 +178,38 @@ class BaseModel(BaseModule):
return loss, log_vars return loss, log_vars
def to(self, device: Optional[Union[int, torch.device]], *args, def to(self,
device: Optional[Union[int, str, torch.device]] = None,
*args,
**kwargs) -> nn.Module: **kwargs) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.to` """Overrides this method to call :meth:`BaseDataPreprocessor.to`
additionally. additionally.
Args: Args:
device (int or torch.device, optional): the desired device of the device (int, str or torch.device, optional): the desired device
parameters and buffers in this module. of the parameters and buffers in this module.
Returns: Returns:
nn.Module: The model itself. nn.Module: The model itself.
""" """
self.data_preprocessor.to(device) if device is not None:
self._set_device(torch.device(device))
return super().to(device) return super().to(device)
def cuda(self, *args, **kwargs) -> nn.Module: def cuda(
self,
device: Optional[Union[int, str, torch.device]] = None,
) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.cuda` """Overrides this method to call :meth:`BaseDataPreprocessor.cuda`
additionally. additionally.
Returns: Returns:
nn.Module: The model itself. nn.Module: The model itself.
""" """
self.data_preprocessor.cuda() if device is None or isinstance(device, int):
return super().cuda() device = torch.device('cuda', index=device)
self._set_device(torch.device(device))
return super().cuda(device)
def cpu(self, *args, **kwargs) -> nn.Module: def cpu(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.cpu` """Overrides this method to call :meth:`BaseDataPreprocessor.cpu`
...@@ -209,9 +218,25 @@ class BaseModel(BaseModule): ...@@ -209,9 +218,25 @@ class BaseModel(BaseModule):
Returns: Returns:
nn.Module: The model itself. nn.Module: The model itself.
""" """
self.data_preprocessor.cpu() self._set_device(torch.device('cpu'))
return super().cpu() return super().cpu()
def _set_device(self, device: torch.device) -> None:
"""Recursively set device for `BaseDataPreprocessor` instance.
Args:
device (torch.device): the desired device of the parameters and
buffers in this module.
"""
def apply_fn(module):
if not isinstance(module, BaseDataPreprocessor):
return
if device is not None:
module._device = device
self.apply(apply_fn)
@abstractmethod @abstractmethod
def forward(self, def forward(self,
batch_inputs: torch.Tensor, batch_inputs: torch.Tensor,
......
...@@ -40,6 +40,16 @@ class ToyModel(BaseModel): ...@@ -40,6 +40,16 @@ class ToyModel(BaseModel):
return out return out
class NestedModel(BaseModel):
def __init__(self):
super().__init__()
self.toy_model = ToyModel()
def forward(self):
pass
class TestBaseModel(TestCase): class TestBaseModel(TestCase):
def test_init(self): def test_init(self):
...@@ -118,6 +128,15 @@ class TestBaseModel(TestCase): ...@@ -118,6 +128,15 @@ class TestBaseModel(TestCase):
out = model.val_step([data]) out = model.val_step([data])
self.assertEqual(out.device.type, 'cuda') self.assertEqual(out.device.type, 'cuda')
model = NestedModel()
self.assertEqual(model.data_preprocessor._device, torch.device('cpu'))
self.assertEqual(model.toy_model.data_preprocessor._device,
torch.device('cpu'))
model.cuda()
self.assertEqual(model.data_preprocessor._device, torch.device('cuda'))
self.assertEqual(model.toy_model.data_preprocessor._device,
torch.device('cuda'))
@unittest.skipIf(not torch.cuda.is_available(), 'cuda should be available') @unittest.skipIf(not torch.cuda.is_available(), 'cuda should be available')
def test_to(self): def test_to(self):
inputs = torch.randn(3, 1, 1).to('cuda:0') inputs = torch.randn(3, 1, 1).to('cuda:0')
...@@ -125,3 +144,17 @@ class TestBaseModel(TestCase): ...@@ -125,3 +144,17 @@ class TestBaseModel(TestCase):
model = ToyModel().to(torch.cuda.current_device()) model = ToyModel().to(torch.cuda.current_device())
out = model.val_step([data]) out = model.val_step([data])
self.assertEqual(out.device.type, 'cuda') self.assertEqual(out.device.type, 'cuda')
model = NestedModel()
self.assertEqual(model.data_preprocessor._device, torch.device('cpu'))
self.assertEqual(model.toy_model.data_preprocessor._device,
torch.device('cpu'))
model.to('cuda')
self.assertEqual(model.data_preprocessor._device, torch.device('cuda'))
self.assertEqual(model.toy_model.data_preprocessor._device,
torch.device('cuda'))
model.to()
self.assertEqual(model.data_preprocessor._device, torch.device('cuda'))
self.assertEqual(model.toy_model.data_preprocessor._device,
torch.device('cuda'))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment