From 79067e46289ad7dccc46e29687c34a192a7d3f12 Mon Sep 17 00:00:00 2001
From: wangjiangben-hw <111729245+wangjiangben-hw@users.noreply.github.com>
Date: Tue, 10 Jan 2023 13:38:56 +0800
Subject: [PATCH] [Fix] Add support for Ascend device (#847)

* add npu device support

* add comment for torch.npu.set_compile_mode
---
 mmengine/device/utils.py                      |  4 ++++
 mmengine/model/base_model/base_model.py       | 12 +++++++++++
 .../model/base_model/data_preprocessor.py     | 21 +++++++++++++++++++
 mmengine/structures/base_data_element.py      | 11 ++++++++++
 4 files changed, 48 insertions(+)

diff --git a/mmengine/device/utils.py b/mmengine/device/utils.py
index 1a31af54..44e92f71 100644
--- a/mmengine/device/utils.py
+++ b/mmengine/device/utils.py
@@ -36,6 +36,10 @@ def is_npu_available() -> bool:
     """Returns True if Ascend PyTorch and npu devices exist."""
     try:
         import torch_npu  # noqa: F401
+
+        # Enable operator support for dynamic shape and
+        # binary operator support on the NPU.
+        torch.npu.set_compile_mode(jit_compile=False)
     except Exception:
         return False
     return hasattr(torch, 'npu') and torch.npu.is_available()
diff --git a/mmengine/model/base_model/base_model.py b/mmengine/model/base_model/base_model.py
index f9316506..06bf5b65 100644
--- a/mmengine/model/base_model/base_model.py
+++ b/mmengine/model/base_model/base_model.py
@@ -184,6 +184,18 @@ class BaseModel(BaseModule):
         Returns:
             nn.Module: The model itself.
         """
+
+        # Since Torch has not officially merged
+        # the npu-related fields, using the _parse_to function
+        # directly will cause the NPU to not be found.
+        # Here, the input parameters are processed to avoid errors.
+        if args and isinstance(args[0], str) and 'npu' in args[0]:
+            args = tuple(
+                [list(args)[0].replace('npu', torch.npu.native_device)])
+        if kwargs and 'npu' in str(kwargs.get('device', '')):
+            kwargs['device'] = kwargs['device'].replace(
+                'npu', torch.npu.native_device)
+
         device = torch._C._nn._parse_to(*args, **kwargs)[0]
         if device is not None:
             self._set_device(torch.device(device))
diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py
index 14c3db96..8f02a6c5 100644
--- a/mmengine/model/base_model/data_preprocessor.py
+++ b/mmengine/model/base_model/data_preprocessor.py
@@ -87,6 +87,18 @@ class BaseDataPreprocessor(nn.Module):
         Returns:
             nn.Module: The model itself.
         """
+
+        # Since Torch has not officially merged
+        # the npu-related fields, using the _parse_to function
+        # directly will cause the NPU to not be found.
+        # Here, the input parameters are processed to avoid errors.
+        if args and isinstance(args[0], str) and 'npu' in args[0]:
+            args = tuple(
+                [list(args)[0].replace('npu', torch.npu.native_device)])
+        if kwargs and 'npu' in str(kwargs.get('device', '')):
+            kwargs['device'] = kwargs['device'].replace(
+                'npu', torch.npu.native_device)
+
         device = torch._C._nn._parse_to(*args, **kwargs)[0]
         if device is not None:
             self._device = torch.device(device)
@@ -101,6 +113,15 @@ class BaseDataPreprocessor(nn.Module):
         self._device = torch.device(torch.cuda.current_device())
         return super().cuda()
 
+    def npu(self, *args, **kwargs) -> nn.Module:
+        """Overrides this method to set the :attr:`device`
+
+        Returns:
+            nn.Module: The model itself.
+        """
+        self._device = torch.device(torch.npu.current_device())
+        return super().npu()
+
     def cpu(self, *args, **kwargs) -> nn.Module:
         """Overrides this method to set the :attr:`device`
 
diff --git a/mmengine/structures/base_data_element.py b/mmengine/structures/base_data_element.py
index 042a9df6..7be1ef90 100644
--- a/mmengine/structures/base_data_element.py
+++ b/mmengine/structures/base_data_element.py
@@ -507,6 +507,17 @@ class BaseDataElement:
                 new_data.set_data(data)
         return new_data
 
+    # Tensor-like methods
+    def npu(self) -> 'BaseDataElement':
+        """Convert all tensors to NPU in data."""
+        new_data = self.new()
+        for k, v in self.items():
+            if isinstance(v, (torch.Tensor, BaseDataElement)):
+                v = v.npu()
+                data = {k: v}
+                new_data.set_data(data)
+        return new_data
+
     # Tensor-like methods
     def detach(self) -> 'BaseDataElement':
         """Detach all tensors in data."""
-- 
GitLab