Skip to content
Snippets Groups Projects
Unverified Commit 60b4c199 authored by 黄启元's avatar 黄启元 Committed by GitHub
Browse files

[Feature] Support MLU backend (#1075)

* support mlu device

* support mlu device

* fix lint error

* fix lint error builder.py

* fix lint error in amp.py

* fix lint errors

* fix data type in instance_data.py
parent f22002ec
No related branches found
No related tags found
No related merge requests found
...@@ -414,9 +414,13 @@ def _broadcast_object_list(object_list: List[Any], ...@@ -414,9 +414,13 @@ def _broadcast_object_list(object_list: List[Any],
is_nccl_backend = group_backend == torch_dist.Backend.NCCL is_nccl_backend = group_backend == torch_dist.Backend.NCCL
current_device = torch.device('cpu') current_device = torch.device('cpu')
is_hccl_backend = group_backend == 'hccl' is_hccl_backend = group_backend == 'hccl'
is_cncl_backend = group_backend == 'cncl'
if is_hccl_backend: if is_hccl_backend:
current_device = torch.npu.current_device() current_device = torch.npu.current_device()
object_sizes_tensor = object_sizes_tensor.to(current_device) object_sizes_tensor = object_sizes_tensor.to(current_device)
elif is_cncl_backend:
current_device = torch.device('mlu', torch.mlu.current_device())
object_sizes_tensor = object_sizes_tensor.to(current_device)
elif is_nccl_backend: elif is_nccl_backend:
# See note about using torch.cuda.current_device() here in # See note about using torch.cuda.current_device() here in
# docstring. We cannot simply use my_rank since rank == device is # docstring. We cannot simply use my_rank since rank == device is
...@@ -436,7 +440,7 @@ def _broadcast_object_list(object_list: List[Any], ...@@ -436,7 +440,7 @@ def _broadcast_object_list(object_list: List[Any],
dtype=torch.uint8, dtype=torch.uint8,
) )
if is_nccl_backend or is_hccl_backend: if is_nccl_backend or is_hccl_backend or is_cncl_backend:
object_tensor = object_tensor.to(current_device) object_tensor = object_tensor.to(current_device)
torch_dist.broadcast(object_tensor, src=src, group=group) torch_dist.broadcast(object_tensor, src=src, group=group)
# Deserialize objects using their stored sizes. # Deserialize objects using their stored sizes.
......
...@@ -216,6 +216,20 @@ class BaseModel(BaseModule): ...@@ -216,6 +216,20 @@ class BaseModel(BaseModule):
self._set_device(torch.device(device)) self._set_device(torch.device(device))
return super().cuda(device) return super().cuda(device)
def mlu(
self,
device: Union[int, str, torch.device, None] = None,
) -> nn.Module:
"""Overrides this method to call :meth:`BaseDataPreprocessor.mlu`
additionally.
Returns:
nn.Module: The model itself.
"""
device = torch.device('mlu', torch.mlu.current_device())
self._set_device(device)
return super().mlu()
def npu( def npu(
self, self,
device: Union[int, str, torch.device, None] = None, device: Union[int, str, torch.device, None] = None,
......
...@@ -122,6 +122,15 @@ class BaseDataPreprocessor(nn.Module): ...@@ -122,6 +122,15 @@ class BaseDataPreprocessor(nn.Module):
self._device = torch.device(torch.npu.current_device()) self._device = torch.device(torch.npu.current_device())
return super().npu() return super().npu()
def mlu(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to set the :attr:`device`
Returns:
nn.Module: The model itself.
"""
self._device = torch.device(torch.mlu.current_device())
return super().mlu()
def cpu(self, *args, **kwargs) -> nn.Module: def cpu(self, *args, **kwargs) -> nn.Module:
"""Overrides this method to set the :attr:`device` """Overrides this method to set the :attr:`device`
......
...@@ -5,7 +5,8 @@ from typing import Union ...@@ -5,7 +5,8 @@ from typing import Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmengine.device import is_cuda_available, is_npu_available from mmengine.device import (is_cuda_available, is_mlu_available,
is_npu_available)
from mmengine.registry import OPTIM_WRAPPERS from mmengine.registry import OPTIM_WRAPPERS
from mmengine.utils import digit_version from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION from mmengine.utils.dl_utils import TORCH_VERSION
...@@ -13,6 +14,8 @@ from .optimizer_wrapper import OptimWrapper ...@@ -13,6 +14,8 @@ from .optimizer_wrapper import OptimWrapper
if is_npu_available(): if is_npu_available():
from torch.npu.amp import GradScaler from torch.npu.amp import GradScaler
elif is_mlu_available():
from torch.mlu.amp import GradScaler
else: else:
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
...@@ -65,8 +68,9 @@ class AmpOptimWrapper(OptimWrapper): ...@@ -65,8 +68,9 @@ class AmpOptimWrapper(OptimWrapper):
**kwargs): **kwargs):
assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), ( assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), (
'`torch.cuda.amp` is only available when pytorch version >= 1.6') '`torch.cuda.amp` is only available when pytorch version >= 1.6')
assert is_cuda_available() or is_npu_available(), ( assert is_cuda_available() or is_npu_available() or is_mlu_available(
'``AmpOptimizerWrapper`` is only available training on gpu or npu') ), ('``AmpOptimizerWrapper`` is only available training '
'on gpu, npu or mlu')
super().__init__(**kwargs) super().__init__(**kwargs)
self._scale_update_param = None self._scale_update_param = None
if loss_scale == 'dynamic': if loss_scale == 'dynamic':
......
...@@ -5,7 +5,8 @@ from typing import Optional ...@@ -5,7 +5,8 @@ from typing import Optional
import torch import torch
from mmengine.device import get_device, is_cuda_available, is_npu_available from mmengine.device import (get_device, is_cuda_available, is_mlu_available,
is_npu_available)
from mmengine.logging import print_log from mmengine.logging import print_log
from mmengine.utils import digit_version from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION from mmengine.utils.dl_utils import TORCH_VERSION
...@@ -75,9 +76,11 @@ def autocast(device_type: Optional[str] = None, ...@@ -75,9 +76,11 @@ def autocast(device_type: Optional[str] = None,
digit_version('1.10.0')): digit_version('1.10.0')):
# If pytorch version is between 1.5.0 and 1.10.0, the default value of # If pytorch version is between 1.5.0 and 1.10.0, the default value of
# dtype for `torch.cuda.amp.autocast` is torch.float16. # dtype for `torch.cuda.amp.autocast` is torch.float16.
assert device_type == 'cuda' or device_type is None, ( assert (
'Pytorch version under 1.10.0 only supports running automatic ' device_type == 'cuda' or device_type == 'mlu'
'mixed training with cuda') or device_type is None), (
'Pytorch version under 1.10.0 only supports running automatic '
'mixed training with cuda or mlu')
if dtype is not None or cache_enabled is not None: if dtype is not None or cache_enabled is not None:
print_log( print_log(
f'{dtype} and {device_type} will not work for ' f'{dtype} and {device_type} will not work for '
...@@ -89,6 +92,9 @@ def autocast(device_type: Optional[str] = None, ...@@ -89,6 +92,9 @@ def autocast(device_type: Optional[str] = None,
if is_npu_available(): if is_npu_available():
with torch.npu.amp.autocast(enabled=enabled): with torch.npu.amp.autocast(enabled=enabled):
yield yield
elif is_mlu_available():
with torch.mlu.amp.autocast(enabled=enabled):
yield
elif is_cuda_available(): elif is_cuda_available():
with torch.cuda.amp.autocast(enabled=enabled): with torch.cuda.amp.autocast(enabled=enabled):
yield yield
......
...@@ -521,6 +521,16 @@ class BaseDataElement: ...@@ -521,6 +521,16 @@ class BaseDataElement:
new_data.set_data(data) new_data.set_data(data)
return new_data return new_data
def mlu(self) -> 'BaseDataElement':
"""Convert all tensors to MLU in data."""
new_data = self.new()
for k, v in self.items():
if isinstance(v, (torch.Tensor, BaseDataElement)):
v = v.mlu()
data = {k: v}
new_data.set_data(data)
return new_data
# Tensor-like methods # Tensor-like methods
def detach(self) -> 'BaseDataElement': def detach(self) -> 'BaseDataElement':
"""Detach all tensors in data.""" """Detach all tensors in data."""
......
...@@ -15,6 +15,9 @@ LongTypeTensor: Union[Any] ...@@ -15,6 +15,9 @@ LongTypeTensor: Union[Any]
if get_device() == 'npu': if get_device() == 'npu':
BoolTypeTensor = Union[torch.BoolTensor, torch.npu.BoolTensor] BoolTypeTensor = Union[torch.BoolTensor, torch.npu.BoolTensor]
LongTypeTensor = Union[torch.LongTensor, torch.npu.LongTensor] LongTypeTensor = Union[torch.LongTensor, torch.npu.LongTensor]
elif get_device() == 'mlu':
BoolTypeTensor = Union[torch.BoolTensor, torch.mlu.BoolTensor]
LongTypeTensor = Union[torch.LongTensor, torch.mlu.LongTensor]
else: else:
BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor] BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor]
LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor] LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor]
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import mmengine import mmengine
from mmengine.device import get_device from mmengine.device import get_device, is_mlu_available
from mmengine.runner import autocast from mmengine.runner import autocast
from mmengine.utils import digit_version from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION from mmengine.utils.dl_utils import TORCH_VERSION
...@@ -14,7 +14,22 @@ from mmengine.utils.dl_utils import TORCH_VERSION ...@@ -14,7 +14,22 @@ from mmengine.utils.dl_utils import TORCH_VERSION
class TestAmp(unittest.TestCase): class TestAmp(unittest.TestCase):
def test_autocast(self): def test_autocast(self):
if not torch.cuda.is_available(): if is_mlu_available():
device = 'mlu'
with autocast(device_type=device):
# torch.autocast support mlu mode.
layer = nn.Conv2d(1, 1, 1).to(device)
res = layer(torch.randn(1, 1, 1, 1).to(device))
self.assertIn(res.dtype, (torch.bfloat16, torch.float16))
with autocast(enabled=False, device_type=device):
res = layer(torch.randn(1, 1, 1, 1).to(device))
self.assertEqual(res.dtype, torch.float32)
# Test with fp32_enabled
with autocast(enabled=False, device_type=device):
layer = nn.Conv2d(1, 1, 1).to(device)
res = layer(torch.randn(1, 1, 1, 1).to(device))
self.assertEqual(res.dtype, torch.float32)
elif not torch.cuda.is_available():
if digit_version(TORCH_VERSION) < digit_version('1.10.0'): if digit_version(TORCH_VERSION) < digit_version('1.10.0'):
# `torch.cuda.amp.autocast` is only support in gpu mode, if # `torch.cuda.amp.autocast` is only support in gpu mode, if
# cuda is not available, it will return an empty context and # cuda is not available, it will return an empty context and
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment