diff --git a/mmengine/dist/dist.py b/mmengine/dist/dist.py index 3b05f06a71e295c3f3f5a8dabb9074f6774a92d5..b989c45a7ca7aa5c07c92d54a3e1be71d6e10efc 100644 --- a/mmengine/dist/dist.py +++ b/mmengine/dist/dist.py @@ -414,9 +414,13 @@ def _broadcast_object_list(object_list: List[Any], is_nccl_backend = group_backend == torch_dist.Backend.NCCL current_device = torch.device('cpu') is_hccl_backend = group_backend == 'hccl' + is_cncl_backend = group_backend == 'cncl' if is_hccl_backend: current_device = torch.npu.current_device() object_sizes_tensor = object_sizes_tensor.to(current_device) + elif is_cncl_backend: + current_device = torch.device('mlu', torch.mlu.current_device()) + object_sizes_tensor = object_sizes_tensor.to(current_device) elif is_nccl_backend: # See note about using torch.cuda.current_device() here in # docstring. We cannot simply use my_rank since rank == device is @@ -436,7 +440,7 @@ def _broadcast_object_list(object_list: List[Any], dtype=torch.uint8, ) - if is_nccl_backend or is_hccl_backend: + if is_nccl_backend or is_hccl_backend or is_cncl_backend: object_tensor = object_tensor.to(current_device) torch_dist.broadcast(object_tensor, src=src, group=group) # Deserialize objects using their stored sizes. diff --git a/mmengine/model/base_model/base_model.py b/mmengine/model/base_model/base_model.py index 3e2d72da0dd58bcdbb0e50a07a299fda584e5424..c682ff48ea7b5510e0258a4f711c32b5b951847c 100644 --- a/mmengine/model/base_model/base_model.py +++ b/mmengine/model/base_model/base_model.py @@ -216,6 +216,20 @@ class BaseModel(BaseModule): self._set_device(torch.device(device)) return super().cuda(device) + def mlu( + self, + device: Union[int, str, torch.device, None] = None, + ) -> nn.Module: + """Overrides this method to call :meth:`BaseDataPreprocessor.mlu` + additionally. + + Returns: + nn.Module: The model itself. + """ + device = torch.device('mlu', torch.mlu.current_device()) + self._set_device(device) + return super().mlu() + def npu( self, device: Union[int, str, torch.device, None] = None, diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py index 8f02a6c5a4c597cb1d41e0f6550f85a95139c757..7f8a54c1780b963a3f479feac3712ef3124f653b 100644 --- a/mmengine/model/base_model/data_preprocessor.py +++ b/mmengine/model/base_model/data_preprocessor.py @@ -122,6 +122,15 @@ class BaseDataPreprocessor(nn.Module): self._device = torch.device(torch.npu.current_device()) return super().npu() + def mlu(self, *args, **kwargs) -> nn.Module: + """Overrides this method to set the :attr:`device` + + Returns: + nn.Module: The model itself. + """ + self._device = torch.device(torch.mlu.current_device()) + return super().mlu() + def cpu(self, *args, **kwargs) -> nn.Module: """Overrides this method to set the :attr:`device` diff --git a/mmengine/optim/optimizer/amp_optimizer_wrapper.py b/mmengine/optim/optimizer/amp_optimizer_wrapper.py index fcc3b4c50fb82e92be87ca293cc4f95cfde6de57..861361369fbfa057e7fa9d06f18eee75a96ab79a 100644 --- a/mmengine/optim/optimizer/amp_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/amp_optimizer_wrapper.py @@ -5,7 +5,8 @@ from typing import Union import torch import torch.nn as nn -from mmengine.device import is_cuda_available, is_npu_available +from mmengine.device import (is_cuda_available, is_mlu_available, + is_npu_available) from mmengine.registry import OPTIM_WRAPPERS from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION @@ -13,6 +14,8 @@ from .optimizer_wrapper import OptimWrapper if is_npu_available(): from torch.npu.amp import GradScaler +elif is_mlu_available(): + from torch.mlu.amp import GradScaler else: from torch.cuda.amp import GradScaler @@ -65,8 +68,9 @@ class AmpOptimWrapper(OptimWrapper): **kwargs): assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), ( '`torch.cuda.amp` is only available when pytorch version >= 1.6') - assert is_cuda_available() or is_npu_available(), ( - '``AmpOptimizerWrapper`` is only available training on gpu or npu') + assert is_cuda_available() or is_npu_available() or is_mlu_available( + ), ('``AmpOptimizerWrapper`` is only available training ' + 'on gpu, npu or mlu') super().__init__(**kwargs) self._scale_update_param = None if loss_scale == 'dynamic': diff --git a/mmengine/runner/amp.py b/mmengine/runner/amp.py index 33ab6bd25d3801e2133e455b57b214af25330256..964518fc90ca9f949f74e1f38421cf9fefb5cfcd 100644 --- a/mmengine/runner/amp.py +++ b/mmengine/runner/amp.py @@ -5,7 +5,8 @@ from typing import Optional import torch -from mmengine.device import get_device, is_cuda_available, is_npu_available +from mmengine.device import (get_device, is_cuda_available, is_mlu_available, + is_npu_available) from mmengine.logging import print_log from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION @@ -75,9 +76,11 @@ def autocast(device_type: Optional[str] = None, digit_version('1.10.0')): # If pytorch version is between 1.5.0 and 1.10.0, the default value of # dtype for `torch.cuda.amp.autocast` is torch.float16. - assert device_type == 'cuda' or device_type is None, ( - 'Pytorch version under 1.10.0 only supports running automatic ' - 'mixed training with cuda') + assert ( + device_type == 'cuda' or device_type == 'mlu' + or device_type is None), ( + 'Pytorch version under 1.10.0 only supports running automatic ' + 'mixed training with cuda or mlu') if dtype is not None or cache_enabled is not None: print_log( f'{dtype} and {device_type} will not work for ' @@ -89,6 +92,9 @@ def autocast(device_type: Optional[str] = None, if is_npu_available(): with torch.npu.amp.autocast(enabled=enabled): yield + elif is_mlu_available(): + with torch.mlu.amp.autocast(enabled=enabled): + yield elif is_cuda_available(): with torch.cuda.amp.autocast(enabled=enabled): yield diff --git a/mmengine/structures/base_data_element.py b/mmengine/structures/base_data_element.py index 46c4c886e6633c69104327ea71278f8cc7ff650f..454a2243718b429bb2b5192c6f130c006dec7c7a 100644 --- a/mmengine/structures/base_data_element.py +++ b/mmengine/structures/base_data_element.py @@ -521,6 +521,16 @@ class BaseDataElement: new_data.set_data(data) return new_data + def mlu(self) -> 'BaseDataElement': + """Convert all tensors to MLU in data.""" + new_data = self.new() + for k, v in self.items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.mlu() + data = {k: v} + new_data.set_data(data) + return new_data + # Tensor-like methods def detach(self) -> 'BaseDataElement': """Detach all tensors in data.""" diff --git a/mmengine/structures/instance_data.py b/mmengine/structures/instance_data.py index 1ceac9ad24d79300aa4f7d40e52d6791f2af43bf..8df9727a00027a2c949e8bac37e3beaa49af68ba 100644 --- a/mmengine/structures/instance_data.py +++ b/mmengine/structures/instance_data.py @@ -15,6 +15,9 @@ LongTypeTensor: Union[Any] if get_device() == 'npu': BoolTypeTensor = Union[torch.BoolTensor, torch.npu.BoolTensor] LongTypeTensor = Union[torch.LongTensor, torch.npu.LongTensor] +elif get_device() == 'mlu': + BoolTypeTensor = Union[torch.BoolTensor, torch.mlu.BoolTensor] + LongTypeTensor = Union[torch.LongTensor, torch.mlu.LongTensor] else: BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor] LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor] diff --git a/tests/test_runner/test_amp.py b/tests/test_runner/test_amp.py index 7ef605637832c361563ebf3da313abb75cef1e5e..89794f34144152c4ea8441a9359357b9743fc13d 100644 --- a/tests/test_runner/test_amp.py +++ b/tests/test_runner/test_amp.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn import mmengine -from mmengine.device import get_device +from mmengine.device import get_device, is_mlu_available from mmengine.runner import autocast from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION @@ -14,7 +14,22 @@ from mmengine.utils.dl_utils import TORCH_VERSION class TestAmp(unittest.TestCase): def test_autocast(self): - if not torch.cuda.is_available(): + if is_mlu_available(): + device = 'mlu' + with autocast(device_type=device): + # torch.autocast support mlu mode. + layer = nn.Conv2d(1, 1, 1).to(device) + res = layer(torch.randn(1, 1, 1, 1).to(device)) + self.assertIn(res.dtype, (torch.bfloat16, torch.float16)) + with autocast(enabled=False, device_type=device): + res = layer(torch.randn(1, 1, 1, 1).to(device)) + self.assertEqual(res.dtype, torch.float32) + # Test with fp32_enabled + with autocast(enabled=False, device_type=device): + layer = nn.Conv2d(1, 1, 1).to(device) + res = layer(torch.randn(1, 1, 1, 1).to(device)) + self.assertEqual(res.dtype, torch.float32) + elif not torch.cuda.is_available(): if digit_version(TORCH_VERSION) < digit_version('1.10.0'): # `torch.cuda.amp.autocast` is only support in gpu mode, if # cuda is not available, it will return an empty context and