From 79067e46289ad7dccc46e29687c34a192a7d3f12 Mon Sep 17 00:00:00 2001 From: wangjiangben-hw <111729245+wangjiangben-hw@users.noreply.github.com> Date: Tue, 10 Jan 2023 13:38:56 +0800 Subject: [PATCH] [Fix] Add support for Ascend device (#847) * add npu device support * add comment for torch.npu.set_compile_mode --- mmengine/device/utils.py | 4 ++++ mmengine/model/base_model/base_model.py | 12 +++++++++++ .../model/base_model/data_preprocessor.py | 21 +++++++++++++++++++ mmengine/structures/base_data_element.py | 11 ++++++++++ 4 files changed, 48 insertions(+) diff --git a/mmengine/device/utils.py b/mmengine/device/utils.py index 1a31af54..44e92f71 100644 --- a/mmengine/device/utils.py +++ b/mmengine/device/utils.py @@ -36,6 +36,10 @@ def is_npu_available() -> bool: """Returns True if Ascend PyTorch and npu devices exist.""" try: import torch_npu # noqa: F401 + + # Enable operator support for dynamic shape and + # binary operator support on the NPU. + torch.npu.set_compile_mode(jit_compile=False) except Exception: return False return hasattr(torch, 'npu') and torch.npu.is_available() diff --git a/mmengine/model/base_model/base_model.py b/mmengine/model/base_model/base_model.py index f9316506..06bf5b65 100644 --- a/mmengine/model/base_model/base_model.py +++ b/mmengine/model/base_model/base_model.py @@ -184,6 +184,18 @@ class BaseModel(BaseModule): Returns: nn.Module: The model itself. """ + + # Since Torch has not officially merged + # the npu-related fields, using the _parse_to function + # directly will cause the NPU to not be found. + # Here, the input parameters are processed to avoid errors. + if args and isinstance(args[0], str) and 'npu' in args[0]: + args = tuple( + [list(args)[0].replace('npu', torch.npu.native_device)]) + if kwargs and 'npu' in str(kwargs.get('device', '')): + kwargs['device'] = kwargs['device'].replace( + 'npu', torch.npu.native_device) + device = torch._C._nn._parse_to(*args, **kwargs)[0] if device is not None: self._set_device(torch.device(device)) diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py index 14c3db96..8f02a6c5 100644 --- a/mmengine/model/base_model/data_preprocessor.py +++ b/mmengine/model/base_model/data_preprocessor.py @@ -87,6 +87,18 @@ class BaseDataPreprocessor(nn.Module): Returns: nn.Module: The model itself. """ + + # Since Torch has not officially merged + # the npu-related fields, using the _parse_to function + # directly will cause the NPU to not be found. + # Here, the input parameters are processed to avoid errors. + if args and isinstance(args[0], str) and 'npu' in args[0]: + args = tuple( + [list(args)[0].replace('npu', torch.npu.native_device)]) + if kwargs and 'npu' in str(kwargs.get('device', '')): + kwargs['device'] = kwargs['device'].replace( + 'npu', torch.npu.native_device) + device = torch._C._nn._parse_to(*args, **kwargs)[0] if device is not None: self._device = torch.device(device) @@ -101,6 +113,15 @@ class BaseDataPreprocessor(nn.Module): self._device = torch.device(torch.cuda.current_device()) return super().cuda() + def npu(self, *args, **kwargs) -> nn.Module: + """Overrides this method to set the :attr:`device` + + Returns: + nn.Module: The model itself. + """ + self._device = torch.device(torch.npu.current_device()) + return super().npu() + def cpu(self, *args, **kwargs) -> nn.Module: """Overrides this method to set the :attr:`device` diff --git a/mmengine/structures/base_data_element.py b/mmengine/structures/base_data_element.py index 042a9df6..7be1ef90 100644 --- a/mmengine/structures/base_data_element.py +++ b/mmengine/structures/base_data_element.py @@ -507,6 +507,17 @@ class BaseDataElement: new_data.set_data(data) return new_data + # Tensor-like methods + def npu(self) -> 'BaseDataElement': + """Convert all tensors to NPU in data.""" + new_data = self.new() + for k, v in self.items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.npu() + data = {k: v} + new_data.set_data(data) + return new_data + # Tensor-like methods def detach(self) -> 'BaseDataElement': """Detach all tensors in data.""" -- GitLab